import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np





def conv1d(in_planes, out_planes, stride=1, bias=True, kernel_size=5, padding=2, dialation=1) :
    return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)

def conv2d(in_planes, out_planes, stride=1, bias=True, kernel_size=5, padding=2, dialation=1) :
    return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)


class NetA(nn.Module) :
    def __init__(self, d_in, filters, d_out, kernel_size=7, padding=3, blocks=0, is_bdrylyaer=False) :
        super(NetA,self).__init__()
        self.d_in = d_in
        self.blocks = blocks
        self.filters = filters
        self.d_out = d_out
        self.kern = kernel_size
        self.pad = padding
        self.swish = nn.SiLU()
        self.conv1 = conv1d(d_in, filters, kernel_size=self.kern, padding=self.pad)
        self.conv_list = []
        if self.blocks != 0:
            for block in range(self.blocks):
                self.conv_list.append(conv1d(filters, filters, kernel_size=self.kern, padding=self.pad))
                self.conv_list.append(self.swish)
        self.conv_list=nn.Sequential(*self.conv_list)
        self.convH = conv1d(filters, filters, kernel_size=self.kern, padding=self.pad)
        if is_bdrylyaer:
            self.fcH = nn.Linear(filters*(self.d_out-1), self.d_out, bias=True)
        else:
            self.fcH = nn.Linear(filters*self.d_out, self.d_out, bias=True)
    def forward(self, x):
        out = self.swish(self.conv1(x))
        if self.blocks != 0:
            out = self.conv_list(out)
        out = self.convH(out)
        out = out.flatten(start_dim=1)     
        out = self.fcH(out)       
        out = out.view(out.shape[0], 1, self.d_out)        
        return out
    


class Net2D(nn.Module) : # Linear
    def __init__(self, resol_in, d_in, filters, d_out, kernel_size=7, padding=3, blocks=0) :
        super(Net2D, self).__init__()
        self.resol_in = resol_in
        self.d_in = d_in
        self.blocks = blocks
        self.filters = filters
        self.d_out = d_out
        self.swish = nn.SiLU()
        # self.swish = nn.ReLU()
        self.kern = kernel_size
        self.pad = padding
        self.conv1 = conv2d(d_in, filters, kernel_size=self.kern, padding=self.pad)
        self.conv_list = []
        if self.blocks != 0:
            for block in range(self.blocks):
                self.conv_list.append(conv2d(filters, filters, kernel_size=self.kern, padding=self.pad))
                self.conv_list.append(self.swish)
        self.conv_list=nn.Sequential(*self.conv_list)
        self.convH = conv2d(filters, filters, kernel_size=self.kern, padding=self.pad)
        self.fcH = nn.Linear(filters*resol_in**2, self.d_out, bias=True)

    def forward(self, x):
        out = self.swish(self.conv1(x))
        if self.blocks != 0:
            out = self.conv_list(out)
        out = self.convH(out)
        out = out.flatten(start_dim=1)
        out = self.fcH(out)
        out = out.view(out.shape[0], 1, self.d_out)
        return out

class PINN(nn.Module):
    def __init__(self, depth_trunk, width_trunk, act, output_domain_dim, output_dim):
        super(PINN, self).__init__()
        self.output_dim=output_dim
        self.n_basis=width_trunk
        
        ##trunk net
        if act=='tanh':
            self.activation=nn.Tanh()
        elif act=='prelu':
            self.activation=nn.PReLU()
        elif act=='relu':
            self.activation=nn.ReLU()
        else:
            print('activation error!!')
            
        self.trunk_list = []
        self.trunk_list.append(nn.Linear(output_domain_dim,width_trunk))
        self.trunk_list.append(self.activation)
        for i in range(depth_trunk):
            self.trunk_list.append(nn.Linear(width_trunk, width_trunk))
            self.trunk_list.append(self.activation)
        self.trunk_list.append(nn.Linear(width_trunk, 1))
        self.trunk_list = nn.Sequential(*self.trunk_list)
        
    def forward(self, data_grid):
        y=self.trunk_list(data_grid)
        return y

class DeepONet(nn.Module):
    def __init__(self, depth_trunk, width_trunk, depth_branch, width_branch, act, num_sensor, output_domain_dim, output_dim):
        super(DeepONet, self).__init__()
        self.use_bias=True
        self.output_dim=output_dim
        self.n_basis=width_trunk
        self.num_sensor=num_sensor
        if self.use_bias:
            self.b = torch.nn.Parameter(torch.zeros(self.output_dim))
        
        ##trunk net
        if act=='tanh':
            self.activation=nn.Tanh()
        elif act=='prelu':
            self.activation=nn.PReLU()
        elif act=='relu':
            self.activation=nn.ReLU()
        else:
            print('activation error!!')
            
        if width_trunk!=width_branch:
            print('width need to be same error!!')
            
        self.trunk_list = []
        self.trunk_list.append(nn.Linear(output_domain_dim,width_trunk))
        self.trunk_list.append(self.activation)
        for i in range(depth_trunk-1):
            self.trunk_list.append(nn.Linear(width_trunk, width_trunk))
            self.trunk_list.append(self.activation)
        self.trunk_list.append(nn.Linear(width_trunk, self.output_dim*width_trunk))
        self.trunk_list = nn.Sequential(*self.trunk_list)
        
        ##branch net
        self.branch_list = []
        self.branch_list.append(nn.Linear(self.num_sensor,width_branch))
        self.branch_list.append(self.activation)
        for i in range(depth_branch-1):
            self.branch_list.append(nn.Linear(width_branch, width_branch))
            self.branch_list.append(self.activation)
        self.branch_list.append(nn.Linear(width_branch, self.output_dim*width_branch))
        self.branch_list = nn.Sequential(*self.branch_list)
        

        
    def forward(self, data_grid, data_sensor):
        
        B_sensor=data_sensor.shape[0]
        B_grid=data_grid.shape[0]
        coeff=self.branch_list(data_sensor).reshape(B_sensor,self.output_dim,1,self.n_basis).repeat(1,1,B_grid,1)        
        basis=self.trunk_list(data_grid).reshape(1,self.output_dim,B_grid,self.n_basis).repeat(B_sensor,1,1,1)
        y=torch.einsum("bijk,bijk->bij", coeff, basis)
        if self.use_bias:
            y += self.b.to(y.device)
        return y