When building a machine learning model, our customers often ask "How much training data will we need?" It's rather like kids in the back seat of the car on a long journey, asking how much further until we get there? Just as a map can help us work out how far we have to go on a journey, we can generate Learning Curves that show the direction and likely number of training samples we will need to complete our model.
Obtaining data to train machine learning models can be an expensive and time consuming process: asking how much data is required is very reasonable.
It is easy to say that "more data is always better", but there are diminishing returns to adding more training data as a dataset grows.
When there are only 50 datapoints to learn from, adding 10 more can make a big difference to the model. When there are 500 datapoints, the next 10 will make less of a difference, and when there are 5000 datapoints they might make almost no difference at all.
To answer the question of how much data is needed, we can use 'Learning Curves'. A learning curve is created by building our machine learning models on successively larger portions of the current dataset - for example, we might train models with 10% of the current data, then 20%, 30%, and so on. We evaluate the performance of each of these models, and plot those results on a graph (as shown below).
The exact shape of the learning curve is different for every dataset, but they have a general trend - the model performance flattens out as the effect of diminishing returns kicks in. There is also an element of randomness in each learning curve's exact shape - for example, from the randomness in selecting the specific 10% used for the first portion of the dataset. We actually calculate the learning curve many times over and the final curve is an average - the variation between the individual curves is shown with a shaded area on the graph.
The final slope of the learning curve can be used to estimate the expected benefit (in terms of improved accuracy) of collecting new data, which can be weighed up against the cost (in time or money) of collecting that data. Alternatively if the customer requires a specific level of accuracy from their model, we can extrapolate the learning curve to estimate how much more data we'd need to reach that point (assuming we haven't passed it already!).
Of course, collecting a dataset is not just about the raw numbers of samples - it is also important to collect the right distribution of samples, to ensure that it is representative and robust (stay tuned for future posts!). At Sagitto, we work directly with our customers to understand their use-cases and to ensure their datasets and their models are of the highest possible quality.