* Kolmogorov-Arnold Networks (KANs) and Lennard Jones
:PROPERTIES:
:date: 2024/05/05 11:06:22
:updated: 2024/05/05 11:06:22
:org-url: https://kitchingroup.cheme.cmu.edu/org/2024/05/05/Kolmogorov-Arnold-Networks-(KANs)-and-Lennard-Jones.org
:permalink: https://kitchingroup.cheme.cmu.edu/blog/2024/05/05/Kolmogorov-Arnold-Networks-(KANs)-and-Lennard-Jones/index.html
:END:
KANs have been a hot topic of discussion recently (https://arxiv.org/abs/2404.19756). Here I explore using them as an alternative to a neural network for a simple atomistic potential using Lennard Jones data. I adapted this code from https://github.com/KindXiaoming/pykan/blob/master/hellokan.ipynb.
TL;DR It was easy to make the model, and it fit this simple data very well. It does not extrapolate in this example, and it is not obvious what the extrapolation behavior should be.
** Create a dataset
We leverage the ~create_dataset~ function to generate the dataset here. I chose a range with some modest nonlinearity, and the minimum.
#+BEGIN_SRC jupyter-python
import matplotlib.pyplot as plt
import torch
from kan import create_dataset, KAN
def LJ(r):
r6 = r**6
return 1 / r6**2 - 1 / r6
dataset = create_dataset(LJ, n_var=1, ranges=[0.95, 2.0],
train_num=50)
plt.plot(dataset['train_input'], dataset['train_label'], 'b.')
plt.xlabel('r')
plt.ylabel('E');
#+END_SRC
#+RESULTS:
:RESULTS:
[[file:./.ob-jupyter/0db7627856ef3cacbeb19cba9e64a53fb49bf422.png]]
:END:
** Create and train the model
We start by making the model. We are going to model a Lennard-Jones potential with one input, the distance between two atoms, and one output. We start with a width of 2 "neurons".
#+BEGIN_SRC jupyter-python
model = KAN(width=[1, 2, 1])
#+END_SRC
#+RESULTS:
Training is easy. You can even run this cell several times.
#+BEGIN_SRC jupyter-python :async yes
model.train(dataset, opt="LBFGS", steps=20);
model.plot()
#+END_SRC
#+RESULTS:
:RESULTS:
: train loss: 1.64e-04 | test loss: 1.46e-02 | reg: 6.72e+00 : 100%|██| 20/20 [00:03<00:00, 5.61it/s]
:
[[file:./.ob-jupyter/0cea2b134045cc964f990ac28b524c32d441976b.png]]
:END:
We can see here that the fit looks very good.
#+BEGIN_SRC jupyter-python
X = torch.linspace(dataset['train_input'].min(),
dataset['train_input'].max(), 100)[:, None]
plt.plot(dataset['train_input'], dataset['train_label'], 'b.', label='data')
plt.plot(X, model(X).detach().numpy(), 'r-', label='fit')
plt.legend()
plt.xlabel('r')
plt.ylabel('E');
#+END_SRC
#+RESULTS:
:RESULTS:
[[file:./.ob-jupyter/24eddff0ce69063a1aaabc80060e78b56ecef0b5.png]]
:END:
KANs do not save us from extrapolation issues though. I think a downside of KANs is it is not obvious what extrapolation behavior to expect. I guess it could be related to what happens in the spline representation of the functions. Eventually those have to extrapolate too.
#+BEGIN_SRC jupyter-python
X = torch.linspace(0, 5, 1000)[:, None]
plt.plot(dataset['train_input'], dataset['train_label'], 'b.')
plt.plot(X, model(X).detach().numpy(), 'r-');
#+END_SRC
#+RESULTS:
:RESULTS:
[[file:./.ob-jupyter/a16818596b6a60ea026406808143fcddcfae54f9.png]]
:END:
It is early days for KANs, so many things we know about MLPs are still unknown for KANs. For example, with MLPs we know they extrapolate like the activation functions. Probably there is some insight like that to be had here, but it needs to be uncovered. With MLPs there are a lot of ways to regularize them for desired behavior. Probably that is true here too, and will be discovered. Similarly, there are many ways people have approached uncertainty quantification in MLPs that probably have some analog in KANs.
Still, the ease of use suggests it could be promising for some applications.