Building a convolutional neural network to quickly predict the presence of malaria parasitized cells in a thin blood smear.
Although the malaria virus doesn’t take the form of a mutant mosquito, it sure feels like a mutant problem. The deadly disease has reached epidemic, even endemic proportions in different parts of the world — killing around 400,000 people annually . In other areas of the world, it’s virtually nonexistent. Some areas are just particularly prone to a disease outbreak — there are certain factors that make an area more likely to be infected by malaria .
- High poverty levels
- Lack of access to proper healthcare
- Political instability
- Presence of disease transmission vectors (ex. mosquitos) 
With this mixture of these problems, we must keep some things in mind when building our model:
- There may be a lack of a reliable power source
- Battery-powered devices have less computational power
- There may be a lack of Internet connection (so training/storing on the cloud may be hard!)
While we want to obtain the highest level of accuracy as possible, we must also consider making the model as small and computationally efficient as possible — and also able to be deployed to edge and Internet of Things devices.
Current diagnosing methods of this disease are tedious and time-consuming.
The most widely used method (so far) is examining thin blood smears under a microscope, and visually searching for infected cells. The patients’ blood is smeared on a glass slide and stained with contrasting agents to better identify infected parasites in their red blood cells.
Then, a clinician manually counts the number of parasitic red blood cells — sometimes up to 5,000 cells (according to WHO protocol) .
Why a convolutional neural network?
Convolutional neural networks have the ability to automatically extract features and learn filters. In previous machine learning solutions, features had to be manually programmed in — for example, size, color, the morphology of the cells. Utilizing convolutional neural networks (CNN) will greatly speed up prediction time while mirroring (or even exceeding) the accuracy of clinicians.
We’re going to be using Keras, with a Tensorflow backend so go ahead and install that if you haven’t. Be sure to also install Numpy, Scikit-learn, Matplotlib, and imutils (a package of image processing and deep learning convenience functions created by Adrian Rosebrock).
Thankfully, we have a great dataset of labeled and preprocessed images to train and evaluate our model. NIH has a malaria dataset of 27,558 cell images with an equal number of parasitized and uninfected cells. A level-set based algorithm was applied to detect and segment the red blood cells. The images were collected and annotated by medical professionals; more information can be found here. Also, download the data from the page: the file is named cell_images.zip
I replicated Adrian Rosebrock’s blog post Deep Learning and Medical Image Analysis with Keras, which can be found here. Following his code, I posted my version on GitHub. You can download the source code to the project here to follow along or create your own.
Let’s get started!
First, create a folder/directory to store the project. Then, create a directory inside that called malaria, download the dataset into the directory and open it up.
$ cd whatever-you-named-your-directory
$ mkdir malaria
$ cd malaria
$ wget https://ceb.nlm.nih.gov/proj/malaria/cell_images.zip
$ unzip cell_images.zip
We’re going to switch back to our parent directory and make another directory called cnn where we store our configuration file and network architecture.
$ cd ..
$ mkdir cnn
$ cd cnn
Create a module inside cnn and name it config.py. This is our configuration file, and it will store all our constant variables.
Our configuration file initializes all the paths to our
- original dataset (Line 4)
- directory that contains the split between training and testing (Line 8)
- and to our newly separated training, validation and testing datasets (Line 11–13).
80% of the data is set aside for training out of the original dataset (Line 16). Then, out of the training data, another 10% percent will be used as validation data (Line 20).
Building the dataset
Create a module named build_dataset.py which should be located in your original directory. build_dataset.py will be used to create the paths in your filesystem. Open it up and insert the following code.
After importing all the necessary packages (Line 3–5), we’re going to shuffle all the images in our original dataset (Line 8–10).
First, let’s split the training data and the testing data using the index we defined from values set in our configuration file (Line 2–4); we then use the index to perform array slicing.
Set aside some of the training data for validation (Line 7–9) by over writing the index with another value, and repeating the same process.
Afterward, we’re are going to define our newly separated training, validation and testing datasets. The list datasets contain 3-tuples, of which contain:
- The name of the split
- The image paths for the split
- The path to the output directory for the split
Now, we can loop over the datasets.
Line 4 will print which data split it’s creating, and Lines 7–9 will create an output directory if it doesn’t exist.
On Line 12, we begin to loop over the input image paths. Line 14 will extract the filename of the input image and Line 15 will extract the corresponding class label (Parasitized, Unparasitized).
Then, build the paths to the subdirectories (Line 18) and create them if they don’t exist (Line 21–23). Each subdirectory (testing, validation, and training) is split into Parasitized and Unparasitized images.
Finally, we’re going to construct the path to the actual image and copy that into the subdirectory (Lines 27–28).
We can build our directory now! Open up a terminal, and execute the following command:
$ python build_dataset.py
The output should look like this:
[INFO] building 'training' split
[INFO] 'creating malaria/training' directory
[INFO] 'creating malaria/training/Uninfected' directory
[INFO] 'creating malaria/training/Parasitized' directory
[INFO] building 'validation' split
[INFO] 'creating malaria/validation' directory
[INFO] 'creating malaria/validation/Uninfected' directory
[INFO] 'creating malaria/validation/Parasitized' directory
[INFO] building 'testing' split
[INFO] 'creating malaria/testing' directory
[INFO] 'creating malaria/testing/Uninfected' directory
[INFO] 'creating malaria/testing/Parasitized' directory
Now that we’ve dealt with the data, let’s start training our model. Before we get to that, let’s take a quick look at the network architecture we’ll be using: a ResNet model replicated by Adrian Rosebrock in his book Deep Learning for Computer Vision with Python . His model was based on the 2015 academic publication Deep Residual Learning for Image Recognition by He et. al. , but smaller and more compact (we’re trying to reduce the storage size for our model).
Model Architecture (ResNet)
First of all, why did we choose ResNet? In a study done by Rajaraman et al., they used pre-trained convolutional neural networks to classify images in the same dataset we are using. They compared six models: AlexNet, VGG-16, ResNet-50, Xception, DenseNet-121, and a custom model they built. Their results showed that the ResNet-50 model consistently showed the most accurate results on this specific database — it also had the highest performance in metrics such as MCC and the F-1 score (which are important in healthcare applications) .
However, DenseNet-121 outperformed it in sensitivity (which is arguably one of the most important metrics in a healthcare setting) when they ran it on the most optimal layer (with feature extraction) rather than the final layer.
You can download the code from here, but if you already have the source code it’s included in the folder under cnn. For a detailed explanation on how to implement it from scratch (as well as a basic explanation of why it is so effective) you can check out my article here.
Training the model
Create a file called train_model.py within your original directory.
Take a look at the packages we’re going to importing:
- keras to train our model (Line 1–3)
- our custom ResNet model class (Line 4), and configuration file (Line 5)
- sklearn to print a classification report (Line 6)
- grabbing paths from our dataset (using imutils) on Line 7
- Matplotlib for plotting our graph (Line 8)
- numpy for numerical processing (Line 9)
- argparse for command line argument parsing (Line 10)
Notice that we’re using the “Agg” backend for matplotlib, as we’re saving our plot to disk.
The only command line argument we need to parse is “ — plot”. It will default to being placed in the current working directory (in this case it would be your parent directory) and named “plot.png”. You can rename the file at the command line when you execute the program.
Training parameters and learning rate decay function
First, let’s set our number of epochs we want to train for as 50 (Line 2). Our learning rate will be initialized as 0.1 (Line 3) which will decay according to our decay function (poly_decay on Lines 9–20). Batch size will be set as 32 (Line 4), which is a good number if you’re running it on a CPU, but you can increase this to 64 if you’re using a GPU.
Our polynomial decay function will make our learning rate decay after each epoch (Line 9). We initialized the power to be 1.0 (Line 14), which turns our polynomial decay into a linear decay. This is to reduce overfitting on our model.
Next, let’s grab the total amount of image paths in the testing/validation/testing sets to determine the total number of steps per epoch for validation & training.
Applying data augmentation to our images helps with regularization (which helps mitigate overfitting). Our network will generalize better to the validation data (which could mean performing worse on the training set).
First, let’s perform data augmentation on our training data by initializing ImageDataGenerator. We’re going to rescale our pixel values to the range [0, 1] (Line 4), and perform random transformations to each training example(Lines 5–11).
For our validation and training data, we’re simply going to rescale the pixel values to the range [0, 1].
Now let’s initialize our training generators which will load images from our input directory.
The Keras flow_from_directory function has two prerequisites: a base input directory for the data split, and N subdirectories inside the base input directory where each subdirectory corresponds to a class label (in our case N=2; Parasitized and Unparasitized).
Take a look at the parameters inputted into each generator:
- class_mode is set to “categorical” to perform one-hot encoding on class labels (Lines 4/13/22)
- target_size: images are resized to 64 x 64 pixels (Lines 5/14/23)
- color_mode is set to “rgb” channel ordering (Lines 6/15/24)
- shuffle image paths is only set as true for the training generator (Line 7)
- batch_size is set as BS = 32 (we already initialized earlier in the learning rate section) (Lines 8/17/26)
Initializing ResNet model
On Lines 2–3, we initialize ResNet with the following parameters.
- Images are 64 x 64 x 3 (length, width, and depth — 3-channel RGB images)
- 2 classes (Parasitised & Unparasitised)
- Stages = (3, 4, 6)
- Filters = (64, 128, 256, 512)
This implies that the first CONV layer (before reducing spatial dimensions) will have 64 total filters.
First, we will stack 3 sets of residual modules, the 3 CONV layers in each residual module will learn 32, 32 and 128 filters respectively. Dimensionality reduction is applied.
Then, 4 sets of residual modules are applied. The three CONV layers will learn 64, 64 and 256 filters — dimensionality reduction is applied again.
Finally, 6 sets of residual modules are stacked again where each CONV layer learns 128, 128 and 512 filters. Spatial dimensions are reduced a final time (Check out my article here for an explanation of stages and filters).
The optimizer we’ll be using is stochastic gradient descent (Line 4). Our learning rate is set to 0.1 and the momentum term as 0.9.
Finally, we compile our model on Lines 5–9. Our loss function is set as binary_crossentropy as we’re performing binary, 2-class classification.
Now, let’s define our set of callbacks — which are executed at the end of each epoch. To decay our learning rate after each epoch, we’re applying poly_decay LearningRateScheduler (Line 2).
The model.fit_generator call on Lines 3–9 tells our script to start our training process. The trainGen generator automatically loads our images from disk and parse the class labels from the image path. valGen does the same for the validation data.
Now that we’re finished the training part, we can evaluate it on our test set. We’ll make predictions on test data (Line 4–5) and find the label with the largest probability for each image in the dataset (Line 10).
Then, we’re going to print a classification_report in the terminal (Lines 13–14).
Finally, we’ll plot a graph of our testing and validation loss, as well as our validation and testing accuracy to see how well we did.
Make sure your project is in the right structure by referring to my repository on GitHub. Now, open up a terminal and execute the following command:
$ python train_model.py
After your model is done training, take a look at the classification report.
You should obtain
- 96.50% accuracy on the training data
- 96.78% accuracy on the validation data
- 97% accuracy on the testing data
Overall, the serialized model file is only 17.7MB. Quantizing the weights in the model themselves would allow us to obtain a model < 10MB.
And now you have it a complete end-to-end malaria classification system!
Now you can save this model to your disk and load new images for prediction. You can also deploy this on your website or on your app.
If you have any questions, feel free to reach out in the comments or through the following:
- Linkedin: https://www.linkedin.com/in/gracelynshi/
- Email me at email@example.com
Special thanks to Dr. Adrian Rosebrock at PyImageSearch’s for his blog post on this and his code accompanying it .
: World Health Organization, Fact Sheet: World Malaria Report 2016, https://www.who.int/malaria/media/world-malaria-report-2016/en/ (13 December 2016).
 World Health Organization, Malaria, https://www.who.int/news-room/fact-sheets/detail/malaria (19 November 2018).
: Carlos Atico Ariza, Malaria Hero: A web app for faster malaria diagnosis https://blog.insightdatascience.com/https-blog-insightdatascience-com-malaria-hero-a47d3d5fc4bb (Nov 6, 2018)
: Rajaraman et al., Pre-trained convolutional neural networks as feature extractors toward improved malaria parasite detection in thin blood smear images (2018). PeerJ 6:e4568; DOI 10.7717/peerj.4568
 A. Rosebrock, Deep Learning and Medical Image Analysis (2017), https://www.pyimagesearch.com/2018/12/03/deep-learning-and-medical-image-analysis-with-keras/
 A. Rosebrock, Deep Learning for Computer Vision with Python (2017)