python4enlightening

the open notebook

Plotting Polynomial InterPolation

This post is a note where I gain some experience with Python matplotlib and linear equations with NumPy. This post is initialized with a specific solution for only the quadratic polynomial. At the end of this post there is a program which generalizes the order of the polynomial solution and therefore the number of points which it is required to fit. ####Polynomial interpolation. Suppose we want to determine the quadratic polynomial \(p(x) = c_0 + c_1x + c_2x^2\) that passes through three given data points \((x_i,y_i)\) for \(i = 1,2,3\).

Requiring that \(p(x) = y_i\) for these three values gives a linear system of 3 equations in 3 unknowns (the coefficient \(c_i\)).

For example, if the data points are \((x_i,y_i) = (-1, 1), (0, -1), (2, 7)\) then the system of equations is

\[\begin{align} c_0 - c_1 + c_2 & = 1 \\ c_0 + 0c_1 + 0c_2 & = 0 \\ c_0 + 2c_1 + 4c_2 & = 7 \\ \end{align}\]

This has the form \(Ac = y\) where \(y\) is the vector of coefficients we seek, and \(A\) is the \(3 x 3\) matrix

\[\begin{equation} A = \begin{bmatrix} 1 & -1 & 1 \\ 1 & 0 & 0 \\ 1 & 2 & 4 \end{bmatrix} \end{equation}\]

To solve this system we can use the \(solve\) function from the \(numpy.linalg\) module. The solution is \(c = [-1, 0, 2]^T\) corresponding to the polynomial \(p(x) = 2x^2 - 1\), as easily verified. I will now demonstrate how to create a script for this in Python.

Solving the Equations in Python

In [94]:
"""
Script for quadratic interpolation.
"""
import numpy as np
from numpy.linalg import solve

# Data points:
xi = np.array([-1., 0., 2])
yi = np.array([1., -1., 7.])

A = np.array([[xi[0]**0, xi[0]**1, xi[0]**2], 
              [xi[1]**0, xi[1]**1, xi[1]**2], 
              [xi[2]**0, xi[2]**1, xi[2]**2]]);
print "A = \n", A

b = yi; 
c = solve(A,b);

print "\nThe polynomial coefficients are:"; print "c =", c
A = 
[[ 1. -1.  1.]
 [ 1.  0.  0.]
 [ 1.  2.  4.]]

The polynomial coefficients are:
c = [-1.  0.  2.]

A Function for Plotting the Solution

In [95]:
def plot_quad(xi,yi,c):
    import matplotlib.pyplot as plt
    clf()
    x = np.linspace(-2,3,1001)
    y = c[0] + c[1]*x + c[2]*x**2

    fig, ax = plt.subplots()
    ax.plot(x, y, 'b', label='Interpolating Polynomial')
    ax.plot(xi, yi, 'ro', label='Data Points')
    legend = ax.legend(loc='upper center')
    plt.title("Data points and interpolating polynomial", fontsize = '14')
    plt.grid()
    plt.ylim(min(yi)-50, max(yi)+50)

    # Set the fontsize
    for label in legend.get_texts():
        label.set_fontsize('large')

    for label in legend.get_lines():
        label.set_linewidth(1.5)

    ax.set_xticklabels(ax.get_xticks(), fontsize='16')
    ax.set_yticklabels(ax.get_yticks(), fontsize='16')

    plt.show()
plot_quad(xi,yi,c)
<matplotlib.figure.Figure at 0x112588450>

Rewrite as a Function for Quadratic Interpolation

In [96]:
"""
Demonstration module for quadratic interpolation.
"""
import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg import solve

def quad_interp(xi,yi):
    """
    Quadratic interpolation.  Compute the coefficients of the polynomial
    interpolating the points (xi[i],yi[i]) for i = 0,1,2.
    Returns c, an array containing the coefficients of
      p(x) = c[0] + c[1]*x + c[2]*x**2.
    """
    # check inputs and print error message if not valid:
    error_message = "xi and yi should have type numpy.ndarray"
    assert (type(xi) is np.ndarray) and (type(yi) is np.ndarray), error_message

    error_message = "xi and yi should have length 3"
    assert len(xi)==3 and len(yi)==3, error_message

    # The linear system to interpolate through data points
    A = np.vstack([xi**0, xi**1, xi**2]).T;
    c = solve(A,yi)
    return c

For testing the function with varying inputs, a loop is defined for producing inputs and calling the interpolation function and plotting function.

In [97]:
for i in range(3):
    xi = np.array([-1., 0., 2])
    yi = np.array([1.+100*i, -1.-2*i, 7.+15*i])

    c = quad_interp(xi,yi);
    plot_quad(xi,yi,c)
    print "The Solution is :\n", "c = ",c, "\ny = ", yi
<matplotlib.figure.Figure at 0x112552850>
The Solution is :
c =  [-1.  0.  2.] 
y =  [ 1. -1.  7.]

<matplotlib.figure.Figure at 0x11252cf90>
The Solution is :
c =  [ -3.         -65.16666667  38.83333333] 
y =  [ 101.   -3.   22.]

<matplotlib.figure.Figure at 0x11339f490>
The Solution is :
c =  [  -5.         -130.33333333   75.66666667] 
y =  [ 201.   -5.   37.]

Generalizing the functions for varying orders of polynomial interpolation

As it is seen in the plots the result is correct for these three inputs. Finally, a generalized solution is written where higher order systems can be solved. For doing so Python has the ability of using a list comprehension which is quite useful for producing vectors on the fly.

In [98]:
def poly_interp(xi_multi,yi_multi):
    """
    General polynomial interpolation. 

    Compute the coefficients of the polynomial
    interpolating the points (xi[i],yi[i]) for i = 0,1,2,...,n-1
    where n = len(xi) = len(yi).

    Returns c, an array containing the coefficients of
      p(x) = c[0] + c[1]*x + c[2]*x**2 + ... + c[N-1]*x**(N-1).

    """
    # check inputs and print error message if not valid:
    error_message = "xi and yi should have type numpy.ndarray"
    assert (type(xi_multi) is np.ndarray) and (type(yi_multi) is np.ndarray), error_message

    error_message = "xi and yi should have the same length "
    assert len(xi_multi)==len(yi_multi), error_message

    # The linear system to interpolate through data points:
    # Uses a list comprehension
    
    n = len(xi_multi)
    A = np.vstack([xi_multi**j for j in range(n)]).T
    c = solve(A,yi_multi)

    return c
In [99]:
def plot_poly(xi_multi, yi_multi, c):
    """
    Plot the resulting function along with the data points.
    """
    x = np.linspace(xi_multi.min() - 1,  xi_multi.max() + 1, 1000)

    # Using Horner's rule for defining interpolating polynomial:
    n = len(xi_multi)
    y = c[n-1]
    for j in range(n-1, 0, -1):
        y = y*x + c[j-1]
        
    ## Plotting
    plt.figure()
    plt.clf()
    plt.plot(x,y,'b-')
    plt.plot(xi_multi,yi_multi,'ro') 
    plt.ylim(yi_multi.min()-1, yi_multi.max()+1)

    plt.title("Data points and interpolating polynomial")
    plt.savefig('poly.png')
In [100]:
for i in range(2):
    xi_multi = np.array([-1., 0., 2, -3, -23, 5, 8])
    yi_multi = np.array([1.+100*i, -1.-3*i, 7.+34*i, 3-3*i, -34+8*i, -5-23*i, -56+56*i])

    c = poly_interp(xi_multi, yi_multi)
    plot_poly(xi_multi, yi_multi, c)

This was a short note of how to solve equations and do some simple plotting.

Comments