Adding new backends to hashcache

| categories: programming | tags:

I have been working on hashcache to make it more flexible. I like the base functionality that uses the filesystem for caching. That still works.

Here I set up a timeit decorator to show how this works.

from pycse.hashcache import hashcache
import time

!rm -fr ./cache

def timeit(func):
    def wrapper(*args, **kwargs):
        t0 = time.time()
        res = func(*args, **kwargs)
        print(f'Elapsed time = {time.time() - t0}s')
        return res
    return wrapper

Now we decorate a function that is "expensive". The first time we run it, it takes a long time.

@timeit
@hashcache
def expensive_func(x):
    time.sleep(3)
    return x

expensive_func(2)

Elapsed time = 3.007030963897705s 2

The second time is very fast, since we just look it up.

expensive_func(2)

Elapsed time = 0.0012097358703613281s 2

Where did we look it up from? It is stored on disk. You can see where by adding a verbose option to the decorator. This shows you all the data that was stored in the cache.

@hashcache(verbose=True)
def expensive_func(x):
    time.sleep(3)
    return x

expensive_func(2)

{ 'args': (2,), 'cwd': '/Users/jkitchin/Dropbox/emacs/journal/2023/09/23', 'elapsed_time': 3.0048787593841553, 'func': 'expensive_func', 'hash': 'b5436cc21714a7ea619729cc9768b8c5b3a03307', 'kwargs': {}, 'module': 'main', 'output': 2, 'run-at': 1695572717.2020931, 'run-at-human': 'Sun Sep 24 12:25:17 2023', 'standardized-kwargs': {'x': 2}, 'user': 'jkitchin', 'version': '0.0.2'} 2

1. Alternative backends for hashcache

The file system is an amazing cache with many benefits. There are few reasons you might like something different though. For example, it is slow to search if you have to iterate over all the directories and read the files, and it might be slow to sync lots of directories to another place.

hashcache is more flexible now, so you can define the functions that load and dump the cache. Here we use lmdb as a key-value database. lmdb expects the keys and values to be bytes, so we do some tricks with io.BytesIO to get these as strings from joblib.dump which expects to write to a file.

The load function has the signature (hash, verbose), and the dump function has the signature (hash, data, verbose). In both cases, hash will be a string for the key to save data in. data will be a dictionary that should be saved in a way that it can be reloaded. verbose is a flag that you can ignore or use to provide some kind of logging.

from pycse.hashcache import hashcache

import io, joblib, lmdb

def lmdb_dump(hsh, data, verbose=False):
    if verbose:
        print('running lmdb_dump')
    with io.BytesIO() as f:
        joblib.dump(data, f)
        value = f.getvalue()

    db = lmdb.Environment(hashcache.cache)
    with db.begin(write=True) as txn:
        txn.put(hsh.encode('utf-8'), value)

def lmdb_load(hsh, verbose=False):
    if verbose:
        print('running lmdb_load')
    db = lmdb.Environment(hashcache.cache)
    with db.begin() as txn:
        val = txn.get(hsh.encode('utf-8'))
        if val is None:
            return False, None
        else:
            return True, joblib.load(io.BytesIO(val))['output']
                                    
! rm -fr cache.lmdb

hashcache.cache = 'cache.lmdb'


@hashcache(loader=lmdb_load, dumper=lmdb_dump, verbose=True)
def f(x):
    return x

f(2)   

running lmdb_load running lmdb_dump 2

And we can recall the result as easily.

f(2)

running lmdb_load 2

2. a shelve version

Maybe you prefer a built in library like shelve. This is also quite simple.

from pycse.hashcache import hashcache

import io, joblib, shelve

def shlv_dump(hsh, data, verbose=False):
    print('running shlv_dump')
    with io.BytesIO() as f:
        joblib.dump(data, f)
        value = f.getvalue()

    with shelve.open(hashcache.cache) as db:
        db[hsh] = value

def shlv_load(hsh, verbose=False):
    print('running shlv_load')
    with shelve.open(hashcache.cache) as db:
        if hsh in db:
            return True, joblib.load(io.BytesIO(db[hsh]))['output']
        else:
            return False, None

hashcache.cache = 'cache.shlv'
! rm -f cache.shlv.db

@hashcache(loader=shlv_load, dumper=shlv_dump)
def f(x):
    return x

