# Operator Learning with Domain Decomposition for Geometry Generalization in PDE Solving

This repository is the official implementation of [Operator Learning with Domain Decomposition for Geometry Generalization in PDE Solving](foo). 

![image info](./assets/framework.png)

## Requirements

Two conda environments are required: one environment `fem` for data generation using [FEniCS](https://fenicsproject.org/) and another environment `pde` for training neural operator and perform Schwarz Neural Inference (SNI).

```setup
conda env create --name fem --file=environment_fem.yml
conda env create --name pde --file=environment_pde.yml
```

## Data generation

### Training data and validation data

We use the following scripts with `fem` environment to generate training and validation data

- `training_data_generation_laplace2d.py`
- `training_data_generation_laplace2d_n.py`
- `training_data_generation_darcy2d.py`
- `training_data_generation_heat2d.py`

For example, to generate training data for Laplace2d-Dirichlet, run the following command to generate 200 random polygons with 10 boundary conditions for each polygon.

```
python training_data_generation_laplace2d.py --count 200
```

### Test data

We use the following scripts with `fem` environment to generate test data

- `test_data_generation_laplace2d.py`
- `test_data_generation_laplace2d_n.py`
- `test_data_generation_darcy2d.py`
- `test_data_generation_heat2d.py`

For example, to generate training data for Laplace2d-Dirichlet on domain A, run the following command to generate 200 random polygons with 10 boundary conditions for each polygon.

```
python test_data_generation_laplace2d.py --domain-type A
```

## Training

We use `train.py` with `pde` environment to train neural operator. For example, run the following command to train model on Laplace2d-Dirichlet.

```train
python train.py --gpu 0 --dataset laplace2d_simple --normalize_y none  --normalize_x none --component all --comment rel2  --loss-name rel2 --epochs 500 --batch-size 16 --model-name CGPT --optimizer AdamW --weight-decay 0.00005   --lr 0.001 --lr-method cycle  --grad-clip 1000.0   --n-hidden 128 --n-layers 3  --use-tb 1
```

## Evaluation

We use the following scripts with `pde` environment to perform inference on test datasets.

- `inference_laplace2d.py`
- `inference_laplace2d_n.py`
- `inference_darcy2d.py`
- `inference_heat2d.py`

For example, to perform inference for Laplace2d-Dirichlet on domain A, run the following command,

```eval
python inference_laplace2d.py --gpu 0 --dataset laplace2d_schwarz --normalize_y none  --normalize_x none --component all --comment rel2  --loss-name rel2  --model-name CGPT  --n-hidden 128 --n-layers 3  --use-tb 1 --n-parts 20 --depth 2 --tau 0.04 --epochs 5000 --model-path data/chkpt/laplace2d_simple_all_CGPTrel2_0502_15_02_06.pt
```

## Results

Our model achieves the following performance

![image info](./assets/main_result.png)

 

