Source code for pyunfold.callbacks


import sys
import numpy as np
from scipy.interpolate import UnivariateSpline


class Callback(object):
    """Callback base class
    """
    def __init__(self):
        pass

    def on_unfolding_begin(self, status=None):
        pass

    def on_unfolding_end(self, status=None):
        pass

    def on_iteration_begin(self, iteration, status=None):
        pass

    def on_iteration_end(self, iteration, status=None):
        pass


class CallbackList(object):
    """Container for Callback objects
    """
    def __init__(self, callbacks=None):
        self.callbacks = validate_callbacks(callbacks)

    def __len__(self):
        return len(self.callbacks)

    def __iter__(self):
        return iter(self.callbacks)

    def on_unfolding_begin(self, status=None):
        for callback in self.callbacks:
            callback.on_unfolding_begin(status=status)

    def on_unfolding_end(self, status=None):
        for callback in self.callbacks:
            callback.on_unfolding_end(status=status)

    def on_iteration_begin(self, iteration, status=None):
        for callback in self.callbacks:
            callback.on_iteration_begin(iteration=iteration, status=status)

    def on_iteration_end(self, iteration, status=None):
        for callback in self.callbacks:
            callback.on_iteration_end(iteration=iteration, status=status)


class Regularizer(object):
    """Regularizer callback base class
    """
    def __init__(self):
        pass


[docs]class Logger(Callback): """Logger callback Writes test statistic information for each iteration to sys.stdout. """ def __init__(self): super(Callback, self).__init__() pass def on_iteration_end(self, iteration, status): """Writes to sys.stdout Parameters ---------- iteration : int Unfolding iteration (i.e. iteration=1 after first unfolding iteration, etc.). status : dict Dictionary containing key value pairs for the current test statistic value (``'ts_iter'``) and the final test statistic stopping condition (``'ts_stopping'``). """ output = ('Iteration {}: ts = {:0.4f}, ts_stopping =' ' {}\n'.format(iteration, status['ts_iter'], status['ts_stopping'])) sys.stdout.write(output)
[docs]class SplineRegularizer(Callback, Regularizer): """Spline regularization callback Smooths the unfolded distribution at each iteration using ``UnivariateSpline`` from ``scipy.interpolate``. For more information about ``UnivariateSpline``, see the `UnivariateSpline API documentation <https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.UnivariateSpline.html>`_. Parameters ---------- degree : int, optional Degree of the smoothing spline. Must be <= 5 (default is 3, a cubic spline). smooth : float or None, optional Positive smoothing factor used to choose the number of knots. If 0, spline will interpolate through all data points (default is None). groups : array_like, optional Group labels for each cause bin. If groups are specified, then each cause group will be regularized independently (default is None). Notes ----- The number of causes must be larger than the spline ``degree``. Examples -------- Specify the spline degree and smoothing factor: >>> from pyunfold.callbacks import SplineRegularizer >>> reg = SplineRegularizer(degree=3, smooth=1.25) Different cause groups are also supported. For instance, in a problem with seven cause bins, if the first three cause bins belong to their own group, the next two cause bins belong to another group, and the last two cause bins belong to yet another group, an array can be constructed that identifies the group each cause bin belongs to. E.g. >>> groups = [0, 0, 0, 1, 1, 2, 2] >>> reg = SplineRegularizer(degree=3, smooth=1.25, groups=groups) If provided with a ``groups`` parameter, ``SplineRegularizer`` will regularize the unfolded distribution for each group independently. """ def __init__(self, degree=3, smooth=None, groups=None): super(SplineRegularizer, self).__init__() self.degree = degree self.smooth = smooth self.groups = np.asarray(groups) if groups is not None else None def on_iteration_end(self, iteration, status=None): y = status['unfolded'] x = np.arange(len(y), dtype=float) if self.groups is None: spline = UnivariateSpline(x, y, k=self.degree, s=self.smooth) fitted_unfolded = spline(x) else: # Check that a group is specified for each cause bin if len(self.groups) != len(y): err_msg = ('Invalid groups array. There should be an entry ' 'for each cause bin. However, got len(groups)={} ' 'while there are {} cause bins.'.format(len(self.groups), len(y))) raise ValueError(err_msg) fitted_unfolded = np.empty(len(y)) group_ids = np.unique(self.groups) for group in group_ids: group_mask = self.groups == group x_group = x[group_mask] y_group = y[group_mask] spline_group = UnivariateSpline(x_group, y_group, k=self.degree, s=self.smooth) fitted_unfolded_group = spline_group(x_group) fitted_unfolded[group_mask] = fitted_unfolded_group status['unfolded'] = fitted_unfolded
def validate_callbacks(callbacks): """Checks that input callbacks are indeed Callback object instances Parameters ---------- callbacks : Callback or iterable Input Callbacks. Can be either a single Callback or an interable of Callbacks. Returns ------- callbacks : list List of Callbacks. If ``callbacks`` is already a list of ``Callback`` objects, then it is returned. If ``callbacks`` is a ``Callback`` object, then a list containing ``callbacks`` is returned. If ``callbacks`` is None, an empty list is returned. """ if callbacks is None: callbacks = [] elif isinstance(callbacks, Callback): callbacks = [callbacks] else: if not all([isinstance(c, Callback) for c in callbacks]): invalid_callbacks = [c for c in callbacks if not isinstance(c, Callback)] raise TypeError('Found non-callback object in callbacks: {}'.format(invalid_callbacks)) return callbacks def extract_regularizer(callbacks): """Returns a Regularizer Callback from an input list of Callbacks Parameters ---------- callbacks : Callback or iterable Input Callbacks. Can be either a single Callback or an interable of Callbacks. Returns ------- regularizer : Regularizer Regularizer Callback if one exists in the input callbacks, otherwise None is returned. Raises ------ NotImplementedError Multiple Regularizers are in the input Callbacks. """ callbacks = validate_callbacks(callbacks) regularizers = [c for c in callbacks if isinstance(c, Regularizer)] if len(regularizers) > 1: raise NotImplementedError('Multiple regularizer callbacks where provided.') regularizer = regularizers[0] if len(regularizers) == 1 else None return regularizer def setup_callbacks_regularizer(callbacks): """Validates and formats input callbacks Parameters ---------- callbacks : Callback or iterable Input Callbacks. Can be either a single Callback or an interable of Callbacks. Returns ------- callbacks : CallbackList regularizer : Regularizer """ callbacks = validate_callbacks(callbacks) regularizer = extract_regularizer(callbacks) callbacks = CallbackList([c for c in callbacks if c is not regularizer]) return callbacks, regularizer