f(2)
    

running shlv_load running shlv_dump 2

And again loading is easy.

f(2)

running shlv_load 2

3. sqlite version

I am a big fan of sqlite. Here I use a simple table mapping a key to a value. I think it could be interesting to consider storing the value as json that would make it more searchable, or you could make a more complex table, but here we keep it simple.

from pycse.hashcache import hashcache

import io, joblib, sqlite3

def sql_dump(hsh, data, verbose=False):
    print('running sql_dump')
    with io.BytesIO() as f:
        joblib.dump(data, f)
        value = f.getvalue()

    with con:
        con.execute("INSERT INTO cache(hash, value) VALUES(?, ?)",
                    (hsh, value))

def sql_load(hsh, verbose=False):
    print('running sql_load')
    with con:        
        cur = con.execute("SELECT value FROM cache WHERE hash = ?",
                          (hsh,))
        value = cur.fetchone()
        if value is None:
            return False, None
        else:
            return True, joblib.load(io.BytesIO(value[0]))['output']

! rm -f cache.sql
hashcache.cache = 'cache.sql'
con = sqlite3.connect(hashcache.cache)
con.execute("CREATE TABLE cache(hash TEXT unique, value BLOB)")
        
@hashcache(loader=sql_load, dumper=sql_dump)
def f(x):
    return x

f(2)    

running sql_load running sql_dump 2

Once again, running is easy.

f(2)

running sql_load 2

4. redis

Finally, you might like a server to cache in. This opens the door to running the server remotely so it is accessible by multiple processes using the cache on different machines. We use redis for this example, but only run it locally. Make sure you run redis-server --daemonize yes

from pycse.hashcache import hashcache

import io, joblib, redis

db = redis.Redis(host='localhost', port=6379)

def redis_dump(hsh, data, verbose=False):
    print('running redis_dump')
    with io.BytesIO() as f:
        joblib.dump(data, f)
        value = f.getvalue()

    db.set(hsh, value)

def redis_load(hsh, verbose=False):
    print('running redis_load')
    if not hsh in db:
        return False, None
    else:
        return True, joblib.load(io.BytesIO(db.get(hsh)))['output']

    
import functools    
hashcache_redis = functools.partial(hashcache,
                                    loader=redis_load,
                                    dumper=redis_dump)    

@hashcache_redis
def f(x):
    return x

f(2)    

running redis_load running redis_dump 2

No surprise here, loading is the same as before.

f(2)

running redis_load 2

5. Summary

I have refactored hashcache to make it much easier to add new backends. You might do that for performance, ease of backup or transferability, to add new capabilities for searching, etc. The new code is a little cleaner than it was before IMO. I am not sure it is API-stable yet, but it is getting there.

Copyright (C) 2023 by John Kitchin. See the License for information about copying.

org-mode source

Org-mode version = 9.7-pre

Discuss on Twitter

A better manager for supervising Python functions

| categories: programming | tags:

In the previous post I introduced a supervisor decorator to automate rerunning functions with new arguments to fix issues in them. Almost immediately after posting it, two things started bugging me. First, I thought it was annoying to have two separate arguments for results and exceptions. I would prefer one list of functions that do the right thing. Second, and most annoying, you have to be very careful in writing your checker functions to be consistent with how you called the function so you use exactly the same positional and keyword arguments. That is tedious and limits reusability/flexibility.

So, I wrote a new manager decorator that solves these two problems. Now, you can write checker functions that work on all the arguments of a function. You decorate the checker functions to indicate if they are for results or exceptions. This was a little more of a rabbit hole than I anticipated, but I persevered, and got to a solution that works for these examples. You can find all the code here.

Here is an example where we have a test function that we want to run with new arguments until we get a positive result. We start in a way that it is possible to get a ZeroDivisionError, and we handle that too.

from pycse.supyrvisor import manager, check_result, check_exception

@check_exception
def check1(args, exc):
    if isinstance(exc, ZeroDivisionError):
        print('ooo. caught 1/0, incrementing x')
        return {'x': 1}

@check_result
def check2(args, result):
    print(args)
    if result < 0:
        args['x'] += 1
        return args
        

@manager(checkers=[check1, check2])
def test(x, a=1):
    return a / x

test(-1)

{'x': -1, 'a': 1} ooo. caught 1/0, incrementing x {'x': 1} 1.0

