import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import random
from torch.optim import LBFGS, Adam
from tqdm import tqdm
import scipy.io

from .util import *
from params import *
from torch.func import vmap, jacrev, hessian, jacfwd

class WaveAct(nn.Module):
    def __init__(self):
        super(WaveAct, self).__init__() 
        self.w1 = nn.Parameter(torch.ones(1), requires_grad=True)
        self.w2 = nn.Parameter(torch.ones(1), requires_grad=True)

    def forward(self, x):
        return self.w1 * torch.sin(x) + self.w2 * torch.cos(x)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=32):
        super(FeedForward, self).__init__() 
        self.linear = nn.Sequential(*[
            nn.Linear(d_model, d_ff),
            WaveAct(),
            nn.Linear(d_ff, d_ff),
            WaveAct(),
            nn.Linear(d_ff, d_model)
        ])

    def forward(self, x):
        return self.linear(x)


class EncoderLayer(nn.Module):
    def __init__(self, d_model, heads):
        super(EncoderLayer, self).__init__()

        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=heads, batch_first=True)
        self.ff = FeedForward(d_model)
        self.act1 = WaveAct()
        self.act2 = WaveAct()
        
    def forward(self, x):
        x2 = self.act1(x)
        x = x + self.attn(x2,x2,x2)[0]
        x2 = self.act2(x)
        x = x + self.ff(x2)
        return x


class DecoderLayer(nn.Module):
    def __init__(self, d_model, heads):
        super(DecoderLayer, self).__init__()

        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=heads, batch_first=True)
        self.ff = FeedForward(d_model)
        self.act1 = WaveAct()
        self.act2 = WaveAct()

    def forward(self, x, e_outputs): 
        x2 = self.act1(x)
        x = x + self.attn(x2, e_outputs, e_outputs)[0]
        x2 = self.act2(x)
        x = x + self.ff(x2)
        return x


class Encoder(nn.Module):
    def __init__(self, d_model, N, heads):
        super(Encoder, self).__init__()
        self.N = N
        self.layers = get_clones(EncoderLayer(d_model, heads), N)
        self.act = WaveAct()

    def forward(self, x):
        for i in range(self.N):
            x = self.layers[i](x)
        return self.act(x)


class Decoder(nn.Module):
    def __init__(self, d_model, N, heads):
        super(Decoder, self).__init__()
        self.N = N
        self.layers = get_clones(DecoderLayer(d_model, heads), N)
        self.act = WaveAct()
        
    def forward(self, x, e_outputs):
        for i in range(self.N):
            x = self.layers[i](x, e_outputs)
        return self.act(x)



