Detecting Keratoconus from Corneal Imaging using Deep Learning.

A few weeks ago I was diagnosed with an eye condition known as Keratoconus. Keratoconus affects approximately one in 2,000 individuals worldwide. It is typically associated with a decrease in visual acuity. Before this, I had no idea what a cornea topography was. After receiving my cornea topography images. I immediately thought that this would be a great opportunity to apply computer vision. One of my interests lies in the application of computer vision in medical imaging analysis.

Normal Cornea vs Keratoconus Cornea

Image Dataset Download and Setup

Here is the data used. We will use a google colab for this task. Place the data in your google drive and mount it to your colab notebook.

We will be classifying eyes_data. so we created 2 folders corresponding to 2 different types of eye classes. We have normal and those with keratoconus, as shown below.

Note: we have very low amounts of data. There are 48 training images and 12 validation images. This is considered low amount of data for any machine learning task.

Approach

We will be using a Convolution Neural Network(CNN) with a ResNet34 architecture.

First, what’s CNN again? We can think of a Convolutional Neural Network (CNN or ConvNet) as a list of layers that transform the image volume into an output volume, which can be a class score as it is the case in this tutorial.

Transfer Learning

Transfer learning is a technique that addresses the two problems posed by deep learning:

  1. Training a CNN requires a large number of images due to the high number of parameters it contains.
  2. Training a CNN requires large computational power.

Transfer learning relies on the fact that the lower layers of the CNNs that were trained on natural images contain generic, low-level features like edges, colors, etc regardless of the specific images in the data set. Using transfer learning makes it possible to fine-tune a CNN on a small data set in a relatively short time.

What is special about ResNet architecture is how it tackles the degradation problem most common in deep networks, where the model accuracy gets saturated and then degrades rapidly.

Resnet architecture introduces an “identity shortcut connection” or often referred to as a “skip connection”, which skips one or more layers. The shortcut connections simply perform identity mappings, and their outputs are added to the outputs of the stacked layers, as shown in the figure below. The skip function creates what is known as a residual block, F(x) in our figure, and that’s where the name Residual Nets (ResNets) came from. connection” or often referred to as a “skip connection”, which skips one or more layers. The shortcut connections simply perform identity mappings, and their outputs are added to the outputs of the stacked layers, as shown in the figure below. The skip function creates what is known as a residual block, F(x) in our figure, and that’s where the name Residual Nets (ResNets) came from.

Comprehensive empirical evidence has shown that the addition of these identity mappings allows the model to go deeper without degradation in performance and such networks are easier to optimize than plain stacked layers. There are several variants of ResNets, such as ResNet50, ResNet101, ResNet152; the number represents the number of layers (depth) of the ResNet.

We split the data into 80% training and 20% test data using imageDataBunch then proceed to create a learner.

Picking a learning rate

Use the above plot to pick adequate learning rates for your model. We need two learning rates since we are using cyclic learning rates:

  • The first learning rate is just before the loss starts to increase, preferably 10x smaller than the rate at which the loss starts to increase.
  • The second learning rate is 10x smaller than the first learning rate, so 1e-03 in our example.

Result Visualization

We use a confusion matrix and look at how sample images were classified.

What can we learn from this matrix?

  • There are two possible predicted classes: “kt_eyes” and “normal_eyes”. The classifier made a total of 12 predictions.
  • Out of those 12 cases, the classifier predicted “kt_eyes” 8 times, and “normal_eyes” 4 times.
  • In reality, 5 patients in the sample have normal_eyes, and 7 patients had kt_eyes.

Let’s now define the most basic terms, which are whole numbers (not rates):

  • true positives (TP): These are cases in which we predicted kt_eyes (they have the disease), and they do have the disease.
  • true negatives (TN): We predicted normal eyes, and they don’t have the disease.
  • false positives (FP): We predicted kt_eyes, but they don’t actually have the disease. (Also known as a “Type I error.”)
  • false negatives (FN): We predicted normal_eyes, but they actually do have the disease. (Also known as a “Type II error.”)

Our accuracy was 91.6% (TP+TN)/total

We can also plot images with top losses; in other words, the images that the model was most confused about. A high loss implies high confidence about the wrong answer.

Areas of further Research/ Improvement

  1. Dataset; More images are required to build a robust model.
  2. Further improvements/work: Build the CNN with Keras.

References

KeratoDetect: Keratoconus Detection Algorithm Using Convolutional Neural Networks

Conclusion

The full code can be found here.

Thank you for reading, follow @itsmuriuki on Twitter.

Back to code and more learning.