This also works, so you can see this is better than the previous version which would not work if you change the signature.

test(a=1, x=-1)

{'x': -1, 'a': 1} ooo. caught 1/0, incrementing x {'x': 1} 1.0

1. The previous examples with manager

Here is the new syntax with manager.

import numpy as np
from scipy.optimize import minimize

def objective(x):
    return np.exp(x**2) - 10 * np.exp(x)


@check_result
def maxIterationsExceeded(args, sol):
    if sol.message == 'Maximum number of iterations has been exceeded.':
        args['maxiter'] *= 2
        return args

@manager(checkers=[maxIterationsExceeded], verbose=True)
def get_result(maxiter=2):
    return minimize(objective, 0.0, options={'maxiter': maxiter})

get_result(2)
Proposed fix in wrapper: {'maxiter': 4}
Proposed fix in wrapper: {'maxiter': 8}
  message: Optimization terminated successfully.
  success: True
   status: 0
      fun: -36.86307468296428
        x: [ 1.662e+00]
      nit: 5
      jac: [-4.768e-07]
 hess_inv: [[ 6.481e-03]]
     nfev: 26
     njev: 13

It works!

2. Stateful supervision

In this example, we aim to find the steady state concentrations of two species by integrating a mass balance to steady state. This is visually easy to see below, the concentrations are essentially flat after 10 min or so. Computationally this is somewhat tricky to find though. A way to do it is to compare some windows of integration to see if the values are not changing very fast. For instance you could average the values from 10 to 11, and compare that to the values in 11 to 12, and keep doing that until they are close enough to the same.

def ode(t, C):
    Ca, Cb = C
    dCadt = -0.2 * Ca + 0.3 * Cb
    dCbdt = -0.3 * Cb + 0.2 * Ca
    return dCadt, dCbdt

tspan = (0, 20)

from scipy.integrate import solve_ivp
sol = solve_ivp(ode, tspan, (1, 0))

import matplotlib.pyplot as plt
plt.plot(sol.t, sol.y.T)
plt.xlabel('t (min)')
plt.ylabel('C')
plt.legend(['A', 'B']);
sol.y.T[-1]
array([0.60003278, 0.39996722])

It is not crucial to use a class here; you could also use global variables, or function attributes. A class is a standard way of encapsulating state though. We just have to make the class callable so it acts like a function when we need it to.

class ReachedSteadyState:        
    def __init__(self, tolerance=0.01):
        self.tolerance = tolerance
        self.last_solution = None
        self.count = 0

    def __str__(self):
        return 'ReachedSteadyState'

    @check_result
    def __call__(self, args, sol):
        if self.last_solution is None:
            self.last_solution = sol
            self.count += 1
            args['C0'] = sol.y.T[-1]
            return args

        # we have a previous solution
        if not np.allclose(self.last_solution.y.mean(axis=1),
                           sol.y.mean(axis=1),
                           rtol=self.tolerance,
                           atol=self.tolerance):
            self.last_solution = sol
            self.count += 1
            args['C0'] = sol.y.T[-1]
            return args

rss = ReachedSteadyState(0.0001)

@manager(checkers=[rss], max_errors=20, verbose=True)        
def get_sol(C0=(1, 0), window=1):
    sol = solve_ivp(ode, t_span=(0, window), y0=C0)
    return sol

sol = get_sol((1, 0), window=2)
sol

Proposed fix in ReachedSteadyState: {'C0': array([0.74716948, 0.25283052]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.65414484, 0.34585516]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.61992776, 0.38007224]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.60733496, 0.39266504]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.60269957, 0.39730043]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.60099346, 0.39900654]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.60036557, 0.39963443]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.60013451, 0.39986549]), 'window': 2} Proposed fix in ReachedSteadyState: {'C0': array([0.60004949, 0.39995051]), 'window': 2}

  message: The solver successfully reached the end of the integration interval.
  success: True
   status: 0
        t: [ 0.000e+00  7.179e-01  2.000e+00]
        y: [[ 6.000e-01  6.000e-01  6.000e-01]
            [ 4.000e-01  4.000e-01  4.000e-01]]
      sol: None
 t_events: None
 y_events: None
     nfev: 14
     njev: 0
      nlu: 0

We can plot the two solutions to see how different they are. This shows they are close.

