A better manager for supervising Python functions

| categories: programming | tags:

In the previous post I introduced a supervisor decorator to automate rerunning functions with new arguments to fix issues in them. Almost immediately after posting it, two things started bugging me. First, I thought it was annoying to have two separate arguments for results and exceptions. I would prefer one list of functions that do the right thing. Second, and most annoying, you have to be very careful in writing your checker functions to be consistent with how you called the function so you use exactly the same positional and keyword arguments. That is tedious and limits reusability/flexibility.

So, I wrote a new manager decorator that solves these two problems. Now, you can write checker functions that work on all the arguments of a function. You decorate the checker functions to indicate if they are for results or exceptions. This was a little more of a rabbit hole than I anticipated, but I persevered, and got to a solution that works for these examples. You can find all the code here.

Here is an example where we have a test function that we want to run with new arguments until we get a positive result. We start in a way that it is possible to get a ZeroDivisionError, and we handle that too.

from pycse.supyrvisor import manager, check_result, check_exception

@check_exception
def check1(args, exc):
    if isinstance(exc, ZeroDivisionError):
        print('ooo. caught 1/0, incrementing x')
        return {'x': 1}

@check_result
def check2(args, result):
    print(args)
    if result < 0:
        args['x'] += 1
        return args
        

@manager(checkers=[check1, check2])
def test(x, a=1):
    return a / x

test(-1)

{'x': -1, 'a': 1} ooo. caught 1/0, incrementing x {'x': 1} 1.0

This also works, so you can see this is better than the previous version which would not work if you change the signature.

test(a=1, x=-1)

{'x': -1, 'a': 1} ooo. caught 1/0, incrementing x {'x': 1} 1.0

1. The previous examples with manager

Here is the new syntax with manager.

import numpy as np
from scipy.optimize import minimize

def objective(x):
    return np.exp(x**2) - 10 * np.exp(x)


@check_result
def maxIterationsExceeded(args, sol):
    if sol.message == 'Maximum number of iterations has been exceeded.':
        args['maxiter'] *= 2
        return args

@manager(checkers=[maxIterationsExceeded], verbose=True)
def get_result(maxiter=2):
    return minimize(objective, 0.0, options={'maxiter': maxiter})

get_result(2)
Proposed fix in wrapper: {'maxiter': 4}
Proposed fix in wrapper: {'maxiter': 8}
  message: Optimization terminated successfully.
  success: True
   status: 0
      fun: -36.86307468296428
        x: [ 1.662e+00]
      nit: 5
      jac: [-4.768e-07]
 hess_inv: [[ 6.481e-03]]
     nfev: 26
     njev: 13

It works!

2. Stateful supervision

In this example, we aim to find the steady state concentrations of two species by integrating a mass balance to steady state. This is visually easy to see below, the concentrations are essentially flat after 10 min or so. Computationally this is somewhat tricky to find though. A way to do it is to compare some windows of integration to see if the values are not changing very fast. For instance you could average the values from 10 to 11, and compare that to the values in 11 to 12, and keep doing that until they are close enough to the same.

def ode(t, C):
    Ca, Cb = C
    dCadt = -0.2 * Ca + 0.3 * Cb
    dCbdt = -0.3 * Cb + 0.2 * Ca
    return dCadt, dCbdt

tspan = (0, 20)

from scipy.integrate import solve_ivp
sol = solve_ivp(ode, tspan, (1, 0))

import matplotlib.pyplot as plt
plt.plot(sol.t, sol.y.T)
plt.xlabel('t (min)')
plt.ylabel('C')
plt.legend(['A', 'B']);
sol.y.T[-1]
array([0.60003278, 0.39996722])

It is not crucial to use a class here; you could also use global variables, or function attributes. A class is a standard way of encapsulating state though. We just have to make the class callable so it acts like a function when we need it to.

