Using Keras and TensorFlow.js to classify seven types of skin lesions
After doing research on Convolutional Neural Networks, I became interested in developing an end-to-end machine learning solution. I decided to use the HAM10000 dataset to build a web app to classify skin lesions. In this article, I’ll provide some background information and explain some of the important concepts I learned while working on this project including Transfer Learning, Data Augmentation, Keras Callbacks, and TensorFlow.js.
Artificial intelligence is shaping the world around us. We interact with things that have been touched by machine learning on a daily basis. From our song and video recommendations to the smart assistants in our phones. But these are both consumer applications of AI, what about AI on a larger scale?
“Just as electricity transformed almost everything 100 years ago, today I actually have a hard time thinking of an industry that I don’t think AI will transform in the next several years.” — Andrew Ng
Want to publish your own articles on DistilINFO Publications?
Send us an email, we will get in touch with you.
Personally, I believe that one field with huge potential for deep learning is healthcare. Even though our technology has advanced significantly, there’s still one problem that still remains a significant issue today. It’s been reported that about 10 percent of deaths and 6 to 17 percent of hospital complications are due to diagnostic issues. Imagine reducing that number down to less than 5% by helping medical professionals diagnose patients with the help of machine learning models. The impact would be huge!
In my last article on Convolutional Neural Networks, I touched on how Computer Vision is being applied in various industries. I highly recommend you check it out here.
A Bit of Background Information on Skin Cancer
- More people are diagnosed with skin cancer each year in the U.S. than all other cancers combined.
- One in five Americans will develop skin cancer by the age of 70.
- Actinic keratosis affects more than 58 million Americans.
I wanted to build a solution that leverages a Convolutional Neural Network to help people classify different types of skin cancers quickly and accurately. My main goal was to create a project that is easily accessible and effective. I ultimately decided on building a web app.
For this project, I used the publicly available HAM10000 dataset which contains approximately 10,000 different images of skin lesions.
The categories of skin lesions include:
- Actinic keratoses and intraepithelial carcinoma (
akiec
): common non-invasive variants of squamous cell carcinomas. They are sometimes seen as precursors that may progress to invasive squamous cell carcinoma. - Basal cell carcinoma (
bcc
): a common version of epithelial skin cancer that rarely metastasizes but grows if it isn’t treated. - Benign keratosis (
bkl
): contains three subgroups (seborrheic keratoses, solar lentigo, and lichen-planus like keratoses (LPLK)). These groups may look different but are biologically similar. - Dermatofibroma (
df
): a benign skin lesion that is regarded as a benign proliferation or an inflammatory reaction to minimal trauma. - Melanoma (
mel
): a malignant neoplasm that can appear in different variants. Melanomas are usually, but not always, chaotic, and some criteria depend on the site location. - Melanocytic Nevi (
nv
): these variants can differ significantly from a dermatoscopic point of view but are usually symmetric in terms of distribution of color and structure. - Vascular Lesions (
vasc
): generally categorized by a red or purple color and solid, well-circumscribed structures known as red clods or lacunes.
For more information on the dataset or the skin cancer classifications please refer to this paper.
Using Transfer Learning for a Convolutional Neural Network Model
If you’ve been working with any kind of data, you’ll know that data is the most important thing when you’re developing deep learning models. However, most of the time your datasets probably won’t be large enough for optimal performance. By large, we’re talking about at least 50,000 images. Networks with a ton of layers are also incredibly expensive to train. It can take a super long time if you don’t have an amazing GPU (or several). ?
The entire idea behind Transfer Learning is that you can take a model that has already been pre-trained on a large dataset, modify it and retrain it on the dataset you’re currently working with.
As I explained in my previous article, Convolutional Neural Networks look for different features in images such as edges and shapes. We can take a neural network with millions of connections that has already been trained at identifying different features and retrain part of it by “freezing” the first few layers. After adding a fully connected layer and through only training the last few layers, we obtain a model that can effectively identify basic features but make predictions that generalize well with relevant data.
The Keras Applications library includes several deep learning models including VGG16, VGG19, ResNet50, MobileNet, and a few others. All of them have been trained on the ImageNet dataset which includes approximately 14 million images. That’s a pretty noticeable difference when compared to our 10,000-image dataset.
For this project, I chose to use the MobileNet architecture, which is optimized for mobile applications with less computing power. This architecture makes use of depth-wise separable convolutions which essentially helps to reduce the number of training parameters, making the model more lightweight. For more information on MobileNet, check out this paper.
Here’s how we can do it in Keras.
mobile = keras.applications.mobilenet.MobileNet() mobile.summary() x = mobile.layers[-6].output x = Dropout(0.25)(x) predictions = Dense(7, activation='softmax')(x) model = Model(inputs=mobile.input, outputs=predictions)
Preprocessing the Images of Skin Lesions
One nice thing about the HAM10000 dataset is that all of the images are the same size, 600×450. However, after looking at the distribution of the images, we see that a significant majority of the images belong to the class of melanocytic nevi.
Augmenting the Training Data
Data augmentation is super useful when it comes to increasing the number of training examples we can work with. We can augment the training data and For this, we use the Keras ImageDataGenerator class from the Keras Preprocessing library, which generates batches of tensor image data with real-time augmentation by looping through the data in batches. Some of the parameters that we pass through are
rotation_range
: which is the degree range for random rotationswidth_shift_range
: this represents a fraction of the total width that the image can be shifted byheight_shift_range
: this represents a fraction of the total height that the image can be shifted byzoom_range=0.1
: the fraction of the image that can be zoomed in or outhorizontal_flip=True
: randomly flips the input horizontallyvertical_flip=True
: randomly flips the input verticallyfill_mode='nearest'
: the specification for filling in points outside of the input boundarie
We can declare an augmented data generator by running the following code. Our target size is 224×224 because those are the dimensions that are needed for the MobileNet input layer.
datagen = ImageDataGenerator() aug_data = datagen.flow_from_directory(path, target_size=(224, 224), batch_size=batch_size)
Compiling the Model
The Keras Callbacks library provides a bunch of useful functions that can be applied at several stages during the training process of the model. These functions can be used to learn more about the internal states of the model. Two of the callbacks that are used in this program are ReduceLROnPlateau
and ModelCheckpoint
.
ReduceLROnPlateau
is used to reduce the learning rate when one of the model metrics has stopped improving. It’s been shown that models often benefit when the learning rate is reduced by a factor of 2–10 once the model stops improving after several iterations. Some important parameters are:
monitor
: the metric that will be used to evaluate whether or not the model is improvingfactor=0.5
: the factor by which the learning rate will be reducedpatience=2
: the number of epochs with the same accuracy after which the learning rate is reducedmode='max'
: reduces the learning rate when the model stops improving
reduce_lr = ReduceLROnPlateau(monitor='accuracy', factor=0.5, patience=2, verbose=1, mode='max', min_lr=0.00001)
ModelCheckpoint is used to save the model after every epoch. save_best_only=True makes sure that the best model isn’t overwritten.
checkpoint = ModelCheckpoint(filepath, monitor='val_top_3_accuracy', verbose=1, save_best_only=True, mode='max')
Plotting a Confusion Matrix of the Predictions
We can see that our model has acceptable performance and that a significant number of test examples for label nv
were classified correctly. The model mostly predicts akiec
, bcc
, bkl
, nv
, and vasc
correctly, but struggles with df
. The model sometimes confuses Melanoma (mel
) with Melanocytic nevi (nv
), as well as nv
with Benign keratosis (bkl
). The model still has much more room for improvement and finer tuning of hyperparameters may help.
Saving the Model and Converting it to TensorFlow.js
After training the model, we can find the Keras model in the local directory as model.h5
. We can convert it to a TensorFlow.js file by running the following code.
!pip install tensorflowjs os.mkdir('tfjs_dir') import tensorflowjs as tfjs tfjs.converters.save_keras_model(model, 'tfjs_dir')
Running Machine Learning in the Browser
TensorFlow.js is the JavaScript version of Google’s popular deep learning framework TensorFlow. It consists of a low-level core API and a high-level layers API. There are two main reasons why I think TensorFlow.js is pretty cool.
- TensorFlow.js with WebGL runs on any kind of GPU, including Nvidia, AMD, and phone GPUs as well.
- You can convert existing models into TensorFlow.js models and repurpose them easily.
- Models run locally in the browser, meaning that the user’s data never leaves their computer.
The last point is particularly important, as you can obviously imagine how important this would be in the future if self-diagnosing online becomes widespread. Training and inference on the client side can help make sure solutions are privacy-friendly!
You can find the code for my project here and a live version of the model here.
Key Takeaways:
- Transfer learning is useful when you don’t necessarily have a lot of data or computing power to work with
- Data augmentation is also another method for making sure you have enough training data to make sure your model performs well
- TensorFlow.js lets you run machine learning models in the browser easily and on the client side
Date: December 19, 2018
Source: Towards Data Science