import matplotlib.pyplot as plt
plt.plot(rss.last_solution.t, rss.last_solution.y.T, label=['A previous' ,'B previous'])
plt.plot(sol.t, sol.y.T, '--', label=['A current', 'B current'])
plt.legend()
plt.xlabel('relative t')
plt.ylabel('C');

Those look pretty similar on this graph.

3. Handling exceptions

Suppose you have a function that randomly fails. This could be because something does not converge with a randomly chosen initial guess, converges to an unphysical answer, etc. In these cases, it makes sense to simply try again with a new initial guess.

For this example, say we have this objective function with two minima. We will say that any solution above 0.5 is unphysical.

def f(x):
    return -(np.exp(-50 * (x - 0.25)**2) + 0.5 * np.exp(-100 * (x - 0.75)**2))


x = np.linspace(0, 1)
plt.plot(x, f(x))
plt.xlabel('x')
plt.ylabel('y');

Here we define a function that takes a guess, and gets a solution. If the solution is unphysical, we raise an exception. We define a custom exception so we can handle it specifically.

class UnphysicalSolution(Exception):
    pass

def get_minima(guess):
    sol = minimize(f, guess)

    if sol.x > 0.5:
        raise UnphysicalSolution
    else:
        return sol

@check_exception
def try_again(args, exc):
    if isinstance(exc, UnphysicalSolution):
        args['guess'] = np.random.random()
        return args
  
@manager(checkers=(try_again,), verbose=True)    
def get_minima(guess):
    sol = minimize(f, guess)

    if sol.x > 0.5:
        raise UnphysicalSolution
    else:
        return sol

get_minima(np.random.random())
Proposed fix in wrapper: {'guess': 0.03789731690063758}
  message: Optimization terminated successfully.
  success: True
   status: 0
      fun: -1.0000000000069411
        x: [ 2.500e-01]
      nit: 4
      jac: [ 0.000e+00]
 hess_inv: [[ 1.000e-02]]
     nfev: 18
     njev: 9

You can see it took four iterations to find a solution. Other times it might take zero or one, or maybe more, it depends on where the guesses fall.

4. Summary

This solution works as well as supervisor did. It was a little deeper rabbit hole to go down, mostly because of some subtlety in making the result and exception decorators work for both functions and class methods. I think it is more robust now, as it should not matter how you call the function, and any combination of args and kwargs should be working.

Copyright (C) 2023 by John Kitchin. See the License for information about copying.

org-mode source

Org-mode version = 9.7-pre

Discuss on Twitter

Supervising Python functions

| categories: programming | tags:

[UPDATE [2023-09-21 Thu]]: See this new post for an update and improved version of this post.

In the last post I talked about using custodian to supervise Python functions. I noted it felt a little heavy, so I wrote a new decorator that does basically the same thing. TL;DR I am not sure this is less heavy, but I learned some things doing it. The code I used is part of pycse at https://github.com/jkitchin/pycse/blob/master/pycse/supyrvisor.py. Check out the code to see how this works.

Here is the prototype problem it solves. This code runs, but does not succeed because it exceeds the maximum iterations.

import numpy as np
from scipy.optimize import minimize

def objective(x):
    return np.exp(x**2) - 10 * np.exp(x)

minimize(objective, 0.0, options={'maxiter': 2})
  message: Maximum number of iterations has been exceeded.
  success: False
   status: 1
      fun: -36.86289091418059
        x: [ 1.661e+00]
      nit: 2
      jac: [-2.374e-01]
 hess_inv: [[ 6.889e-03]]
     nfev: 20
     njev: 10

The solution is simple, you increase the number of iterations. That is tedious to do manually though, and not practical if you do this hundreds of times in a study. Enter pycse.supyrvisor. It provides a decorator to do this. Similar to custodian, we have to define a function that has arguments to change this. We do this here. This function still does not succeed yet.

def get_result(maxiter=2):
    return minimize(objective, 0.0, options={'maxiter': maxiter})

get_result(2)
  message: Maximum number of iterations has been exceeded.
  success: False
   status: 1
      fun: -36.86289091418059
        x: [ 1.661e+00]
      nit: 2
      jac: [-2.374e-01]
 hess_inv: [[ 6.889e-03]]
     nfev: 20
     njev: 10

