import torch
import torch.nn as nn
from einops import rearrange, einsum
import numpy as np


class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return x * self.sigmoid(x)


class DeepONet2d(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_dim, x_size, y_size,
                 act_type="swish", n_layers=4, x_span=1, y_span=1):
        super(DeepONet2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.hidden_dim = hidden_dim
        self.x_size = x_size
        self.y_size = y_size
        self.n_layers = n_layers
        self.x_span = x_span
        self.y_span = y_span

        if act_type == "swish":
            self.act = Swish()
        elif act_type == "relu":
            self.act = nn.ReLU()
        elif act_type == "tanh":
            self.act = nn.Tanh()
        else:
            raise NotImplementedError

        self.func_linear_layers = nn.Sequential(
            nn.Linear(self.in_channels*self.x_size*self.y_size, self.hidden_dim),
            self.act,
            nn.Linear(self.hidden_dim, self.hidden_dim),
            self.act,
            nn.Linear(self.hidden_dim, self.hidden_dim),
            self.act,
            nn.Linear(self.hidden_dim, self.hidden_dim))

        self.loc_linear_layers = nn.Sequential(
            nn.Linear(2, self.hidden_dim),
            self.act,
            nn.Linear(self.hidden_dim, self.hidden_dim),
            self.act,
            nn.Linear(self.hidden_dim, self.hidden_dim),
            self.act,
            nn.Linear(self.hidden_dim, self.hidden_dim))

        self.param_bias = nn.Parameter(torch.randn((self.x_size*self.y_size*self.out_channels, )), requires_grad=True)

    def forward(self, x):  # (B, C_comb, x_size, y_size)
        func_inp = rearrange(x, 'b c h w -> b (c h w)')  # (B, C_comb*x_size*y_size)
        func_out = self.func_linear_layers(func_inp)  # (B, hidden_dim)

        loc = self.get_grid(x.shape, x.device)  # (B, x_size, y_size, 2)
        loc_inp = rearrange(loc, 'b h w c -> b (h w) c')  # (B, x_size*y_size, 2)
        loc_inp_reap = loc_inp.repeat((1, self.out_channels, 1))  # (B, x_size*y_size*out_channels, 2)
        loc_out = self.loc_linear_layers(loc_inp_reap)  # (B, x_size*y_size*out_channels, hidden_dim)

        merge_out = einsum(loc_out, func_out, 'b n p, b p -> b n') + self.param_bias  # (B, x_size*y_size*out_channels)
        merge_out = merge_out.reshape(-1, self.x_size, self.y_size, self.out_channels).permute(0, 3, 1, 2)  # (B, C, x_size, y_size)

        return merge_out

    def get_grid(self, shape, device):
        batchsize, x_size, y_size = shape[0], shape[2], shape[3]
        gridx = torch.tensor(np.linspace(0, self.x_span, x_size), dtype=torch.float)
        gridx = gridx.reshape(1, x_size, 1, 1).repeat([batchsize, 1, y_size, 1])
        gridy = torch.tensor(np.linspace(0, 1, y_size), dtype=torch.float)
        gridy = gridy.reshape(1, 1, y_size, 1).repeat([batchsize, x_size, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)