# Scalable Mechanistic Neural Networks (S-MNN)

The S-MNN solver is implemented in `solver/ode_forward.py`. The interface is the `ode_forward` function.

Other files in the `solver` folder are the original [MNN](https://github.com/alpz/mech-nn) solvers (*dense* and *sparse* versions).

In the `package` folder we provide a Python package for our S-MNN solver using PyTorch C++ API.

All other files are for the experiments.

## Building Environments

1. Create a conda environment by using the `env.yml` file

2. _[Optional]_ Build the C++ package of the solver

```
cd package && ./install.sh
```

## Standalone Validation

Run

```
PYTHONPATH=. python odebench/solve_and_plot.py
```

## Lorenz Discovery

The discovery proceeds in a cycle of optimization and thresholding the basis weights.
The optimization runs for 400 epochs.

**Select solver**

S-MNN: `solver=LEAST_SQUARES`

MNN *dense*: `solver=DENSE_CHOLESKY`

MNN *sparse*: `solver=SPARSE_INDIRECT_BLOCK_CG`

**Set log directory**

e.g.

```
log_dir="./logs_test"
```

**Run experiment**

```
PYTHONPATH=. python discovery/lorenz.py --log-dir=$log_dir --solver=$solver
```


## KdV (PDE)

**Generate dataset**

Generate 1D KdV *easy* dataset using the scripts at https://github.com/brandstetter-johannes/LPSDA

Modify the data location in pde/kdv.py if needed.

**Select solver and set log directory**

Same as _Lorenz Discovery_

**Run experiment**

```
PYTHONPATH=. python pde/kdv.py --log-dir=$log_dir --solver=$solver
```

After training, testing can be done by using the `evaluate` and `get_losses` functions in `pde/kdv.py`.

Modify the location of the PyTorch checkpoint if needed.

## Sea Surface Temperature (SST)

**Provided as a separated project `marrying`.**

Please copy the `solver` folder from this project to the `src` folder of `marrying` and run the experiments under `marrying`.

**Download the dataset**

Download the [SST-V2 dataset](https://psl.noaa.gov/repository/entry/show/PSL+Climate+Data+Repository/Public/PSL+Datasets/NOAA+OI+SST/Weekly+and+Monthly/sst.wkmean.1990-present.nc?entryid=12159560-ab82-48a1-b3e4-88ace20475cd&output=data.cdl)

Modify the data location in `configs/data/sst.yaml` if needed.

**Train the model**

Run 

```
python src/train.py experiment=train_sst
```

**Evaluate the model**

Run 

```
python src/eval.py +experiment=eval_sst
```