Next, we need a "checker" function. The role of this function is to check the output of the function, determine if it is ok, and if not, to return a new set of arguments to run the function with. There are some subtleties in this. You can call your function with a combination of args and kwargs, and you have to write this function in a way that is consistent with how you call the function. In the example above, we called get_result(2) where the 2 is a positional argument. In this checker function we write it with that in mind. If we detect that the minimizer failed because of exceeding the maximum number of iterations, we get the argument and double it. Then, we return the new args and kwargs. Otherwise this function returns None, indicating the solution is fine as far as this function is concerned.

def maxIterationsExceeeded(args, kwargs, sol):
    if sol.message == 'Maximum number of iterations has been exceeded.':
        maxiter = args[0]
        maxiter *= 2
        return (maxiter,), kwargs
        

Finally, we get the supervisor decorator, and decorate the function.

from pycse.supyrvisor import supervisor

get_result = supervisor(check_funcs=[maxIterationsExceeeded], verbose=True)(get_result)

get_result(2)
Proposed fix in maxIterationsExceeeded: ((4,), {})
Proposed fix in maxIterationsExceeeded: ((8,), {})
  message: Optimization terminated successfully.
  success: True
   status: 0
      fun: -36.86307468296428
        x: [ 1.662e+00]
      nit: 5
      jac: [-4.768e-07]
 hess_inv: [[ 6.481e-03]]
     nfev: 26
     njev: 13

It works!

1. Stateful supervision

In this example, we aim to find the steady state concentrations of two species by integrating a mass balance to steady state. This is visually easy to see below, the concentrations are essentially flat after 10 min or so. Computationally this is somewhat tricky to find though. A way to do it is to compare some windows of integration to see if the values are not changing very fast. For instance you could average the values from 10 to 11, and compare that to the values in 11 to 12, and keep doing that until they are close enough to the same.

def ode(t, C):
    Ca, Cb = C
    dCadt = -0.2 * Ca + 0.3 * Cb
    dCbdt = -0.3 * Cb + 0.2 * Ca
    return dCadt, dCbdt

tspan = (0, 20)

from scipy.integrate import solve_ivp
sol = solve_ivp(ode, tspan, (1, 0))

import matplotlib.pyplot as plt
plt.plot(sol.t, sol.y.T)
plt.xlabel('t (min)')
plt.ylabel('C')
plt.legend(['A', 'B']);
sol.y.T[-1]
array([0.60003278, 0.39996722])

The goal then is to have a supervisor function that will keep track of the last solution and the current one, and compare the average of them. You could do something more sophisticated, but this is simple enough to try out now. If the difference between two integrations is small enough, we will say we have hit steady state, and if not, we integrate from the end of the last solution forward again. That means we have to store some state information so we can compare a current solution to the last solution.

Let's start by defining a function that returns a solution from some initial condition. Next, we show that if you run it 12ish times, initializing from the last state, we get something that appears steady-stateish in the sense that the y values only changing in the second decimal place. You might consider that close enough to steady state.

def get_sol(C0=(1, 0), window=1):
    sol = solve_ivp(ode, t_span=(0, window), y0=C0)
    return sol

sol = get_sol()
sol = get_sol(sol.y.T[-1])
sol = get_sol(sol.y.T[-1])
sol = get_sol(sol.y.T[-1])
sol = get_sol(sol.y.T[-1])
sol = get_sol(sol.y.T[-1])
sol = get_sol(sol.y.T[-1])
sol = get_sol(sol.y.T[-1])
sol = get_sol(sol.y.T[-1])
sol = get_sol(sol.y.T[-1])
sol = get_sol(sol.y.T[-1])
sol = get_sol(sol.y.T[-1])
sol
  message: The solver successfully reached the end of the integration interval.
  success: True
   status: 0
        t: [ 0.000e+00  3.565e-01  1.000e+00]
        y: [[ 6.016e-01  6.014e-01  6.010e-01]
            [ 3.984e-01  3.986e-01  3.990e-01]]
      sol: None
 t_events: None
 y_events: None
     nfev: 14
     njev: 0
      nlu: 0

That is obviously tedious, so now we devise a supervisor function to do it for us. Since we will save state between calls, I will use a class here. We will define a tolerance that we want the difference of the average of two sequential solutions to be less than. We have to be a little careful here. There are many ways to call get_sol, e.g. all of these are correct, but when the checker function is called, it will get different arguments.