class PINNsformer(nn.Module):
    def __init__(self, d_out=3, d_model=32, d_hidden=32, N=1, heads=2,
                 pres_weight=1, velx_weight=1, vely_weight=1, bc_weight=1, ic_weight=1,
                 *args, **kwargs):
        super(PINNsformer, self).__init__()
        self.pres_weight = pres_weight
        self.velx_weight = velx_weight
        self.vely_weight = vely_weight
        self.bc_weight = bc_weight
        self.ic_weight = ic_weight
        self.loss_container = nn.MSELoss()

        self.linear_emb = nn.Linear(3, d_model)

        self.encoder = Encoder(d_model, N, heads)
        self.decoder = Decoder(d_model, N, heads)
        
        self.linear_out = nn.Sequential(*[
            nn.Linear(d_model, d_hidden),
            WaveAct(),
            nn.Linear(d_hidden, d_hidden),
            WaveAct(),
            nn.Linear(d_hidden, d_out)
        ])
        total_params = sum(p.numel() for p in self.parameters())
        print(f"Total number of parameters: {total_params}")

    def forward(self, x):
        src = self.linear_emb(x.reshape((-1,3)))
        e_outputs = self.encoder(src)
        d_output = self.decoder(src, e_outputs)
        output = self.linear_out(d_output)
        return output.squeeze(0)
    
    def calc_ic_loss(self,
                     x_ic:torch.Tensor, y_ic:torch.Tensor, **kwargs) -> torch.Tensor:
        y_ic_pred = self.forward(x_ic)
        return self.loss_container(y_ic_pred, y_ic)

    def calc_ic(self, x_ic:torch.Tensor, y_ic:torch.Tensor, **kwargs) -> torch.Tensor:
        y_ic_pred = self.forward(x_ic)
        return y_ic_pred

    def eval_ic_loss(self,
                     x_ic:torch.Tensor, y_ic:torch.Tensor, **kwargs) -> torch.Tensor:
        y_ic_pred = self.forward(x_ic)
        return self.loss_container(y_ic_pred, y_ic)
    
    def calc_bc_loss(self,
                     x_bc:torch.Tensor, y_bc:torch.Tensor, **kwargs) -> torch.Tensor:
        y_bc_pred = self.forward(x_bc)
        return self.loss_container(y_bc_pred, y_bc)

    def calc_bc(self, x_bc:torch.Tensor, y_bc:torch.Tensor=None, **kwargs) -> torch.Tensor:
        y_bc_pred = self.forward(x_bc)
        return y_bc_pred

    def eval_bc_loss(self,
                     x_bc:torch.Tensor, y_bc:torch.Tensor, **kwargs) -> torch.Tensor:
        y_bc_pred = self.forward(x_bc)
        return self.loss_container(y_bc_pred, y_bc)

    def calc_pres_loss(self,
        x_pde:torch.Tensor, **kwargs) -> torch.Tensor:

        Dy_pred = vmap(jacrev(self.forward))(x_pde)
        
    def calc_pres(
        self, x_pde:torch.Tensor, **kwargs) -> torch.Tensor:
        Dy_pred = vmap(jacrev(self.forward))(x_pde)
        pres_pde = Dy_pred[:,0,0] + bulk*(Dy_pred[:,1,1] + Dy_pred[:,2,2])
        return pres_pde

    def calc_pres_loss(self,
        x_pde:torch.Tensor, **kwargs) -> torch.Tensor:

        pres_pde = self.calc_pres(x_pde)
        return self.loss_container(pres_pde, torch.zeros_like(pres_pde))
    
    def eval_pres_loss(self,
        x_pde:torch.Tensor, **kwargs) -> torch.Tensor:

        pres_pde = self.calc_pres(x_pde)
        return self.loss_container(pres_pde, torch.zeros_like(pres_pde))

    def calc_velx_loss(self, x_pde: torch.Tensor, **kwargs) -> torch.Tensor:
        Dy_pred = vmap(jacrev(self.forward))(x_pde)
        velx_pde = Dy_pred[:,1,0] + Dy_pred[:,0,1]/rho
        return self.loss_container(velx_pde, torch.zeros_like(velx_pde))

    def calc_velx(self, x_pde: torch.Tensor, **kwargs) -> torch.Tensor:
        Dy_pred = vmap(jacrev(self.forward))(x_pde)
        velx_pde = Dy_pred[:,1,0] + Dy_pred[:,0,1]/rho
        return velx_pde

    def eval_velx_loss(self, x_pde: torch.Tensor, **kwargs) -> torch.Tensor:
        velx_pde = self.calc_velx(x_pde)
        return self.loss_container(velx_pde, torch.zeros_like(velx_pde))

    def calc_vely_loss(self, x_pde: torch.Tensor, **kwargs) -> torch.Tensor:
        Dy_pred = vmap(jacrev(self.forward))(x_pde)
        vely_pde = Dy_pred[:,2,0] + Dy_pred[:,0,2]/rho
        return self.loss_container(vely_pde, torch.zeros_like(vely_pde))

    def calc_vely(self, x_pde: torch.Tensor, **kwargs) -> torch.Tensor:
        Dy_pred = vmap(jacrev(self.forward))(x_pde)
        vely_pde = Dy_pred[:,2,0] + Dy_pred[:,0,2]/rho
        return vely_pde

    def eval_vely_loss(self, x_pde: torch.Tensor, **kwargs) -> torch.Tensor:
        vely_pde = self.calc_vely(x_pde)
        return self.loss_container(vely_pde, torch.zeros_like(vely_pde))

    def loss_fn(self,
                x_pde:torch.Tensor,
                x_ic:torch.Tensor, y_ic:torch.Tensor,
                x_bc:torch.Tensor, y_bc:torch.Tensor, **kwargs
    ) -> torch.Tensor:
        
        # Compute the PDE loss
        pres_pde_loss = self.calc_pres_loss(x_pde)
        velx_pde_loss = self.calc_velx_loss(x_pde)
        vely_pde_loss = self.calc_vely_loss(x_pde)
        
        # Calculate boundary value
        bc_loss = self.calc_bc_loss(x_bc, y_bc)

        # Compute the initial loss
        ic_loss = self.calc_ic_loss(x_ic, y_ic)

        # Total final loss
        tot_loss = self.pres_weight*pres_pde_loss + self.velx_weight*velx_pde_loss + self.vely_weight*vely_pde_loss + self.bc_weight*bc_loss + self.ic_weight*ic_loss
        
        return tot_loss

    def eval_losses(self,
                    x_pde:torch.Tensor, y_pde:torch.Tensor,
                    x_ic:torch.Tensor, y_ic:torch.Tensor,
                    x_bc:torch.Tensor, y_bc:torch.Tensor, **kwargs):
        # Get the prediction

        bc_loss = self.eval_bc_loss(x_bc, y_bc)

        ic_loss = self.eval_ic_loss(x_ic, y_ic)
                
        # Now calculate the error wrt the true pdeution
        y_pred_final = self.forward(x_pde)
        y_loss = self.loss_container(y_pred_final, y_pde)

        
        pres_pde_loss = self.eval_pres_loss(x_pde)

        velx_pde_loss = self.eval_velx_loss(x_pde)

        vely_pde_loss = self.eval_vely_loss(x_pde)

        tot_loss = bc_loss + ic_loss + pres_pde_loss + velx_pde_loss + vely_pde_loss
        
        return pres_pde_loss, velx_pde_loss, vely_pde_loss, y_loss, bc_loss, ic_loss, torch.zeros_like(y_loss), tot_loss
    
        
        
    def evaluate_consistency(self, x):
        # Get the predictions
        y_pred = self.forward(x)
        
        # Get the derivatives 
        Dy_pred = vmap(jacrev(self.forward))(x)
        
        # Get the properties
        
        # Remember the equations
        # Sound speed
        # cc = np.sqrt(bulk_modulus/density)
        # Impedance
        # Z = density*cc
        # density = Z/cc
        # bulk_modulus = Z*cc
        # Acoustic equation
        # p_t + K (u_x + v_y) & = 0 \\ 
        #    u_t + p_x / \rho & = 0 \\
        #    v_t + p_y / \rho & = 0.
        
        # pres PDE
        pres_pde = Dy_pred[:,0,0] + bulk*(Dy_pred[:,1,1] + Dy_pred[:,2,2])
        
        # Velocity PDE
        velx_pde = Dy_pred[:,1,0] + Dy_pred[:,0,1]/rho
        vely_pde = Dy_pred[:,2,0] + Dy_pred[:,0,2]/rho
        
        # Compute the PDE loss
        pres_pde_loss = self.loss_container(pres_pde, torch.zeros_like(pres_pde))
        velx_pde_loss = self.loss_container(velx_pde, torch.zeros_like(velx_pde))
        vely_pde_loss = self.loss_container(vely_pde, torch.zeros_like(vely_pde))
        
        pde_loss = pres_pde_loss + velx_pde_loss + vely_pde_loss
        
        
        return torch.abs(pres_pde), torch.abs(velx_pde), torch.abs(vely_pde)    
        
        

    

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)