Curve Fitting with SciPy

The usual approach to prediction problems these days is to create a machine learning model. However machine learning models can struggle to train on sparse data sets. They may also produce obviously incorrect results due to the model being unaware of real world constraints. For instance they may fit a decay curve that goes negative for some values for a target variable, that cannot go negative in the real world.

The good news is that SciPy’s scipy.optimize.curve_fit function can come to the rescue here. It is capable of producing an accurate fit given a very limited amount of data. Better still it can even handle complex curves with multiple parameters. For example here is a curve fit calculated from just 10 points:

an example of the close fit produced by SciPy's curve_fit function

As you can see the parameter of the curve has been correctly retrieved, armed with this and the equation of the curve, we can make excellent predictions.

A Simple Example

So how do we use curve_fit? Well to use curve fit you need two things. A function which defines the equation of the curve you want to fit to the data and a data set of the input and output variable values. You then make a call to curve_fit  with the function and its input and output values. To calculate the best parameters for the curve, curve_fit will return the unknown parameters and a covariance matrix.

import numpy as np
from scipy.optimize import curve_fit

# create a data set for y = 2(x^2)
x_values = np.array([1,2,3,4,5])
y_values = np.array([2,8,18,32,50])

# define the function we are going to curve fit to
def curve_function(x, a):
    return a*(x**2)

# perform our curve fit
parameter, covariance = curve_fit(curve_function, x_values, y_values)
print(parameter[0],covariance[0][0])

Given this dataset, curve_fit returns a value of 2 for the a parameter. In this case because the fit is exact the covariance is zero. We could now use our retrieved value of “a” along with a new value of “x” to predict and unknown “y” value.

Fitting a 3D Curve

Obviously this approach would be of limited utility if we could only fit one parameter as in the toy example above. However curve_fit is capable of fitting multiple parameters and can even fit higher dimensional curves. Lets look at another example

# 3D curve example
import numpy as np
from scipy.optimize import curve_fit
# create a data set
x_values = np.array([1,2,3,4,5])
y_values = np.array([5,4,3,2,1])
z_values = np.array([13,20,33,52,77])
#define the function we are going to curve fit to
def curve_function(variables, a, b):
    # unpack our variables
    x , y = variables
    return a*(x**2)+b*(y)

# perform our curve fit
values, covariance = curve_fit(
    curve_function, (x_values, y_values), z_values
)
print(values)

In this more complex case our fit is approximate rather than exact, but the parameters returned are correct to several decimal places. Given this new dataset, curve_fit returns a values extremely close to the correct values of  3 and 2 for the parameters “a” and “b”. The covariance should be zero and is in fact extremely small.

Note that all the input values we pass to curve_fit and thus to the curve function are passed as a tuple and then unpacked inside the curve function.

Fitting noisy data

These toy examples fit exact data but its is also possible to fit noisy data. An example can be found in this colaboratory workbook. In such a case the diagonal of the covariance matrix can be converted to standard deviations for the corresponding parameters using:


std_dev = np.sqrt(np.diag(covariance))

So how can we use this practically. Well suppose that we are attempting to develop a machine learning model. A first step might be for us to produce a simple curve fitted model to provide a baseline. The fitted curve then provides a quick cheap and simple way for us to make predictions of unknown values. This then gives  a machine learning algorithm a simple baseline model against which to assess performance. Any useful machine learning model should be able to outperform a curve fitted model.

The key barrier to producing a curve fit model is that we need to know what equation we should be attempting to fit. However this should not deter us. Generally people are quite good at recognising curve shapes to provide a candidate equation for fitting. Also in case of doubt, there is no reason that you could not try to fit a number of different curves to see which provides best fit. Candidates can include linear, polynomial, exponential, logarithmic, reciprocal or more unusual curves.

Applying Constraints

Another useful aspect of the curve_fit function is that we can apply constraints. This allows us to ensure that we will not fit a curve that produces nonsensical results. Examples might include ensuring that a decay curve remains positive even at high time values. To do this set the bounds parameter. This parameter takes a tuple for the minimum and maximum allowable values for each parameter that the fitted curve may  attain. To define different minima and maxima for different parameters, use a tuple of lists. When doing this np.inf and -np.inf can be used to allow bounding in only one direction.

To further reduce the possibilities of nonsensical results you can set the check_finite parameter to true, this will ensure that there are no nulls or np.inf values in your input data which might lead to nonsensical data fitting. Here is an example of usage

# perform 3D curve fit above, constraining a and b parameter values
values, covariance = curve_fit(
    curve_function,
    (x_values, y_values),
    z_values,
    check_finite = True,
    bounds = (
        [0, 1], # set minimum permissible values for a and b
        [4, np.inf], # set max permissible value for a but not b
    ),
) 

Further reading

The docs for scipy.optimize.curve_fit can be found at https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html including other parameters and options for curve_fit. The example workbook can be found on github or at google colaboratory

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.