get_sol()           # no args: args=(), kwargs={} 
get_sol((1, 0), 2)  # all positional args: args=((1, 0), 2), kwargs={}
get_sol((1, 0))     # one positional arg:  args=((1, 0),), kwargs={}
get_sol((1, 0), window=2) # a positional and kwarg: args =((1, 0),), kwargs={'window': 2}

We have to either assume one of these, or write a function that can handle any of them. I am going to assume here that args will always just be the initial condition, and anything else will be in kwargs. That is a convention we use for this problem, and if you break the convention, you will have errors. For example, get_sol(C0=(1, 0)) will cause an error because you will not have a positional argument for C0 but instead a keyword argument for C0.

It is not crucial to use a class here; you could also use global variables, or function attributes. A class is a standard way of encapsulating state though. We just have to make the class callable so it acts like a function when we need it to.

class ReachedSteadyState:        
    def __init__(self, tolerance=0.01):
        self.tolerance = tolerance
        self.last_solution = None
        self.count = 0

    def __str__(self):
        return 'ReachedSteadyState'
        
    def __call__(self, args, kwargs, sol):
        if self.last_solution is None:
            self.last_solution = sol
            self.count += 1
            C0 = sol.y.T[-1]
            return (C0,), kwargs

        # we have a previous solution
        if not np.allclose(self.last_solution.y.mean(axis=1),
                           sol.y.mean(axis=1),
                           rtol=self.tolerance,
                           atol=self.tolerance):
            self.last_solution = sol
            self.count += 1
            C0 = sol.y.T[-1]            
            return (C0,), kwargs

Now, we decorate the get_sol function, and then run it. Since we used a bigger window, it only takes 9 iterations to get to an approximate steady state.

def get_sol(C0=(1, 0), window=1):
    sol = solve_ivp(ode, t_span=(0, window), y0=C0)
    return sol

rss = ReachedSteadyState(0.0001)
get_sol = supervisor(check_funcs=(rss,), verbose=True, max_errors=20)(get_sol)
sol = get_sol((1, 0), window=2)
sol

Proposed fix in ReachedSteadyState: ((array([0.74716948, 0.25283052]),), {'window': 2}) Proposed fix in ReachedSteadyState: ((array([0.65414484, 0.34585516]),), {'window': 2}) Proposed fix in ReachedSteadyState: ((array([0.61992776, 0.38007224]),), {'window': 2}) Proposed fix in ReachedSteadyState: ((array([0.60733496, 0.39266504]),), {'window': 2}) Proposed fix in ReachedSteadyState: ((array([0.60269957, 0.39730043]),), {'window': 2}) Proposed fix in ReachedSteadyState: ((array([0.60099346, 0.39900654]),), {'window': 2}) Proposed fix in ReachedSteadyState: ((array([0.60036557, 0.39963443]),), {'window': 2}) Proposed fix in ReachedSteadyState: ((array([0.60013451, 0.39986549]),), {'window': 2}) Proposed fix in ReachedSteadyState: ((array([0.60004949, 0.39995051]),), {'window': 2})

  message: The solver successfully reached the end of the integration interval.
  success: True
   status: 0
        t: [ 0.000e+00  7.179e-01  2.000e+00]
        y: [[ 6.000e-01  6.000e-01  6.000e-01]
            [ 4.000e-01  4.000e-01  4.000e-01]]
      sol: None
 t_events: None
 y_events: None
     nfev: 14
     njev: 0
      nlu: 0

We can plot the two solutions to see how different they are. This shows they are close.

import matplotlib.pyplot as plt
plt.plot(rss.last_solution.t, rss.last_solution.y.T, label=['A previous' ,'B previous'])
plt.plot(sol.t, sol.y.T, '--', label=['A current', 'B current'])
plt.legend()
plt.xlabel('relative t')
plt.ylabel('C');

Those look pretty similar on this graph.

2. Handling exceptions

Suppose you have a function that randomly fails. This could be because something does not converge with a randomly chosen initial guess, converges to an unphysical answer, etc. In these cases, it makes sense to simply try again with a new initial guess.

For this example, say we have this objective function with two minima. We will say that any solution above 0.5 is unphysical.

