'''
a modified convolution-based 2d DeepONet implementation
built from
https://github.com/devzhk/PINO
https://arxiv.org/abs/2111.03794
see also
https://github.com/Zhengyu-Huang/Operator-Learning/blob/main/nn/mynn.py
see also the official implementation
https://github.com/lululxvi/deepxde/blob/master/deepxde/nn/pytorch/deeponet.py
'''

import matplotlib.pyplot as plt
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial


class BaseNet(nn.Module):
    def __init__(self, layers, 
                       activation='relu', 
                       out_activation=None,
                       layer_type='linear', # or conv2d
                       normalize=False):
        super(BaseNet, self).__init__()

        self.n_layers = len(layers) - 1
        assert self.n_layers >= 1
        if isinstance(activation, str):
            if activation == 'tanh':
                self.activation = nn.Tanh()
            elif activation == 'relu':
                self.activation = nn.ReLU()
            else:
                raise ValueError(f'{activation} is not supported')
        else:
            self.activation = activation
        self.layers = nn.ModuleList()
        if layer_type == 'linear' or layer_type == 'dense':
            module = nn.Linear
        elif layer_type == "conv2d":
            module = partial(nn.Conv2d,
                             kernel_size=3,
                             padding=1)

        for j in range(self.n_layers):
            self.layers.append(module(layers[j], layers[j+1]))

            if j != self.n_layers - 1:
                if normalize:
                    norm = nn.BatchNorm2d if layer_type == 'conv2d' else nn.LayerNorm
                    self.layers.append(norm(layers[j+1]))

                self.layers.append(self.activation)

        if out_activation is not None:
            self.layers.append(out_activation)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)

        return x

class DeepONet2d(nn.Module):
    def __init__(self, 
            input_size, # not channel size
            branch_dim,
            trunk_dim,
            n_layers,
            activation='tanh',
            layer_type='linear', # or conv2d
            normalize=False,
            add_grad_channel=False):
        super(DeepONet2d, self).__init__()
        
        in_dim = 1+2*add_grad_channel
        if layer_type == 'linear' or layer_type == 'dense':
            in_dim *= (input_size[0]*input_size[1])
        

        branch_layer = [in_dim] + [branch_dim for _ in range(n_layers)]
        trunk_layer = [2] + [trunk_dim for _ in range(n_layers)]
        self.branch = BaseNet(branch_layer, activation, 
                              layer_type=layer_type, 
                              normalize=normalize)
        self.trunk = BaseNet(trunk_layer, activation)
        self.input_size = input_size
        self.add_grad_channel = add_grad_channel
        self.layer_type = layer_type

    def forward(self, x, gradx, grid=None, **inputs):
        """
        input: (bsz, n, n, 1) 
        # only single 1 channel for branch
        grid: (bsz, n, n, 2)
        """
        bsz, *mesh_size, _ = x.size()
        if self.add_grad_channel and gradx is not None:
            x = torch.cat([x, gradx], dim=-1)
        
        if self.input_size:
            x = x.permute(0, 3, 1, 2)
            x = F.interpolate(x, size=self.input_size,
                              mode='bilinear',
                              align_corners=True)

        if self.layer_type == 'linear' or self.layer_type == 'dense':
            x = x.view(bsz, -1)

        a = self.branch(x)
        
        grid = grid[0] # same grid for every sample
        b = self.trunk(grid).permute(2, 0, 1).unsqueeze(0)
        if self.input_size:
            b = F.interpolate(b, size=self.input_size,
                                    mode='bilinear',
                                    align_corners=True).squeeze()
        # (width, *meshsize)

        if self.layer_type == 'linear' or self.layer_type == 'dense':
            x = torch.einsum('bi,ixy->bxy', a, b)
        elif self.layer_type == "conv2d":
            x = torch.einsum('bixy,ixy->bxy', a, b)
            if self.input_size:
                x = F.interpolate(x.unsqueeze(1), size=mesh_size,
                                    mode='bilinear',
                                    align_corners=True).squeeze()

        x = x.view(bsz, *mesh_size, 1)

        return dict(preds=x)



if __name__ == "__main__":
    from torchinfo import summary
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    config = dict(
            input_size=(128, 128), # not channel size
            branch_dim=384,
            trunk_dim=384,
            n_layers=8,
            add_grad_channel=True,
            layer_type="conv2d",
            )
    model = DeepONet2d(**config)
    model.to(device)
    batch_size, n_grid = 10, 201
    summary(model, input_size=[(batch_size, n_grid, n_grid, 1),
                               (batch_size, n_grid, n_grid, 2),
                               (batch_size, n_grid, n_grid, 2)], device=device)
