The most important lesson in Machine Learning
Most applied machine learning research looks beautiful in the paper, but fails spectacularly in real-life. They all fall prey to the same mistake: the original sin of machine learning.
When the COVID-19 pandemic struck, numerous researchers from laboratories worldwide took on the challenge of leveraging machine learning to address the most significant public health crisis we had ever faced. One prevalent approach involved training an image classifier to analyze X-ray images of lungs for COVID detection.
Initially, this task appeared deceptively simple. By 2020, eight years had passed since the creation of ImageNet, marking a pinnacle in computer vision advancements. With the advent of vision transformers, ConvNets, ResNets, YOLO, and other sophisticated architectures, the field of computer vision had made remarkable progress in image classification, seemingly solving many related challenges.
Expectations soared as researchers believed that with a suitable dataset containing COVID and non-COVID lung X-ray images, training a classifier would be straightforward. Teams worldwide embarked on this mission, collaborating extensively to gather and exchange datasets, resulting in approximately 400 research papers pursuing similar objectives.
Following more or less the same methodology, these researchers sourced COVID and non-COVID lung image datasets, trained classifiers, and achieved impressive accuracy rates ranging from 80% to 95%. Many of them claimed to have successfully solved the task of COVID detection from X-ray images.
It turns out all of this was useless. The vast majority fell prey to the original sin in machine learning. To understand why, let’s go back a few decades to one of the longest standing urban legends in artificial intelligence, and learn the most important lesson in Machine Learning.
That is not a tank! Or is it?
It was the early 1960s—or so the legend tells. There are many variants of this tale, and some of them place it in different years or have different details. But details don’t matter. So let’s move on.
Where was I? Oh yeah… It was the early 60s, and artificial intelligence researchers were very excited about this new connectionist idea that you could have small computational units similar to brain neurons connected together performing a small amount of computation. This would allow you to compute some really complex functions when these neurons—these artificial neural networks, as they were called—grew large enough.
And so—our tale continues—the US government hired some of these early connectionist researchers to work on a neural network that could detect whether an image contained an enemy tank or it was an image of an empty space, a forest, some trees, a landscape, or… something not tank.
These researchers already knew that in order to have a statistically valid estimation of how well your model works, you have to have separate training and testing data sets. So they set out to take pictures of tanks and pictures of forests and savannas and places with no tanks and made a data set of, let's say, 100 tanks and 100 non-tanks.
Then they split the data set into 50 pictures of each class for training and another 50 pictures of each class for testing, implemented their, at the moment, still very simple artificial neural network—probably a simple perceptron, the details don't really matter—and the results were spectacular. Not only was the neural network able to completely learn the training set, it also showed very high performance on the testing set.
Very happy with the results, the researchers sent this model to their contractors just to receive a few weeks later, their reaction letter, telling them that in all the cases when someone was testing this model in real life, with real tanks, and with real no-tanks, it didn't perform any better than a random guessing.
So they went back to the drawing board, scratching their heads, trying to find out what was missing. The statistics were sound, the mathematics was sound, the code was sound, but there was one problem in the data that they couldn't solve. It turns out two different teams took the pictures of the tank and the non-tank. The people who took the pictures of the tank did it on a cloudy day, while the people who were tasked to take pictures of non-tanks did it on a slightly more sunny day.
In these pictures, you could accurately classify which of the two sets they belonged to by simply looking at the average brightness of the pixels. The brighter pictures were more often correlated with non-tanks than the darker pictures. So the neural network had learned exactly what the researchers asked it to do, but not what they wanted it to do.
It learned to differentiate pictures from two different sets, but it didn't learn it by looking at the thing the researchers wanted. They wanted the network to detect a tank. Instead, the network learned to look at the average brightness of the pixels and perfectly captured the most salient difference between the two sets of photos. The brighter pictures belong to one set, and the darker pictures belong to the other set. Problem solved. Now give me the money.
This is just a cautionary tale, probably apocryphal. Many people have told this story in different moments, and nobody can find an actual reference. But it's a story that we keep telling because it highlights one of the most important issues at the core of machine learning. This is a story that I tell my students over and over, every course. And they nod in agreement, but they always fail to understand what the core of the issue is here, and they make this same mistake, over and over.
What is the problem?
The crux of the issue lies in the fact that machine learning, and particularly classification, primarily deals with discriminative models. A discriminative model aims to identify the most prominent feature that distinguishes objects among different classes.
In classification problems, you typically have a set of objects, each described by a specific number of features. These features can be words in a text message, pixels in an image, or structured features, for example, in credit card transaction validation. The objective is to classify these objects into various classes according to those features.
There are three primary types of classification problems: binary, multi-class, and multi-label. A binary problem separates objects into two classes, such as positive and negative. This is common in the medical domain, where the goal is often to distinguish between patients with and without a particular disease. A multi-class problem involves more than two classes, and the goal is to separate objects into one of these specific classes. A multi-label problem has more than one class, and a single object can belong to one or more classes.
In all these cases, a classifier attempts to find a combination of features that correlates with one class more than the others. The classifier then fits an explanation to this correlation by finding a mathematical formula that computes the most likely class for a specific combination of features.
But here is the kicker. Classification algorithms are designed to find the simplest or smallest possible explanation. To understand why, it's essential to revisit the basics of machine learning and the concepts of overfitting and underfitting.
When training a machine learning algorithm, you typically have a training set and a testing set. The algorithm is trained on a subset of data, and its performance is evaluated on a different subset of data that the algorithm hasn't seen before. There are three possible regimes in this scenario.
First, the algorithm may perform poorly on both the training and testing data, indicating underfitting. This usually means that the algorithm isn't powerful enough to capture the complex patterns and correlations in the data. The solution is to use a more intelligent or powerful algorithm.
Second, the algorithm may perform well on the training data but poorly on the testing data, indicating overfitting. In this case, the algorithm fails to capture a generalizable pattern, often because it's too complex or the training data is too small. The algorithm is learning to answer questions from the training data instead of learning the actual knowledge.
Finally, the algorithm may perform similarly on both the training and testing data, indicating that the algorithm has captured a real pattern in the training set that extends to the testing set.
To avoid underfitting, you can make the model stronger by adding more neurons in a neural network, more trees in a random forest, or more complex kernel functions in a support vector machine.
To avoid overfitting, you can regularize the model by adding constraints to the hypothesis or formula the model is building, which penalizes overly complex formulas. Regularization makes the model prefer solutions with smaller parameters, leading to smoother formulas that are more likely to generalize to unseen data.
In summary, good classification algorithms almost always regularize, leading to the simplest solution that explains the difference between your classes. For example, if you give a neural network two sets of images, one with tanks and one without, and the images with tanks are darker on average, the neural network will learn to differentiate between the darker and lighter images. This is the simplest possible explanation, and the neural network is more likely to find this explanation than a more complex one, provided they have more or less equally predictive power.
How can we avoid this?
By now, I hope you understand the difficulty of the problem we have in even the most basic machine learning scenarios. We want to make a classifier to decide whether an image of a lung has COVID or not. However, we have no way to explain to our machine learning algorithm that the thing we care about is COVID. The only thing the algorithm knows is that we have two sets of images, one for class A and one for class B. It needs to find the simplest way to discriminate between these two sets of images.
If we want to build a COVID classifier, we have to make sure that the images in the two sets are such that the simplest explanation that separates them is the presence or absence of COVID. But here's the problem: there are a thousand ways in which two different images of lungs can differ. They can be from different people of different ages, genders, races, taken with different types of X-ray machines, using different techniques, having different sizes, resolutions, brightness, and contrast.
If we want our algorithm to learn the right discriminant—the right explanation—we have to make sure that our two sets of images have almost exactly the same composition across all the things that we don't care about. We cannot have a set that has more images from one gender or race than the other set, or from one specific hospital than the other set—which might have used a different machine, and thus the images might be a little bit less or more blurry, or centered differently. Nor can we have images taken from before 2020 and images taken from after 2020 that may have—because they've updated the software—a different type of watermark in a small corner in the image.
We have to account for all of the possible variations that are irrelevant to our problem and make sure that our dataset is equally balanced across all classes for these variations. Otherwise, we run the risk that one of those variations is sufficiently prevalent in one class with respect to the other that our machine learning algorithm captures that as the most important difference. This becomes even more important the harder your problem is to solve. The harder it is to detect COVID in an image, the more things must be normalized, and everything else taken care of, so that the the only reasonable explanation for differentiating one image from another is the presence of COVID.
I'm sure by now you know how our original story ends. Why all these people, all these hundreds of research papers on COVID failed catastrophically. They all had methodological issues like the tanks urban legend.
Some of them had COVID images from one hospital and non-COVID images from a different hospital. Others had COVID images recently taken and non-COVID images that were from an older dataset. Others still had small differences in the data collection that made classes in one set and the other set differ in things other than the presence or absence of COVID. And then this huge performance that you see is just your classifier learning to differentiate between images taken in 2021 and images taken in 2018, or images taken with one X-ray machine versus another machine.
But it gets worse
Now, the most insidious part of this is that there is little you can do to notice this problem while doing the research. Here’s why.
The standard way in which you estimate the performance of your algorithms is by splitting your data into several chunks. Some of those chunks are used for training, parameter tuning, model selection, etc. In any case, you always leave at least one part of unseen data in which you test your final algorithm. This is the way you ensure generalization from training to test—generalization from seen to unseen data.
But if your data collection protocol is flawed—if you collected COVID images from one hospital and non-COVID images from another—then your test set will have the exact same flaws! Any statistics that you do will lie to you and tell you that you are doing it right because you are finding the same irrelevant differentiators in training as in test.
So these researchers were set up to fail from the get-go. From the moment they collected the data, they were doomed. There was nothing they could do once they decided on that collection mechanism. There was nothing they could do during their research process, during their validation process, or during their statistical analysis to realize the massive flaw they had at the very core of their assumptions.
Now, you can imagine this happens all over the place. It happened to two of my graduate students when they were in undergrad. They also trained a pretty standard ResNet on COVID vs. non-COVID images and obtained results that were too good to be true, because they weren’t. They collected the COVID images from one dataset and the non-COVID from a different dataset, and there were tiny differences in the image format, color composition, etc, not enough to be seen by the naked eye, but enough to help the neural net cheat on the COVID exam and gain extra points by looking at the wrong features.
And it also happened to superstar machine learning practitioner Andrew Ng and his superstar team of researchers. They were also analyzing images of lungs, and did all the right things. Except, when they split the data, different pictures of the same person fell on both the training and the test set. So the model could learn to detect subtle cues that appeared on both the training and test set—maybe a patient had a scar from a previous condition—and obtain a few extra points of accuracy. Not exactly the same level of critical failure as the other, but still, an example of how even the most professional among us can fall for this insidious mistake.
But this common pitfall is not restricted to COVID images or even to image classification. T is just one of the most salient, recent examples of this problem. But this happens all over in machine learning. This is why the vast majority of machine learning research that you see in papers fails to generalize.
I’m at the moment reviewing a very good paper on using LLMs to detect fake news. Very good, except they also fell prey to this mistake. They collected truthful news from around the web, but they couldn’t find a reliable source of fake news, of course. So they resorted to using ChatGPT to generate plausibly-sounding but fabricated news. And they got something like 95% accuracy, breaking several benchmarks! But—this is my hypothesis—what their classifier learned was to differentiate regular news text from ChatGPT-generated text. And not any ChatGPT-generated text—which is unsolvable in practice—but ChatGPT using their specific prompt.
The bigger lesson
Most applied ML research looks excellent in the paper. However, when you go to put it into practice and have to deal with the nitty-gritty of reality, it turns out that this performance you see in the paper is nowhere near real-life performance.
This is the hardest type of generalization. Generalizing from test set to application is the moment in which all your assumptions about data collection, about how data is distributed in real life, and about what is important or not in the data, start to fall apart.
So the bigger lesson from this super common and super insidious pitfall is this: data collection is where your assumptions begin. No ML problem starts at the training phase, or even the modeling phase. The way you decide to collect, filter, and sanitize your data already encodes a huge set of assumptions that separate the problem you’re actually solving from the problem you think you’re solving.
But we are teaching whole generations of machine learning researchers and practitioners to focus on finding the best algorithms and architectures. And that’s important, of course. But it is not the most critical part. Data collection and curation is the make it or break it of machine learning. Good data plus mediocre model beats bad data with awesome model every single day.
There were three acronyms in computing in the late 70’s early 80’s. KISS, GIGA, RTKP.
Keep it simple stupid.
Garbage in garbage out.
Reduce-it to-a known problem.
Guess they don’t teach those ideas anymore?
It doesn't matter how well we've coded the algorithm if it's on top of bad data...