def f(x):
    return -(np.exp(-50 * (x - 0.25)**2) + 0.5 * np.exp(-100 * (x - 0.75)**2))


x = np.linspace(0, 1)
plt.plot(x, f(x))
plt.xlabel('x')
plt.ylabel('y');

Here we define a function that takes a guess, and gets a solution. If the solution is unphysical, we raise an exception. We define a custom exception so we can handle it specifically.

class UnphysicalSolution(Exception):
    pass

def get_minima(guess):
    sol = minimize(f, guess)

    if sol.x > 0.5:
        raise UnphysicalSolution
    else:
        return sol

Some initial guesses work fine.

get_minima(0.2)    
  message: Optimization terminated successfully.
  success: True
   status: 0
      fun: -1.0000000000069416
        x: [ 2.500e-01]
      nit: 4
      jac: [ 4.470e-08]
 hess_inv: [[ 1.000e-02]]
     nfev: 14
     njev: 7

But, others don't.

get_minima(0.8)    

UnphysicalSolution Traceback (most recent call last) Cell In[16], line 1 -—> 1 get_minima(0.8)

Cell In[14], line 8, in get_minima(guess) 5 sol = minimize(f, guess) 7 if sol.x > 0.5: -—> 8 raise UnphysicalSolution 9 else: 10 return sol

UnphysicalSolution:

Here is an example where we can simply rerun with a new guess. That is done here.

def try_again(args, kwargs, exc):
    if isinstance(exc, UnphysicalSolution):
        args = (np.random.random(),)
        return args, kwargs
  
@supervisor(exception_funcs=(try_again,), verbose=True)    
def get_minima(guess):
    sol = minimize(f, guess)

    if sol.x > 0.5:
        raise UnphysicalSolution
    else:
        return sol

get_minima(np.random.random())
Proposed fix in try_again: ((0.7574152313004273,), {})
Proposed fix in try_again: ((0.39650554857922415,), {})
  message: Optimization terminated successfully.
  success: True
   status: 0
      fun: -1.0000000000069411
        x: [ 2.500e-01]
      nit: 3
      jac: [ 0.000e+00]
 hess_inv: [[ 1.000e-02]]
     nfev: 14
     njev: 7

You can see it took two iterations to find a solution. Other times it might take zero or one, or maybe more, it depends on where the guesses fall.

3. Summary

This solution works pretty well, similar to custodian. It is a little simpler than custodian I think, as you can do simple things with functions, and don't really need to make classes for everything. Probably it does less than custodian, and also probably there are some corner issues I haven't uncovered yet. It was a nice exercise in building a decorator though, and thinking through all the ways this can be done.

Copyright (C) 2023 by John Kitchin. See the License for information about copying.

org-mode source

Org-mode version = 9.7-pre

Discuss on Twitter

Using Custodian to help converge an optimization problem

| categories: optimization, programming | tags:

In high-throughput calculations, some fraction of them usually fail for some reason. Sometimes it is easy to fix these calculations and re-run them successfully, for example, you might just need a different initialization, or to increase memory or the number of allowed steps, etc. custodian is a tool that is designed for this purpose.

The idea is we make a function to do what we want that has arguments that control that. We need a function that can examine the output of the function and determine if it succeeded, and if it didn't succeed to say what new arguments to try next. Then we run the function in custodian and let it take care of rerunning with new arguments until it either succeeds, or tries too many times.

The goal here is to use custodian to fix a problem optimization. The example is a little contrived, we set a number of iterations artificially low so that the minimization fails by reaching the maximum number of iterations. Custodian will catch this, and increase the number of iterations until it succeeds. Here is the objective function:

import matplotlib.pyplot as plt
import numpy as np

def objective(x):
    return np.exp(x**2) - 10*np.exp(x)

x = np.linspace(0, 2)
plt.plot(x, objective(x))
plt.xlabel('x')
plt.ylabel('y');

Clearly there is a minimum near 1.75, but with a bad initial guess, and not enough iterations, an optimizer fails here. We can tell it fails from the message here, and the solution is run it again with more iterations.

from scipy.optimize import minimize

minimize(objective, 0.0, options={'maxiter': 2})
:RESULTS:
  message: Maximum number of iterations has been exceeded.
  success: False
   status: 1
      fun: -36.86289091418059
        x: [ 1.661e+00]
      nit: 2
      jac: [-2.374e-01]
 hess_inv: [[ 6.889e-03]]
     nfev: 20
     njev: 10
