import torch.utils
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from collections import OrderedDict
from torch.func import vmap, jacrev, jacfwd, hessian
import numpy as np

class modSoftplus(nn.Module):
    def __init__(self, beta:float=1., threshold:float=20.) -> None:
        super().__init__()
        self.beta = beta
        self.threshold = threshold
        
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return torch.nn.Softplus(beta=self.beta, threshold=self.threshold)(x*self.beta)/self.beta

class SinActivation(nn.Module):
    def __init__(self, beta:float=1.) -> None:
        super().__init__()
        self.beta = beta
        
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return torch.sin(self.beta*x)

class NCL(torch.nn.Module):
    def __init__(self,
                 div_hidden_units:list,
                 mom_weight:float=1.,
                 init_weight:float=1.,
                 inc_weight:float=1.,
                 radius: float=1/20.,
                 lr:float=1e-3,
                 div_activation:nn.Module=nn.Softplus(beta=25., threshold=20.),
                 device: str='cuda:0',
                 *args,
                 **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.device = device
        
        self.in_dim = 5
        self.radius = radius
        self.init_weight = init_weight
        self.inc_weight = inc_weight
        self.mom_weight = mom_weight
        self.div_hidden_units = div_hidden_units
        
        self.loss_container = torch.nn.MSELoss(reduction='mean') 
        self.alignment_loss_container = torch.nn.MSELoss(reduction='mean') 

        
        # Divergence free network
        div_out_dim = 3
        self.div_out_dim = div_out_dim
        self.div_mat_dim = (div_out_dim*(div_out_dim-1))//2
        self.div_hidden_units = [self.in_dim] + div_hidden_units
    
        div_net = nn.Sequential()
        for i in range(len(self.div_hidden_units)-1):
            div_net.add_module(f'div_lin{i}', nn.Linear(self.div_hidden_units[i], self.div_hidden_units[i+1]))
            div_net.add_module(f'div_act{i}', div_activation)
        div_net.add_module(f'div_lin{len(self.div_hidden_units)-1}', nn.Linear(self.div_hidden_units[-1], self.div_mat_dim+1))
        
        self.div_net = div_net.to(self.device)

        for name, param in self.div_net.named_parameters():
            if 'weight' in name:
                nn.init.xavier_normal_(param)*0.7
            if 'bias' in name:
                nn.init.zeros_(param)
        
        # Save the optimizer
        self.lr = lr
        
        self.device = device
    
    def embed(self, x:torch.Tensor):
        c = 2*np.pi
        return torch.cat([x[:,:1], torch.cos(c*x[:,1:2]), torch.sin(c*x[:,1:2]), torch.cos(c*x[:,2:3]), torch.sin(c*x[:,2:3])], dim=1)
    
    def embed_single(self, x:torch.Tensor):
        c = 2*np.pi
        return torch.cat([x[:1], torch.cos(c*x[1:2]), torch.sin(c*x[1:2]), torch.cos(c*x[2:3]), torch.sin(c*x[2:3])], dim=0)
    
    def forward_single(self, tx: torch.Tensor, return_final:bool=False):
        tx_in = self.embed_single(tx)
        p = self.div_net(tx_in.reshape((1,-1)))[-1].reshape((-1))
        def div_A_matrix(x:torch.Tensor):
            #print(x.shape)
            # Pass through the networks
            #root_out = self.root_net(x.reshape((1,-1))).reshape(-1)
            div_in = self.embed_single(x)
            div_out = self.div_net(div_in.reshape((1,-1))).reshape(-1)[:-1]
            # Reshape into a matrix form
            mat = torch.zeros((self.div_out_dim, self.div_out_dim), device=self.device)
            triu_indexes = torch.triu_indices(self.div_out_dim, self.div_out_dim, offset=1)
            mat = mat.index_put(tuple(triu_indexes), div_out)
            #print(out.shape)
            # Make the matrix antisymmetric
            A = mat - torch.transpose(mat, dim0=0, dim1=1)
            #print(A.shape)
            return A
        # Now get the vector
        # div_vec has a (3,3,4) shape
        div_fun = jacrev(div_A_matrix)
        div_out = torch.einsum('...ii', div_fun(tx)[:,:,:]) 
        div_out[0] += 2.
        div_out[1:3] += 1e-1
        return torch.concat([div_out, p], dim=0)
    
    def forward(self, tx: torch.Tensor, return_final:bool=False):
        tx_in = self.embed(tx)
        p = self.div_net(tx_in)[:,-1].reshape((-1,1))
        def div_A_matrix(x:torch.Tensor):
            #print(x.shape)
            # Pass through the networks
            #root_out = self.root_net(x)
            div_in = self.embed_single(x)
            div_out = self.div_net(div_in)[:-1]
            # Reshape into a matrix form
            mat = torch.zeros((self.div_out_dim, self.div_out_dim), device=self.device)
            triu_indexes = torch.triu_indices(self.div_out_dim, self.div_out_dim, offset=1)
            mat = mat.index_put(tuple(triu_indexes), div_out)
            #print(out.shape)
            # Make the matrix antisymmetric
            A = mat - torch.transpose(mat, dim0=0, dim1=1)
            #print(A.shape)
            return A
        # Now get the vector
        # div_vec has a (b,3,3,4) shape
        div_fun = vmap(jacrev(div_A_matrix))
        # This is the divergence free output of the divergence equation
        div_out = torch.einsum('...ii', div_fun(tx)[:,:,:,:])
        div_out[:,0] += 2.
        div_out[:,1:3] += 1e-1
        
        return torch.column_stack((div_out, p))
    
    def loss_fn(self, 
                x_pde:torch.Tensor,
                x_init:torch.Tensor, y_init:torch.Tensor
        ) -> torch.Tensor:
        
        # Get the prediction (rho, rhou, rhov)
        y_pred = self.forward(x_pde)
        rho = y_pred[:,0]
        rhou = y_pred[:,1:3]
        p = y_pred[:,-1]
        
        # Get the derivatives
        Dy_pred = vmap(jacrev(self.forward_single))(x_pde)
        Drho = Dy_pred[:,0]
        Drhou = Dy_pred[:,1:3]
        Dp = Dy_pred[:,-1]
        
        # Get the momentum loss
        term_1 = rho.unsqueeze(1)**2 * Drhou[:,:,0]
        term_2 = rho.unsqueeze(1) * (Drho[:,0]).unsqueeze(1) * rhou
        term_3 = rho.unsqueeze(1) * torch.einsum('bij, bj -> bi', Drhou[:,:,1:], rhou)
        term_4 = torch.einsum('bij, bj -> bi', torch.einsum('bi, bj -> bji', Drho[:,1:], rhou) , rhou)
        term_5 = rho.unsqueeze(1)**2 * Dp[:,1:]        
        
        mom_pde = term_1 - term_2 + term_3 - term_4 + term_5
        mom_loss = self.loss_container(mom_pde, torch.zeros_like(mom_pde))
        
        # Incompressibility loss
        inc_pde = torch.einsum('bi,bi -> b', Drho, y_pred[:,:-1])
        inc_loss = self.loss_container(inc_pde, torch.zeros_like(inc_pde))
        
        y_init_pred = self.forward(x_init)
        init_loss = self.loss_container(y_init_pred[:,:-1], y_init)
    
        return init_loss*self.init_weight + self.mom_weight*mom_loss + self.inc_weight*inc_loss
    
    def eval_losses(self, 
                    x_pde:torch.Tensor,
                    x_init:torch.Tensor, y_init:torch.Tensor,
                    x_sol:torch.Tensor, y_sol:torch.Tensor,
                    step:int) -> torch.Tensor:
        
        # Get the prediction (rho, rhou, rhov)
        y_pred = self.forward(x_pde)
        rho = y_pred[:,0]
        rhou = y_pred[:,1:3]
        p = y_pred[:,-1]
        
        # Get the derivatives
        Dy_pred = vmap(jacrev(self.forward_single))(x_pde)
        Drho = Dy_pred[:,0]
        Drhou = Dy_pred[:,1:3]
        Dp = Dy_pred[:,-1]
        
        # Get the momentum loss
        term_1 = rho.unsqueeze(1)**2 * Drhou[:,:,0]
        term_2 = rho.unsqueeze(1) * (Drho[:,0]).unsqueeze(1) * rhou
        term_3 = rho.unsqueeze(1) * torch.einsum('bij, bj -> bi', Drhou[:,:,1:], rhou)
        term_4 = torch.einsum('bij, bj -> bi', torch.einsum('bi, bj -> bji', Drho[:,1:], rhou) , rhou)
        term_5 = rho.unsqueeze(1)**2 * Dp[:,1:]        
        
        mom_pde = term_1 - term_2 + term_3 - term_4 + term_5
        mom_loss = self.loss_container(mom_pde, torch.zeros_like(mom_pde))
        
        # Incompressibility loss
        inc_pde = torch.einsum('bi,bi -> b', Drho, y_pred[:,:-1])
        inc_loss = self.loss_container(inc_pde, torch.zeros_like(inc_pde))
        
        y_init_pred = self.forward(x_init)
        init_loss = self.loss_container(y_init_pred[:,:-1], y_init)
        
        y_sol_pred = self.forward(x_sol)
        y_loss = self.alignment_loss_container(y_sol_pred[:,:-1], y_sol)
        
        tot_loss_val = inc_loss + mom_loss + init_loss
        #print(f'Step: {step}, F_loss: {F_loss_val}, div_loss: {div_loss_val}, bc_loss: {bc_loss_val}, init_loss: {init_loss_val}')
        return step, mom_loss, inc_loss, y_loss, init_loss, tot_loss_val
    
    def forward_u_single(self, tx: torch.Tensor):
        out = self.forward_single(tx)
        return out[1:3]/out[0]
    
    def forward_single_final(self, tx: torch.Tensor):
        return self.forward_single(tx, return_final=True)
    
    def evaluate_consistency(self, x_pde:torch.Tensor):
        Du = vmap(jacrev(self.forward_u_single))(x_pde)
        Dy = vmap(jacrev(self.forward_single_final))(x_pde)
        Drho = Dy[:,0]
        Drhou = Dy[:,1:3]
        Dp = Dy[:,-1]
        
        
        y = self.forward(x_pde, return_final=True)
        rho = y[:,0]
        rhou = y[:,1:3]
        p = y[:,-1]
        
        
        #inc_pde = vmap(torch.trace)(Du[:,:,1:3])
        inc_pde = torch.einsum('bi,bi -> b', Drho, y[:,:-1])
        
        term_1 = rho.unsqueeze(1)**2 * Drhou[:,:,0]
        term_2 = rho.unsqueeze(1) * (Drho[:,0]).unsqueeze(1) * rhou
        term_3 = rho.unsqueeze(1) * torch.einsum('bij, bj -> bi', Drhou[:,:,1:], rhou)
        term_4 = torch.einsum('bij, bj -> bi', torch.einsum('bi, bj -> bji', Drho[:,1:], rhou) , rhou)
        term_5 = rho.unsqueeze(1)**2 * Dp[:,1:]        
        
        mom_pde = term_1 - term_2 + term_3 - term_4 + term_5
        
        
        div_pde = vmap(torch.trace)(Dy[:,:3])
        
        return torch.abs(mom_pde), torch.abs(div_pde), torch.abs(inc_pde)
    
    
    def calc_ic_loss(self, x_init:torch.Tensor, y_init:torch.Tensor):
        y_init_pred = self.forward(x_init)
        y_init_div = y_init_pred[:,:3]
        init_loss = self.loss_container(y_init_div, y_init)
        return init_loss

    def calc_inc_loss(self, x_pde:torch.Tensor):
        y_pred = self.forward(x_pde)
        rho = y_pred[:,0]
        rhou = y_pred[:,1:3]
        p = y_pred[:,-1]
        
        # Get the derivatives
        Dy_pred = vmap(jacrev(self.forward_single))(x_pde)
        Drho = Dy_pred[:,0]
        Drhou = Dy_pred[:,1:3]
        Dp = Dy_pred[:,-1]
        inc_pde = torch.einsum('bi,bi -> b', Drho, y_pred[:,:-1])
        inc_loss = self.loss_container(inc_pde, torch.zeros_like(inc_pde))
        return inc_loss
    
    def calc_mom_loss(self, x_pde:torch.Tensor):
        y_pred = self.forward(x_pde)
        y_div = y_pred[:,:4]
        
        rho_div = y_div[:,0]
        rhou_div = y_div[:,1:3]
        
        Dy_pred = vmap(jacrev(self.forward_single))(x_pde)[:,:4]
        
        Drho_div = Dy_pred[:,0]
        Drhou_div = Dy_pred[:,1:3]
        
        term_1 = rho_div.unsqueeze(1)**2 * Drhou_div[:,:,0]
        term_2 = rho_div.unsqueeze(1) * (Drho_div[:,0]).unsqueeze(1) * rhou_div
        term_3 = rho_div.unsqueeze(1) * torch.einsum('bij, bj -> bi', Drhou_div[:,:,1:], rhou_div)
        term_4 = torch.einsum('bij, bj -> bi', torch.einsum('bi, bj -> bji', Drho_div[:,1:], rhou_div) , rhou_div)
        
        mom_pde = term_1 - term_2 + term_3 - term_4
        
        mom_loss = self.loss_container(mom_pde, torch.zeros_like(mom_pde))
        return mom_loss

    def calc_div_loss(self, x_pde:torch.Tensor):
        Dy = vmap(jacrev(self.forward_single))(x_pde)
        div_pde = vmap(torch.trace)(Dy[:,:3])
        div_loss = self.loss_container(div_pde, torch.zeros_like(div_pde))
        return div_loss

    def calc_align_loss(self, x_pde:torch.Tensor, alignment_mode:str):
        Du_div = vmap(jacrev(self.forward_u_single))(x_pde)
        Dy_inc = vmap(jacrev(self.forward_single))(x_pde)[:,4:]
        y_pred = self.forward(x_pde)
        u_div = y_pred[:,1:3]/y_pred[:,0].unsqueeze(1)
        y_inc = y_pred[:,4:]
        
        if alignment_mode == 'DERL':
            alignment_loss = self.loss_container(Du_div - Dy_inc, torch.zeros_like(Du_div))
        elif alignment_mode == 'OUTL':
            alignment_loss = self.loss_container(u_div - y_inc, torch.zeros_like(u_div))
        elif alignment_mode == 'SOB':
            alignment_loss = self.loss_container(Du_div - Dy_inc, torch.zeros_like(Du_div)) + self.loss_container(u_div - y_inc, torch.zeros_like(u_div))
        else:
            raise ValueError('alignment mode not recognized')
        return alignment_loss
