Skip to content

linear Module

This module contains the code defining a JIT-compatible linear interpolation scheme.

LinearInterpolator

Simple JIT-ed wrapper around numpy.interp for use with vector valued functions (i.e. f(x) -> [y,z])

We use numba to just-in-time compile these methods which results in a dramatic speed up on the order of a factor of 50.

Parameters:

Name Type Description Default
x_vals ndarray

The independent variable values, must be monotonically increasing.

required
y_vals ndarray

The corresponding function output. Should be a two dimensional array. The first dimension should be the same length as x. The second dimension is the length of the output of the interpolated function.

required

Attributes:

Name Type Description
x ndarray

The independent variable values, must be monotonically increasing.

y ndarray

The function values. Should be a two dimensional array. The first dimension should be the same length as x. The second dimension is the length of the output of the interpolated function.

x_min float

The low edge of the interpolation

x_max float

The high edge of the interpolation

Methods:

Name Description
LinearInterpolator

Construct the LinearInterpolator

check_values

Internal consistency check. Raises an Exception on failure.

interpolate

Perform interpolation for a set of x-coordinate values.

Source code in src/spyral/interpolate/linear.py
@jitclass(spec=linear_spec)  # type: ignore
class LinearInterpolator:
    """Simple JIT-ed wrapper around numpy.interp for use with vector valued functions (i.e. f(x) -> [y,z])

    We use numba to just-in-time compile these methods which results in a dramatic speed up on the order of
    a factor of 50.


    Parameters
    ----------
    x_vals: numpy.ndarray
        The independent variable values, must be monotonically increasing.
    y_vals: numpy.ndarray
        The corresponding function output. Should be a two dimensional array. The first dimension should be the
        same length as x. The second dimension is the length of the output of the interpolated function.

    Attributes
    ----------
    x: numpy.ndarray
        The independent variable values, must be monotonically increasing.
    y: numpy.ndarray
        The function values. Should be a two dimensional array. The first dimension should be the
        same length as x. The second dimension is the length of the output of the interpolated function.
    x_min: float
        The low edge of the interpolation
    x_max: float
        The high edge of the interpolation

    Methods
    -------
    LinearInterpolator(x_vals: ndarray, y_vals: ndarray)
        Construct the LinearInterpolator
    check_values()
        Internal consistency check. Raises an Exception on failure.
    interpolate(xs: ndarray) -> ndarray
        Perform interpolation for a set of x-coordinate values.
    """

    def __init__(self, x_vals: np.ndarray, y_vals: np.ndarray):
        self.x = x_vals
        self.y = y_vals
        self.x_min = x_vals[0]
        self.x_max = x_vals[-1]
        self.check_values()

    def check_values(self):
        """Internal consistency check. Raises an Exception on failure."""
        if len(self.y.shape) < 2:
            raise LinearInterpError(
                f"LinearInterpolator y-axis has invalid shape! Shape {self.y.shape} must be at minimum 2."
            )
        if len(self.x) != len(self.y[1]):
            raise LinearInterpError(
                f"LinearInterpolator axes do not match in length! x: {len(self.x)} y: {len(self.y[1])}"
            )

    def interpolate(self, xs: np.ndarray) -> np.ndarray:
        """Perform interpolation for a set of x-coordinate values.

        Parameters
        ----------
        xs: numpy.ndarray
            The x-values to perform interpolation on.

        Returns
        -------
        numpy.ndarray
            An 2-D array. Each row contains the corresponding interpolated value.

        """
        results = np.empty((len(xs), len(self.y)))
        for idx in range(len(self.y)):
            results[:, idx] = np.interp(xs, self.x, self.y[idx])
        for idx, x in enumerate(xs):
            if x < self.x_min:
                results[idx, :] = self.y[:, 0]
            elif x > self.x_max:
                results[idx, :] = self.y[:, -1]
        return results

check_values()

Internal consistency check. Raises an Exception on failure.

Source code in src/spyral/interpolate/linear.py
def check_values(self):
    """Internal consistency check. Raises an Exception on failure."""
    if len(self.y.shape) < 2:
        raise LinearInterpError(
            f"LinearInterpolator y-axis has invalid shape! Shape {self.y.shape} must be at minimum 2."
        )
    if len(self.x) != len(self.y[1]):
        raise LinearInterpError(
            f"LinearInterpolator axes do not match in length! x: {len(self.x)} y: {len(self.y[1])}"
        )

interpolate(xs)

Perform interpolation for a set of x-coordinate values.

Parameters:

Name Type Description Default
xs ndarray

The x-values to perform interpolation on.

required

Returns:

Type Description
ndarray

An 2-D array. Each row contains the corresponding interpolated value.

Source code in src/spyral/interpolate/linear.py
def interpolate(self, xs: np.ndarray) -> np.ndarray:
    """Perform interpolation for a set of x-coordinate values.

    Parameters
    ----------
    xs: numpy.ndarray
        The x-values to perform interpolation on.

    Returns
    -------
    numpy.ndarray
        An 2-D array. Each row contains the corresponding interpolated value.

    """
    results = np.empty((len(xs), len(self.y)))
    for idx in range(len(self.y)):
        results[:, idx] = np.interp(xs, self.x, self.y[idx])
    for idx, x in enumerate(xs):
        if x < self.x_min:
            results[idx, :] = self.y[:, 0]
        elif x > self.x_max:
            results[idx, :] = self.y[:, -1]
    return results