:END:

With Custodian you define a "Job". This is a class with params that contain the adjustable arguments in a dictionary, and a run method that stores the results in the params attribute. This is an important step, because the error handlers only get the params, so you need the results in there to inspect them.

The error handlers are another class with a check method that returns True if you should rerun, and a correct method that sets the params to new values to try next. It seems to return some information about what happened. In the correct method, we double the maximum number of iterations allowed, and use the last solution point that failed as the initial guess for the next run.

from custodian.custodian import Custodian, Job, ErrorHandler

class Minimizer(Job):
    def __init__(self, params=None):
        self.params = params if params else {'maxiter': 2, 'x0': 0}
        
    def run(self):
        sol = minimize(objective,
                       self.params['x0'],
                       options={'maxiter': self.params['maxiter']})
        self.params['sol'] = sol

class MaximumIterationsExceeded(ErrorHandler):
    def __init__(self, params):
        self.params = params

    def check(self):
        return self.params['sol'].message == 'Maximum number of iterations has been exceeded.'

    def correct(self):
        self.params['maxiter'] *= 2
        self.params['x0'] = self.params['sol'].x        
        return {'errors': 'MaximumIterations Exceeded',
                'actions': 'maxiter = {self.params["maxiter"]}, x0 = {self.params["x0"]}'}

Now we setup the initial params to try, create a Custodian object with the handler and job, and then run it. The results and final params are stored in the params object.

params = {'maxiter': 1, 'x0': 0}

c = Custodian([MaximumIterationsExceeded(params)],
              [Minimizer(params)],
               max_errors=5)

c.run()
for key in params:
    print(key, params[key])
MaximumIterationsExceeded
MaximumIterationsExceeded
maxiter 4
x0 [1.66250127]
sol   message: Optimization terminated successfully.
  success: True
   status: 0
      fun: -36.86307468296398
        x: [ 1.662e+00]
      nit: 1
      jac: [-9.060e-06]
 hess_inv: [[1]]
     nfev: 6
     njev: 3

Note that params is modified, and finally has the maxiter value that worked, and the solution in it. You can see we had to rerun this problem twice before it succeeded, but this happened automatically after the setup. This example is easy because we can simply increase the maxiter value, and no serious logic is needed. Other use cases might include try it again with another solver, try again with a different initial guess, etc.

It feels a little heavyweight to define the classes, and to store the results in params here, but this was overall under an hour of work to put it all together, starting from scratch with the Custodian documentation from the example on the front page. You can do more sophisticated things, including having multiple error handlers. Overall, for a package designed for molecular simulations, this worked well for a different kind of problem.

Copyright (C) 2023 by John Kitchin. See the License for information about copying.

org-mode source

Org-mode version = 9.7-pre

Discuss on Twitter

for-else loops

| categories: programming | tags:

I just learned of for/else loops (http://pyvideo.org/video/1780/transforming-code-into-beautiful-idiomatic-pytho). They are interesting enough to write about. The idea is that there is an "else" clause of a for loop that is only executed if the loop completes without a break statement. The use case is to avoid using a flag. For example, let us say we want to loop through a list and determine if a number exists. Here is a typical way you might think to do it:

def f():
    flag = False
    for i in range(10):
        if i == 5:
            flag = True
            break

    return flag

print f()
True

A for/else loop does this in a different way. Essentially, the else clause runs if the loop completes, otherwise if the break occurs it is skipped. In this example the break statement occurs, so the else statement is skipped.

def f():
    for i in range(10):
        if i == 5:
            break
    else:
        return False

    return True

print f()
True

In this example no break statement occurs, so the else clause is executed.

def f():
    for i in range(10):
        if i == 15:
            break
    else:
        return False

    return True

print f()
False

It is hard to say if this is an improvement over the flag. They both use the same number of lines of code, and I find it debatable if the else statement is intuitive in its meaning. Maybe if there were multiple potential breaks this would be better.

Needless to say, go watch http://pyvideo.org/video/1780/transforming-code-into-beautiful-idiomatic-pytho. You will learn a lot of interesting things!

Copyright (C) 2013 by John Kitchin. See the License for information about copying.

org-mode source

Discuss on Twitter
Next Page ยป