class ReachedSteadyState:        
    def __init__(self, tolerance=0.01):
        self.tolerance = tolerance
        self.last_solution = None
        self.count = 0

    def __str__(self):
        return 'ReachedSteadyState'

    @check_result
    def __call__(self, args, sol):
        if self.last_solution is None:
            self.last_solution = sol
            self.count += 1
            args['C0'] = sol.y.T[-1]
            return args

        # we have a previous solution
        if not np.allclose(self.last_solution.y.mean(axis=1),
                           sol.y.mean(axis=1),
                           rtol=self.tolerance,
                           atol=self.tolerance):
            self.last_solution = sol
            self.count += 1
            args['C0'] = sol.y.T[-1]
            return args

rss = ReachedSteadyState(0.0001)

@manager(checkers=[rss], max_errors=20, verbose=True)        
def get_sol(C0=(1, 0), window=1):
    sol = solve_ivp(ode, t_span=(0, window), y0=C0)
    return sol

sol = get_sol((1, 0), window=2)
sol

Proposed fix in ReachedSteadyState: {'C0': array([0.74716948, 0.25283052]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.65414484, 0.34585516]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.61992776, 0.38007224]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.60733496, 0.39266504]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.60269957, 0.39730043]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.60099346, 0.39900654]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.60036557, 0.39963443]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.60013451, 0.39986549]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.60004949, 0.39995051]), 'window': 2}

  message: The solver successfully reached the end of the integration interval.
  success: True
   status: 0
        t: [ 0.000e+00  7.179e-01  2.000e+00]
        y: [[ 6.000e-01  6.000e-01  6.000e-01]
            [ 4.000e-01  4.000e-01  4.000e-01]]
      sol: None
 t_events: None
 y_events: None
     nfev: 14
     njev: 0
      nlu: 0

We can plot the two solutions to see how different they are. This shows they are close.

import matplotlib.pyplot as plt
plt.plot(rss.last_solution.t, rss.last_solution.y.T, label=['A previous' ,'B previous'])
plt.plot(sol.t, sol.y.T, '--', label=['A current', 'B current'])
plt.legend()
plt.xlabel('relative t')
plt.ylabel('C');

Those look pretty similar on this graph.

3. Handling exceptions

Suppose you have a function that randomly fails. This could be because something does not converge with a randomly chosen initial guess, converges to an unphysical answer, etc. In these cases, it makes sense to simply try again with a new initial guess.

For this example, say we have this objective function with two minima. We will say that any solution above 0.5 is unphysical.

def f(x):
    return -(np.exp(-50 * (x - 0.25)**2) + 0.5 * np.exp(-100 * (x - 0.75)**2))


x = np.linspace(0, 1)
plt.plot(x, f(x))
plt.xlabel('x')
plt.ylabel('y');

Here we define a function that takes a guess, and gets a solution. If the solution is unphysical, we raise an exception. We define a custom exception so we can handle it specifically.

class UnphysicalSolution(Exception):
    pass

def get_minima(guess):
    sol = minimize(f, guess)

    if sol.x > 0.5:
        raise UnphysicalSolution
    else:
        return sol

@check_exception
def try_again(args, exc):
    if isinstance(exc, UnphysicalSolution):
        args['guess'] = np.random.random()
        return args
  
@manager(checkers=(try_again,), verbose=True)    
def get_minima(guess):
    sol = minimize(f, guess)

    if sol.x > 0.5:
        raise UnphysicalSolution
    else:
        return sol

get_minima(np.random.random())
Proposed fix in wrapper: {'guess': 0.03789731690063758}
  message: Optimization terminated successfully.
  success: True
   status: 0
      fun: -1.0000000000069411
        x: [ 2.500e-01]
      nit: 4
      jac: [ 0.000e+00]
 hess_inv: [[ 1.000e-02]]
     nfev: 18
     njev: 9

You can see it took four iterations to find a solution. Other times it might take zero or one, or maybe more, it depends on where the guesses fall.

4. Summary

This solution works as well as supervisor did. It was a little deeper rabbit hole to go down, mostly because of some subtlety in making the result and exception decorators work for both functions and class methods. I think it is more robust now, as it should not matter how you call the function, and any combination of args and kwargs should be working.

Copyright (C) 2023 by John Kitchin. See the License for information about copying.

org-mode source

Org-mode version = 9.7-pre

Discuss on Twitter