import torch
from torch import nn
from collections import OrderedDict
from torch.func import jacrev, hessian, vmap, jacfwd    
import numpy as np

mu = 1e-3
rho = 1.
y_max = 0.41
x_max = 2.2
center = torch.tensor([1/5.,1/5.])
radius = 1/20.
restricted_x = 0.5
t_max = 2.0

# Generate points uniformly in the ball
def ball_uniform(n: int, radius: float=0.05, dim:int=2, center:torch.tensor=center):
    angle = torch.distributions.Normal(0., 1.).sample((n, dim))
    norms = torch.norm(angle, p=2., dim=1).reshape((-1,1))
    angle = angle/norms
    rad = torch.distributions.Uniform(0., 1.).sample((n,1))**(1./dim)
    
    points = torch.tile(center, (n,1)) + radius*rad*angle
    return points

# Generate points uniformly in the ball boundary
def ball_boundary_uniform(n: int, radius: float=1/20., dim:int=2, center:torch.tensor=center):
    angle = torch.distributions.Normal(0., 1.).sample((n, dim))
    norms = torch.norm(angle, p=2., dim=1).reshape((-1,1))
    angle = angle/norms    
    points = torch.tile(center, (n,1)) + radius*angle
    #print(points.shape)
    return points

# Ramp function for the boundary conditions
def ramp_function(x:torch.Tensor, min_=0., max_=1., c_=4., r_=5.):
    return (min_*np.exp(c_*r_) + max_*torch.exp(r_*x.reshape(-1)))/(np.exp(c_*r_) + np.exp(r_*x.reshape(-1))).reshape(x.shape)

class NSPotentialNet(torch.nn.Module):
    def __init__(self,
                 init_weight: float,
                 mom_weight: float,
                 div_weight: float,
                 out_weight: float,
                 der_weight: float,
                 bc_weight: float,
                 hidden_units: list,
                 lr_init: float,
                 device: str,
                 activation: nn.Module=nn.Tanh,
                 last_activation: bool=False,
                 t_init = 8.,
                 *args,
                 **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
        self.t_init = t_init
        self.mom_weight = mom_weight
        self.div_weight = div_weight
        self.sys_weight = out_weight
        self.sys_weight = der_weight
        self.bc_weight = bc_weight
        self.init_weight = init_weight
        self.hidden_units = hidden_units
        self.lr_init = lr_init
        self.device = device
        # Define the net, first layer
        net_dict = OrderedDict(
            {'lin0': nn.Linear(3, hidden_units[0]),
            'act0': activation}
        )

        # Define the net, hidden layers
        for i in range(1, len(hidden_units)):
            net_dict.update({f'lin{i}': nn.Linear(in_features=hidden_units[i-1], out_features=hidden_units[i])})
            net_dict.update({f'act{i}': activation})
        # Define the net, last layer
        net_dict.update({f'lin{len(hidden_units)}': nn.Linear(in_features=hidden_units[-1], out_features=2)})
        if last_activation:
            net_dict.update({f'act{len(hidden_units)}': activation})
        # Save the network
        self.net = nn.Sequential(net_dict).to(self.device)
        # Define the optimizer
        self.opt = torch.optim.Adam(self.net.parameters(), lr=lr_init)
        self.loss_container = nn.MSELoss(reduction='mean')
            
        #self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=self.opt, milestones=[50000, 100000], gamma=5e-1)
        #self.lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=self.opt, step_size=50000, gamma=5e-1)
        #self.lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=self.opt, step_size=75000, gamma=5e-1)
    
    # Return the potential for single samples
    def psi_single(self,x:torch.Tensor) -> torch.Tensor:
        out = self.net(x.reshape((1,-1))).reshape((-1))
        return out[0]
    
    # Forward function
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        out = self.net(x)
        p = out[:,1].reshape((-1,1))
        uv = vmap(jacrev(self.psi_single))(x)[:,1:]
        u = uv[:,1].reshape((-1,1))
        v = -uv[:,0].reshape((-1,1))
        return torch.column_stack((u,v,p))
    
    # Forward function for single samples
    def forward_single(self, x:torch.Tensor) -> torch.Tensor:
        out = self.net(x.reshape((1,-1))).reshape((-1))
        p = out[1].reshape((-1))
        uv = jacrev(self.psi_single)(x)[1:]
        u = uv[1].reshape((-1))
        v = -uv[0].reshape((-1))
        return torch.concat([u,v,p])
        
    # Boundary condition losses
    def get_bc_loss(self, batch_size: int=64):
        # Condition at x=0
        y_bc1 = torch.distributions.Uniform(0., y_max).sample((batch_size,1))
        t_bc1 = torch.distributions.Uniform(0., t_max).sample((batch_size,1))
        pts_bc1 = torch.column_stack((t_bc1, torch.zeros_like(t_bc1), y_bc1))
        # The first part of the boundary conditions is only on u
        pred_bc1 = self.forward(pts_bc1.to(self.device))[:,:2]
        # The true velocity is 0 along y and given by the ramp function along x
        true_bc1 = ramp_function(t_bc1.reshape((-1))+self.t_init)*4*1.5*pts_bc1[:,2].reshape((-1))*(y_max - pts_bc1[:,2].reshape((-1)))/y_max**2
        true_bc1 = true_bc1.reshape((-1,1))
        bc1 = self.loss_container(pred_bc1, torch.column_stack((true_bc1, torch.zeros_like(true_bc1))).to(self.device))

        # Condition at y = 0 o 0.41
        t_bc2 = torch.distributions.Uniform(0., t_max).sample((2*batch_size,1))
        x_bc2 = torch.distributions.Uniform(0., x_max).sample((2*batch_size,1))
        y_bc2 = torch.concatenate((torch.zeros((batch_size)), y_max*torch.ones((batch_size))))
        pts_bc2 = torch.column_stack((t_bc2, x_bc2, y_bc2))
        pred_bc2 = self.forward(pts_bc2.to(self.device))[:,:2]
        # Here it should be 0
        true_bc2 = torch.zeros_like(pred_bc2)
        bc2 = self.loss_container(pred_bc2, true_bc2.to(self.device))
        
        # Boundary conditions at x = 2.2
        y_bc3 = torch.distributions.Uniform(0., y_max).sample((batch_size,1))
        t_bc3 = torch.distributions.Uniform(0., t_max).sample((batch_size,1))
        x_bc3 = x_max*torch.ones_like(t_bc3)
        pts_bc3 = torch.column_stack((t_bc3, x_bc3, y_bc3))
        pred_bc3 = self.forward(pts_bc3.to(self.device))[:,2]
        # The pressure must be 0
        true_bc3 = torch.zeros_like(pred_bc3)
        bc3 = self.loss_container(pred_bc3, true_bc3.to(self.device))
        
        # Boundary conditions on the cylinder
        t_bc4 = torch.distributions.Uniform(0., t_max).sample((batch_size,1))
        x_bc4 = ball_boundary_uniform(n=batch_size, radius=radius, dim=2, center=center)
        pts_bc4 = torch.column_stack((t_bc4, x_bc4))
        pred_bc4 = self.forward(pts_bc4.to(self.device))[:,:2]
        # Velocity should be 0
        bc4 = self.loss_container(pred_bc4, torch.zeros_like(pred_bc4).to(self.device))

        return bc1 + bc2 + bc3 + bc4
    
    # Function for the initial condition loss
    def get_init_loss(self, batch_size: int=128):
        x_init = torch.distributions.Uniform(0., x_max).sample((batch_size,1))
        y_init = torch.distributions.Uniform(0., y_max).sample((batch_size,1))
        t_init = torch.zeros((batch_size, 1))
        pts_init = torch.column_stack((t_init, x_init, y_init)).to(self.device)
        
        pred_init = self.forward(pts_init)
        init_loss = self.loss_container(pred_init, torch.zeros_like(pred_init))
        
        return init_loss
        

    def loss_fn(self,
        x_pde:torch.Tensor,
        y_pde:torch.Tensor,
        D_pde:torch.Tensor,
        x_init:torch.Tensor,
        y_init:torch.Tensor,
        mode:int,
        H_pde=None,
    ) -> torch.Tensor:
        # Check that the mode parameter is correct
        if mode not in [-1,0,1,10]:
            raise ValueError(f'mode should be either 0 for derivative learning,\
                1 for output learning or -1 for PINN learning, but found {mode}')
        
        # Get the prediction
        out_pred = self.forward(x_pde)
        # Get the partial derivatives from the network
        out_X = vmap(jacrev(self.forward_single))(x_pde)
        #out_XX = vmap(hessian(self.forward_single))(x_pde)
        out_XX = vmap(hessian(self.forward_single))(x_pde)
        
        # PINN mode
        if mode == -1:
            # In this case, we learn through the pde in PINN style
            # Calculate the pde_residual
            lapl_u = torch.diagonal(out_XX[:,:2,1:,1:], dim1=2, dim2=3).sum(dim=2)
            mom_pde = out_X[:,:2,0] - mu*lapl_u + torch.einsum('bij,bj->bi', out_X[:,:2,1:], out_pred[:,:2]) + out_X[:,-1,1:]
            
            # Divergence residual, here it is not necessary
            # div_u = torch.einsum('bii->b', out_X[:,:2,1:3])
            # div_pde = div_u
            
            # Calculate the loss
            mom_loss = self.loss_container(mom_pde, torch.zeros_like(mom_pde))
            #div_loss = self.loss_container(div_pde, torch.zeros_like(div_pde))
            
            spec_loss = self.mom_weight*mom_loss #+ self.div_weight*div_loss
            
            
            # In this case we learn the output
            spec_loss += self.sys_weight*self.loss_container(out_pred, y_pde)            
            
            
        # Derivative learning mode
        elif mode == 0:
            # In this case, we learn by supervision the partial derivatives
            spec_loss = self.sys_weight*self.loss_container(out_X, D_pde)
            if H_pde is not None:
                spec_loss += self.sys_weight*self.loss_container(torch.diagonal(out_XX, dim1=2, dim2=3), H_pde)
            
        # Derivative + vanilla learning mode
        elif mode == 10:
            spec_loss = self.sys_weight*self.loss_container(out_pred, y_pde) + self.sys_weight*self.loss_container(out_X, D_pde)
        # Vanilla learning mode
        else:
            # In this case we learn the output
            spec_loss = self.sys_weight*self.loss_container(out_pred, y_pde)
        
        # Calculate the init prediction
        y_init_pred = self.forward(x_init)
        # Initial loss
        init_loss = self.loss_container(y_init_pred, y_init)
        #init_loss = self.get_init_loss()
        # BC loss
        bc_loss = self.get_bc_loss()
        # Total loss
        tot_loss = spec_loss + self.init_weight*init_loss + self.bc_weight*bc_loss
        return tot_loss
    
    def print_losses(self, step:int,
        x_pde:torch.Tensor,
        y_pde:torch.Tensor,
        D_pde:torch.Tensor,
        x_init:torch.Tensor,
        y_init:torch.Tensor,
        mode:int,    
        H_pde = None,
    ):
        
        
        # Check that the mode parameter is correct
        if mode not in [-1,0,1,10]:
            raise ValueError(f'mode should be either 0 for derivative learning,\
                1 for output learning or -1 for PINN learning, but found {mode}')
        
        # Get the prediction
        out_pred = self.forward(x_pde)
        # Get the partial derivatives from the network
        out_X = vmap(jacrev(self.forward_single))(x_pde)
        #out_XX = vmap(hessian(self.forward_single))(x_pde)
        out_XX = vmap(hessian(self.forward_single))(x_pde)
        
        # PINN mode
        # In this case, we learn through the pde in PINN style
        # Calculate the pde_residual
        lapl_u = torch.diagonal(out_XX[:,:2,1:,1:], dim1=2, dim2=3).sum(dim=2)
        mom_pde = out_X[:,:2,0] - mu*lapl_u + torch.einsum('bij,bj->bi', out_X[:,:2,1:], out_pred[:,:2]) + out_X[:,-1,1:]
        
        # Divergence residual, here it is not necessary
        # div_u = torch.einsum('bii->b', out_X[:,:2,1:3])
        # div_pde = div_u
        
        # Calculate the loss
        mom_loss = self.loss_container(mom_pde, torch.zeros_like(mom_pde))
        #div_loss = self.loss_container(div_pde, torch.zeros_like(div_pde))
        
        pde_loss = self.mom_weight*mom_loss #+ self.div_weight*div_loss
        
        
        # In this case we learn the output
        pde_loss += self.sys_weight*self.loss_container(out_pred, y_pde)            
        
        
        div_u = torch.einsum('bii->b', out_X[:,:2,1:3])
        div_pde = div_u
            
        # Calculate the loss
        div_loss = self.loss_container(div_pde, torch.zeros_like(div_pde))
            
        # Derivative learning mode
        # In this case, we learn by supervision the partial derivatives
        der_loss = self.sys_weight*self.loss_container(out_X, D_pde)
        if H_pde is not None:
            spec_loss += self.sys_weight*self.loss_container(torch.diagonal(out_XX, dim1=2, dim2=3), H_pde)
        
        # Vanilla learning mode
        # In this case we learn the output
        out_loss = self.sys_weight*self.loss_container(out_pred, y_pde)
        
        # Calculate the init prediction
        y_init_pred = self.forward(x_init)
        # Initial loss
        init_loss = self.loss_container(y_init_pred, y_init)
        #init_loss = self.get_init_loss()
        # BC loss
        bc_loss = self.get_bc_loss()
        
        if mode == -1:
            # Calculate the loss
            spec_loss = pde_loss
            
        elif mode == 0:
            # In this case, we learn by supervision the partial derivatives
            spec_loss = self.sys_weight*der_loss
        elif mode == 10:
            spec_loss = self.sys_weight*out_loss + self.sys_weight*der_loss
        else:
            # In this case we learn the output
            spec_loss = self.sys_weight*out_loss
        


        # Total loss
        tot_loss = spec_loss + self.init_weight*init_loss + self.bc_weight*bc_loss
                
        print(f'Step: {step}, total loss: {tot_loss}, init loss: {init_loss}, bc loss: {bc_loss}')
        print(f'mom loss: {mom_loss}, div loss: {div_loss}, out loss {out_loss}, der loss: {der_loss}')
        
        return step, mom_loss, div_loss, out_loss, der_loss, init_loss, bc_loss, tot_loss        
        
    
    def get_consistencies(self,
                   x_pde,
                   ):
        # Get the prediction
        out_pred = self.forward(x_pde)
        # Get the partial derivatives from the network
        out_X = vmap(jacrev(self.forward_single))(x_pde)
        out_XX = vmap(hessian(self.forward_single))(x_pde)          
        
        lapl_u = torch.diagonal(out_XX[:,:2,1:,1:], dim1=2, dim2=3).sum(dim=2)
        mom_pde = out_X[:,:2,0] - mu*lapl_u + torch.einsum('bij,bj->bi', out_X[:,:2,1:], out_pred[:,:2]) + out_X[:,-1,1:]
        #print(lapl_u)
        div_u = torch.einsum('bii->b', out_X[:,:2,1:3]).reshape((-1,1))
        div_pde = div_u
        
        return torch.norm(mom_pde, p=2, dim=1), torch.norm(div_pde, p=2, dim=1)      
    
class NSNet(torch.nn.Module):
    def __init__(self,
                 init_weight: float,
                 mom_weight: float,
                 div_weight: float,
                 sys_weight: float,
                 bc_weight: float,
                 hidden_units: list,
                 lr_init: float,
                 device: str,
                 activation: nn.Module=nn.Tanh,
                 last_activation: bool=True,
                 t_init = 8.,
                 *args,
                 **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
        
        self.t_init = t_init
        self.mom_weight = mom_weight
        self.div_weight = div_weight
        self.sys_weight = sys_weight
        self.bc_weight = bc_weight
        self.init_weight = init_weight
        self.hidden_units = hidden_units
        self.lr_init = lr_init
        self.device = device
        
        # Define the net, first layer
        net_dict = OrderedDict(
            {'lin0': nn.Linear(3, hidden_units[0]),
            'act0': activation}
        )

        # Define the net, hidden layers
        for i in range(1, len(hidden_units)):
            net_dict.update({f'lin{i}': nn.Linear(in_features=hidden_units[i-1], out_features=hidden_units[i])})
            net_dict.update({f'act{i}': activation})
        # Define the net, last layer
        net_dict.update({f'lin{len(hidden_units)}': nn.Linear(in_features=hidden_units[-1], out_features=3)})
        if last_activation:
            net_dict.update({f'act{len(hidden_units)}': activation})
        # Save the network
        self.net = nn.Sequential(net_dict).to(self.device)

        
        self.opt = torch.optim.Adam(self.net.parameters(), lr=lr_init)
        self.loss_container = nn.MSELoss(reduction='mean')
        #self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=self.opt, milestones=[50000, 100000], gamma=5e-1)


    def forward(self, x:torch.Tensor) -> torch.Tensor:
        out = self.net(x)
        return out
    
    def forward_single(self, x:torch.Tensor) -> torch.Tensor:
        out = self.net(x.reshape((1,-1))).reshape((-1))
        return out
    
    # Boundary condition losses
    def get_bc_loss(self, batch_size: int=64):
        # Condition at x=0
        y_bc1 = torch.distributions.Uniform(0., y_max).sample((batch_size,1))
        t_bc1 = torch.distributions.Uniform(0., t_max).sample((batch_size,1))
        pts_bc1 = torch.column_stack((t_bc1, torch.zeros_like(t_bc1), y_bc1))
        # The first part of the boundary conditions is only on u
        pred_bc1 = self.forward(pts_bc1.to(self.device))[:,:2]
        # The true velocity is 0 along y and given by the ramp function along x
        true_bc1 = ramp_function(t_bc1.reshape((-1))+self.t_init)*4*1.5*pts_bc1[:,2].reshape((-1))*(y_max - pts_bc1[:,2].reshape((-1)))/y_max**2
        true_bc1 = true_bc1.reshape((-1,1))
        bc1 = self.loss_container(pred_bc1, torch.column_stack((true_bc1, torch.zeros_like(true_bc1))).to(self.device))

        # Condition at y = 0 o 0.41
        t_bc2 = torch.distributions.Uniform(0., t_max).sample((2*batch_size,1))
        x_bc2 = torch.distributions.Uniform(0., x_max).sample((2*batch_size,1))
        y_bc2 = torch.concatenate((torch.zeros((batch_size)), y_max*torch.ones((batch_size))))
        pts_bc2 = torch.column_stack((t_bc2, x_bc2, y_bc2))
        pred_bc2 = self.forward(pts_bc2.to(self.device))[:,:2]
        # Here it should be 0
        true_bc2 = torch.zeros_like(pred_bc2)
        bc2 = self.loss_container(pred_bc2, true_bc2.to(self.device))
        
        # Boundary conditions at x = 2.2
        y_bc3 = torch.distributions.Uniform(0., y_max).sample((batch_size,1))
        t_bc3 = torch.distributions.Uniform(0., t_max).sample((batch_size,1))
        x_bc3 = x_max*torch.ones_like(t_bc3)
        pts_bc3 = torch.column_stack((t_bc3, x_bc3, y_bc3))
        pred_bc3 = self.forward(pts_bc3.to(self.device))[:,2]
        # The pressure must be 0
        true_bc3 = torch.zeros_like(pred_bc3)
        bc3 = self.loss_container(pred_bc3, true_bc3.to(self.device))
        
        # Boundary conditions on the cylinder
        t_bc4 = torch.distributions.Uniform(0., t_max).sample((batch_size,1))
        x_bc4 = ball_boundary_uniform(n=batch_size, radius=radius, dim=2, center=center)
        pts_bc4 = torch.column_stack((t_bc4, x_bc4))
        pred_bc4 = self.forward(pts_bc4.to(self.device))[:,:2]
        # Velocity should be 0
        bc4 = self.loss_container(pred_bc4, torch.zeros_like(pred_bc4).to(self.device))

        return bc1 + bc2 + bc3 + bc4
    
    # Function for the initial condition loss
    def get_init_loss(self, batch_size: int=128):
        x_init = torch.distributions.Uniform(0., x_max).sample((batch_size,1))
        y_init = torch.distributions.Uniform(0., y_max).sample((batch_size,1))
        t_init = torch.zeros((batch_size, 1))
        pts_init = torch.column_stack((t_init, x_init, y_init)).to(self.device)
        
        pred_init = self.forward(pts_init)
        init_loss = self.loss_container(pred_init, torch.zeros_like(pred_init))
        
        return init_loss
    
    
    def loss_fn(self,
        x_pde:torch.Tensor,
        y_pde:torch.Tensor,
        D_pde:torch.Tensor,
        x_init:torch.Tensor,
        y_init:torch.Tensor,
        mode:int,
        H_pde=None,
        x_bc = None,
        y_bc = None
    ) -> torch.Tensor:
        # Check that the mode parameter is correct
        if mode not in [-1,0,1,10]:
            raise ValueError(f'mode should be either 0 for derivative learning,\
                1 for output learning or -1 for PINN learning, but found {mode}')
        
        # Get the prediction
        out_pred = self.forward(x_pde)
        # Get the partial derivatives from the network
        out_X = vmap(jacrev(self.forward_single))(x_pde)
        #out_XX = vmap(hessian(self.forward_single))(x_pde)
        out_XX = vmap(hessian(self.forward_single))(x_pde)
        
        # PINN mode
        if mode == -1:
            # In this case, we learn through the pde in PINN style
            # Calculate the pde_residual
            lapl_u = torch.diagonal(out_XX[:,:2,1:,1:], dim1=2, dim2=3).sum(dim=2)
            mom_pde = out_X[:,:2,0] - mu*lapl_u + torch.einsum('bij,bj->bi', out_X[:,:2,1:], out_pred[:,:2]) + out_X[:,-1,1:]
            
            # Divergence residual, here it is not necessary
            div_u = torch.einsum('bii->b', out_X[:,:2,1:3])
            div_pde = div_u
            
            # Calculate the loss
            mom_loss = self.loss_container(mom_pde, torch.zeros_like(mom_pde))
            div_loss = self.loss_container(div_pde, torch.zeros_like(div_pde))
            
            spec_loss = self.mom_weight*mom_loss + self.div_weight*div_loss
            
            
            # In this case we learn the output
            spec_loss += self.sys_weight*self.loss_container(out_pred, y_pde)            
            
            
        # Derivative learning mode
        elif mode == 0:
            # In this case, we learn by supervision the partial derivatives
            spec_loss = self.sys_weight*self.loss_container(out_X, D_pde)
            if H_pde is not None:
                spec_loss += self.sys_weight*self.loss_container(torch.diagonal(out_XX, dim1=2, dim2=3), H_pde)
            
        # Derivative + vanilla learning mode
        elif mode == 10:
            spec_loss = self.sys_weight*self.loss_container(out_pred, y_pde) + self.sys_weight*self.loss_container(out_X, D_pde)
        # Vanilla learning mode
        else:
            # In this case we learn the output
            spec_loss = self.sys_weight*self.loss_container(out_pred, y_pde)
        
        # Calculate the init prediction
        y_init_pred = self.forward(x_init)
        # Initial loss
        init_loss = self.loss_container(y_init_pred, y_init)
        #init_loss = self.get_init_loss()
        # BC loss
        bc_loss = self.get_bc_loss()
        # Total loss
        tot_loss = spec_loss + self.init_weight*init_loss + self.bc_weight*bc_loss
        return tot_loss
    
    def print_losses(self, step:int,
        x_pde:torch.Tensor,
        y_pde:torch.Tensor,
        D_pde:torch.Tensor,
        x_init:torch.Tensor,
        y_init:torch.Tensor,
        mode:int,    
        H_pde = None,
        x_bc = None,
        y_bc = None
    ):
        
        
        # Check that the mode parameter is correct
        if mode not in [-1,0,1,10]:
            raise ValueError(f'mode should be either 0 for derivative learning,\
                1 for output learning or -1 for PINN learning, but found {mode}')
        
        # Get the prediction
        out_pred = self.forward(x_pde)
        # Get the partial derivatives from the network
        out_X = vmap(jacrev(self.forward_single))(x_pde)
        #out_XX = vmap(hessian(self.forward_single))(x_pde)
        out_XX = vmap(hessian(self.forward_single))(x_pde)
        
        # PINN mode
        # In this case, we learn through the pde in PINN style
        # Calculate the pde_residual
        lapl_u = torch.diagonal(out_XX[:,:2,1:,1:], dim1=2, dim2=3).sum(dim=2)
        mom_pde = out_X[:,:2,0] - mu*lapl_u + torch.einsum('bij,bj->bi', out_X[:,:2,1:], out_pred[:,:2]) + out_X[:,-1,1:]
        
        # Divergence residual, here it is not necessary
        div_u = torch.einsum('bii->b', out_X[:,:2,1:3])
        div_pde = div_u
        
        # Calculate the loss
        mom_loss = self.loss_container(mom_pde, torch.zeros_like(mom_pde))
        div_loss = self.loss_container(div_pde, torch.zeros_like(div_pde))
        
        pde_loss = mom_loss + div_loss            
            
        # Derivative learning mode
        # In this case, we learn by supervision the partial derivatives
        der_loss = self.loss_container(out_X, D_pde)
        if H_pde is not None:
            der_loss += self.loss_container(torch.diagonal(out_XX, dim1=2, dim2=3), H_pde)
        
        # Vanilla learning mode
        # In this case we learn the output
        out_loss = self.loss_container(out_pred, y_pde)
        
        # Calculate the init prediction
        y_init_pred = self.forward(x_init)
        # Initial loss
        init_loss = self.loss_container(y_init_pred, y_init)
        #init_loss = self.get_init_loss()
        # BC loss
        bc_loss = self.get_bc_loss()
        
        if mode == -1:
            # Calculate the loss
            spec_loss = pde_loss
            
        elif mode == 0:
            # In this case, we learn by supervision the partial derivatives
            spec_loss = self.sys_weight*der_loss
        elif mode == 10:
            spec_loss = self.sys_weight*out_loss + self.sys_weight*der_loss
        else:
            # In this case we learn the output
            spec_loss = self.sys_weight*out_loss
        


        # Total loss
        tot_loss = spec_loss + self.init_weight*init_loss + self.bc_weight*bc_loss
                
        print(f'Step: {step}, total loss: {tot_loss}, init loss: {init_loss}, bc loss: {bc_loss}')
        print(f'mom loss: {mom_loss}, div loss: {div_loss}, out loss {out_loss}, der loss: {der_loss}')
        
        return step, mom_loss, div_loss, out_loss, der_loss, init_loss, bc_loss, tot_loss
    
    def get_consistencies(self,
                   x_pde,
                   ):
        # Get the prediction
        out_pred = self.forward(x_pde)
        # Get the partial derivatives from the network
        out_X = vmap(jacrev(self.forward_single))(x_pde)
        out_XX = vmap(hessian(self.forward_single))(x_pde)          
        
        lapl_u = torch.diagonal(out_XX[:,:2,1:,1:], dim1=2, dim2=3).sum(dim=2)
        mom_pde = out_X[:,:2,0] - mu*lapl_u + torch.einsum('bij,bj->bi', out_X[:,:2,1:], out_pred[:,:2]) + out_X[:,-1,1:]
        #print(lapl_u)
        div_u = torch.einsum('bii->b', out_X[:,:2,1:3]).reshape((-1,1))
        div_pde = div_u
        
        return torch.norm(mom_pde, p=2, dim=1), torch.norm(div_pde, p=2, dim=1)  
    
class NSRestrictedPotentialNet(NSPotentialNet):
    def __init__(self,
                 init_weight: float,
                 mom_weight: float,
                 div_weight: float,
                 out_weight: float,
                 der_weight: float,
                 bc_weight: float,
                 hidden_units: list,
                 lr_init: float,
                 device: str,
                 activation: nn.Module=nn.Tanh,
                 last_activation: bool=True,
                 t_init = 8.,
                 *args,
                 **kwargs) -> None:
        super().__init__(
            init_weight=init_weight,
            mom_weight=mom_weight,
            div_weight=div_weight,
            der_weight=der_weight,
            out_weight=out_weight,
            bc_weight=bc_weight,
            hidden_units=hidden_units,
            lr_init=lr_init,
            device=device,
            activation=activation,
            last_activation=last_activation,
            t_init=t_init,
            *args, **kwargs)
    
    def get_bc_loss(self, x_bc:torch.Tensor, y_bc:torch.Tensor, batch_size: int=128): 
        batch_size = x_bc.shape[0]
        # Boundary conditions on the 
        pred_bc1 = self.forward(x_bc.to(self.device))
        true_bc1 = y_bc
        bc1 = self.loss_container(pred_bc1, true_bc1.to(self.device))
        
               
        # Condition at y = 0 o 0.41
        t_bc2 = torch.distributions.Uniform(0., t_max).sample((2*batch_size,1))
        x_bc2 = torch.distributions.Uniform(0., x_max-restricted_x).sample((2*batch_size,1))
        y_bc2 = torch.concatenate((torch.zeros((batch_size)), y_max*torch.ones((batch_size))))
        pts_bc2 = torch.column_stack((t_bc2, x_bc2, y_bc2))
        pred_bc2 = self.forward(pts_bc2.to(self.device))[:,:2]
        # Here it should be 0
        true_bc2 = torch.zeros_like(pred_bc2)
        bc2 = self.loss_container(pred_bc2, true_bc2.to(self.device))
        
        # Boundary conditions at x = 2.2
        y_bc3 = torch.distributions.Uniform(0., y_max).sample((batch_size,1))
        t_bc3 = torch.distributions.Uniform(0., t_max).sample((batch_size,1))
        x_bc3 = (x_max-restricted_x)*torch.ones_like(t_bc3)
        pts_bc3 = torch.column_stack((t_bc3, x_bc3, y_bc3))
        pred_bc3 = self.forward(pts_bc3.to(self.device))[:,2]
        # The pressure must be 0
        true_bc3 = torch.zeros_like(pred_bc3)
        bc3 = self.loss_container(pred_bc3, true_bc3.to(self.device))
        
        return bc1 + bc2 + bc3
    
    def get_init_loss(self, batch_size: int=128):
        x_init = torch.distributions.Uniform(restricted_x, x_max).sample((batch_size,1))
        y_init = torch.distributions.Uniform(0., y_max).sample((batch_size,1))
        t_init = torch.zeros((batch_size, 1))
        pts_init = torch.column_stack((t_init, x_init, y_init)).to(self.device)
        
        pred_init = self.forward(pts_init)
        init_loss = self.loss_container(pred_init, torch.zeros_like(pred_init))
        
        return init_loss
    
    def loss_fn(self,
        x_pde:torch.Tensor,
        y_pde:torch.Tensor,
        D_pde:torch.Tensor,
        x_init:torch.Tensor,
        y_init:torch.Tensor,
        mode:int,
        H_pde=None,
        x_bc = None,
        y_bc = None
    ) -> torch.Tensor:
        # Check that the mode parameter is correct
        if mode not in [-1,0,1,10]:
            raise ValueError(f'mode should be either 0 for derivative learning,\
                1 for output learning or -1 for PINN learning, but found {mode}')
        
        # Get the prediction
        out_pred = self.forward(x_pde)
        # Get the partial derivatives from the network
        out_X = vmap(jacrev(self.forward_single))(x_pde)
        #out_XX = vmap(hessian(self.forward_single))(x_pde)
        out_XX = vmap(hessian(self.forward_single))(x_pde)
        
        # PINN mode
        if mode == -1:
            # In this case, we learn through the pde in PINN style
            # Calculate the pde_residual
            lapl_u = torch.diagonal(out_XX[:,:2,1:,1:], dim1=2, dim2=3).sum(dim=2)
            mom_pde = out_X[:,:2,0] - mu*lapl_u + torch.einsum('bij,bj->bi', out_X[:,:2,1:], out_pred[:,:2]) + out_X[:,-1,1:]
            
            # Divergence residual, here it is not necessary
            div_u = torch.einsum('bii->b', out_X[:,:2,1:3])
            div_pde = div_u
            
            # Calculate the loss
            mom_loss = self.loss_container(mom_pde, torch.zeros_like(mom_pde))
            div_loss = self.loss_container(div_pde, torch.zeros_like(div_pde))
            
            spec_loss = self.mom_weight*mom_loss + self.div_weight*div_loss
            
            
            # In this case we learn the output
            spec_loss += self.sys_weight*self.loss_container(out_pred, y_pde)            
            
            
        # Derivative learning mode
        elif mode == 0:
            # In this case, we learn by supervision the partial derivatives
            spec_loss = self.sys_weight*self.loss_container(out_X, D_pde)
            if H_pde is not None:
                spec_loss += self.sys_weight*self.loss_container(torch.diagonal(out_XX, dim1=2, dim2=3), H_pde)
            
        # Derivative + vanilla learning mode
        elif mode == 10:
            spec_loss = self.sys_weight*self.loss_container(out_pred, y_pde) + self.sys_weight*self.loss_container(out_X, D_pde)
        # Vanilla learning mode
        else:
            # In this case we learn the output
            spec_loss = self.sys_weight*self.loss_container(out_pred, y_pde)
        
        # Calculate the init prediction
        y_init_pred = self.forward(x_init)
        # Initial loss
        init_loss = self.loss_container(y_init_pred, y_init)
        #init_loss = self.get_init_loss()
        # BC loss
        bc_loss = self.get_bc_loss(x_bc=x_bc, y_bc=y_bc)
        # Total loss
        tot_loss = spec_loss + self.init_weight*init_loss + self.bc_weight*bc_loss
        return tot_loss
    
    def print_losses(self, step:int,
        x_pde:torch.Tensor,
        y_pde:torch.Tensor,
        D_pde:torch.Tensor,
        x_init:torch.Tensor,
        y_init:torch.Tensor,
        mode:int,    
        H_pde = None,
        x_bc = None,
        y_bc = None
    ):
        
        
        # Check that the mode parameter is correct
        if mode not in [-1,0,1,10]:
            raise ValueError(f'mode should be either 0 for derivative learning,\
                1 for output learning or -1 for PINN learning, but found {mode}')
        
        # Get the prediction
        out_pred = self.forward(x_pde)
        # Get the partial derivatives from the network
        out_X = vmap(jacrev(self.forward_single))(x_pde)
        #out_XX = vmap(hessian(self.forward_single))(x_pde)
        out_XX = vmap(hessian(self.forward_single))(x_pde)
        
        # PINN mode
        # In this case, we learn through the pde in PINN style
        # Calculate the pde_residual
        lapl_u = torch.diagonal(out_XX[:,:2,1:,1:], dim1=2, dim2=3).sum(dim=2)
        mom_pde = out_X[:,:2,0] - mu*lapl_u + torch.einsum('bij,bj->bi', out_X[:,:2,1:], out_pred[:,:2]) + out_X[:,-1,1:]
        
        # Divergence residual, here it is not necessary
        # div_u = torch.einsum('bii->b', out_X[:,:2,1:3])
        # div_pde = div_u
        
        # Calculate the loss
        mom_loss = self.loss_container(mom_pde, torch.zeros_like(mom_pde))
        #div_loss = self.loss_container(div_pde, torch.zeros_like(div_pde))
        
        pde_loss = self.mom_weight*mom_loss #+ self.div_weight*div_loss
        
        
        # In this case we learn the output
        pde_loss += self.sys_weight*self.loss_container(out_pred, y_pde)            
        
        
        div_u = torch.einsum('bii->b', out_X[:,:2,1:3])
        div_pde = div_u
            
        # Calculate the loss
        div_loss = self.loss_container(div_pde, torch.zeros_like(div_pde))
            
        # Derivative learning mode
        # In this case, we learn by supervision the partial derivatives
        der_loss = self.sys_weight*self.loss_container(out_X, D_pde)
        if H_pde is not None:
            spec_loss += self.sys_weight*self.loss_container(torch.diagonal(out_XX, dim1=2, dim2=3), H_pde)
        
        # Vanilla learning mode
        # In this case we learn the output
        out_loss = self.sys_weight*self.loss_container(out_pred, y_pde)
        
        # Calculate the init prediction
        y_init_pred = self.forward(x_init)
        # Initial loss
        init_loss = self.loss_container(y_init_pred, y_init)
        #init_loss = self.get_init_loss()
        # BC loss
        bc_loss = self.get_bc_loss(x_bc=x_bc, y_bc=y_bc)
        
        if mode == -1:
            # Calculate the loss
            spec_loss = pde_loss
            
        elif mode == 0:
            # In this case, we learn by supervision the partial derivatives
            spec_loss = self.sys_weight*der_loss
        elif mode == 10:
            spec_loss = self.sys_weight*out_loss + self.sys_weight*der_loss
        else:
            # In this case we learn the output
            spec_loss = self.sys_weight*out_loss
        


        # Total loss
        tot_loss = spec_loss + self.init_weight*init_loss + self.bc_weight*bc_loss
                
        print(f'Step: {step}, total loss: {tot_loss}, init loss: {init_loss}, bc loss: {bc_loss}')
        print(f'mom loss: {mom_loss}, div loss: {div_loss}, out loss {out_loss}, der loss: {der_loss}')
        
        return step, mom_loss, div_loss, out_loss, der_loss, init_loss, bc_loss, tot_loss
    
    
    
class NSRestrictedNet(NSNet):
    def __init__(self,
                 init_weight: float,
                 mom_weight: float,
                 div_weight: float,
                 sys_weight: float,
                 bc_weight: float,
                 hidden_units: list,
                 lr_init: float,
                 device: str,
                 activation: nn.Module=nn.Tanh,
                 last_activation: bool=True,
                 t_init = 8.,
                 *args,
                 **kwargs) -> None:
        super().__init__(
            init_weight=init_weight,
            mom_weight=mom_weight,
            div_weight=div_weight,
            sys_weight=sys_weight,
            bc_weight=bc_weight,
            hidden_units=hidden_units,
            lr_init=lr_init,
            device=device,
            activation=activation,
            last_activation=last_activation,
            t_init=t_init,
            *args, **kwargs)
        
    def get_bc_loss(self, x_bc:torch.Tensor, y_bc:torch.Tensor, batch_size: int=128):        
        batch_size = x_bc.shape[0]
        # Boundary conditions on the 
        pred_bc1 = self.forward(x_bc.to(self.device))
        true_bc1 = y_bc
        bc1 = self.loss_container(pred_bc1, true_bc1.to(self.device))
        
        # Condition at y = 0 o 0.41
        t_bc2 = torch.distributions.Uniform(0., t_max).sample((2*batch_size,1))
        x_bc2 = torch.distributions.Uniform(0., x_max-restricted_x).sample((2*batch_size,1))
        y_bc2 = torch.concatenate((torch.zeros((batch_size)), y_max*torch.ones((batch_size))))
        pts_bc2 = torch.column_stack((t_bc2, x_bc2, y_bc2))
        pred_bc2 = self.forward(pts_bc2.to(self.device))[:,:2]
        # Here it should be 0
        true_bc2 = torch.zeros_like(pred_bc2)
        bc2 = self.loss_container(pred_bc2, true_bc2.to(self.device))
        
        # Boundary conditions at x = 2.2
        y_bc3 = torch.distributions.Uniform(0., y_max).sample((batch_size,1))
        t_bc3 = torch.distributions.Uniform(0., t_max).sample((batch_size,1))
        x_bc3 = (x_max-restricted_x)*torch.ones_like(t_bc3)
        pts_bc3 = torch.column_stack((t_bc3, x_bc3, y_bc3))
        pred_bc3 = self.forward(pts_bc3.to(self.device))[:,2]
        # The pressure must be 0
        true_bc3 = torch.zeros_like(pred_bc3)
        bc3 = self.loss_container(pred_bc3, true_bc3.to(self.device))
        
        return bc1 + bc2 + bc3
    
    def multioutput_sobolev_error(self, x_pde:torch.Tensor, D_pde:torch.Tensor):
        rand_vec = ball_boundary_uniform(n=100, radius=1., dim=3, center=torch.tensor([0.,0.,0.])).to(self.device)
        
        def rand_proj(x):
            out_pred =self.forward_single(x)
            proj_pred= torch.einsum('pj,j->p', rand_vec, out_pred)
            return proj_pred
        rand_proj_der_pred = vmap(jacfwd(rand_proj))(x_pde)
        rand_proj_der_true = torch.einsum('bij,pi->bpj', D_pde, rand_vec)
        error = torch.norm(rand_proj_der_pred - rand_proj_der_true, p=2, dim=2).mean()
        return error
    
    def loss_fn(self,
        x_pde:torch.Tensor,
        y_pde:torch.Tensor,
        D_pde:torch.Tensor,
        x_init:torch.Tensor,
        y_init:torch.Tensor,
        mode:str,
        H_pde=None,
        x_bc = None,
        y_bc = None
    ) -> torch.Tensor:
        # Check that the mode parameter is correct
        modes = ['PINN', 'Derivative', 'Output', 'Sobolev']
        if mode not in modes:
            raise ValueError(f'mode should be in {modes}, but found {mode}')
        
        
        # PINN mode
        if mode == 'PINN':
            # Get the prediction
            out_pred = self.forward(x_pde)
            # Get the partial derivatives from the network
            out_X = vmap(jacrev(self.forward_single))(x_pde)
            #out_XX = vmap(hessian(self.forward_single))(x_pde)
            out_XX = vmap(hessian(self.forward_single))(x_pde)
            # In this case, we learn through the pde in PINN style
            # Calculate the pde_residual
            lapl_u = torch.diagonal(out_XX[:,:2,1:,1:], dim1=2, dim2=3).sum(dim=2)
            mom_pde = out_X[:,:2,0] - mu*lapl_u + torch.einsum('bij,bj->bi', out_X[:,:2,1:], out_pred[:,:2]) + out_X[:,-1,1:]
            
            # Divergence residual, here it is not necessary
            div_u = torch.einsum('bii->b', out_X[:,:2,1:3])
            div_pde = div_u
            
            # Calculate the loss
            mom_loss = self.loss_container(mom_pde, torch.zeros_like(mom_pde))
            div_loss = self.loss_container(div_pde, torch.zeros_like(div_pde))
            
            spec_loss = self.mom_weight*mom_loss + self.div_weight*div_loss    
            
        # Derivative learning mode
        elif mode == 'Derivative':
            # Get the partial derivatives from the network
            out_X = vmap(jacrev(self.forward_single))(x_pde)
            
            # In this case, we learn by supervision the partial derivatives
            spec_loss = self.sys_weight*self.loss_container(out_X, D_pde)
            if H_pde is not None:
                #out_XX = vmap(hessian(self.forward_single))(x_pde)
                out_XX = vmap(hessian(self.forward_single))(x_pde)
                spec_loss += self.sys_weight*self.loss_container(torch.diagonal(out_XX, dim1=2, dim2=3), H_pde)
            
        # Derivative + vanilla learning mode
        elif mode == 'Sobolev':
             # Get the prediction
            out_pred = self.forward(x_pde)
            # Get the partial derivatives from the network
            out_X = vmap(jacrev(self.forward_single))(x_pde)
            
            #error = self.multioutput_sobolev_error(x_pde, D_pde)
            der_loss = self.sys_weight*self.loss_container(out_X, D_pde)
            spec_loss = self.sys_weight*self.loss_container(out_pred, y_pde) + self.sys_weight*der_loss
        # Vanilla learning mode
        else:
            # In this case we learn the output
             # Get the prediction
            out_pred = self.forward(x_pde)
            spec_loss = self.sys_weight*self.loss_container(out_pred, y_pde)
        
        # Calculate the init prediction
        y_init_pred = self.forward(x_init)
        # Initial loss
        init_loss = self.loss_container(y_init_pred, y_init)
        #init_loss = self.get_init_loss()
        # BC loss
        y_bc_pred = self.forward(x_bc)
        #bc_loss = self.get_bc_loss(x_bc=x_bc, y_bc=y_bc)
        bc_loss = self.loss_container(y_bc_pred, y_bc)
        # Total loss
        tot_loss = spec_loss + self.init_weight*init_loss + self.bc_weight*bc_loss
        return tot_loss
    
    def print_losses(self, step:int,
        x_pde:torch.Tensor,
        y_pde:torch.Tensor,
        D_pde:torch.Tensor,
        x_init:torch.Tensor,
        y_init:torch.Tensor,
        mode:str,    
        H_pde = None,
        x_bc = None,
        y_bc = None,
        print_to_screen: bool=True
    ):
        
        
        # Check that the mode parameter is correct
        modes = ['PINN', 'Derivative', 'Output', 'Sobolev']
        if mode not in modes:
            raise ValueError(f'mode should be in {modes}, but found {mode}')
        
        # Get the prediction
        out_pred = self.forward(x_pde)
        # Get the partial derivatives from the network
        out_X = vmap(jacrev(self.forward_single))(x_pde)
        #out_XX = vmap(hessian(self.forward_single))(x_pde)
        out_XX = vmap(hessian(self.forward_single))(x_pde)
        
        # PINN mode
        # In this case, we learn through the pde in PINN style
        # Calculate the pde_residual
        lapl_u = torch.diagonal(out_XX[:,:2,1:,1:], dim1=2, dim2=3).sum(dim=2)
        mom_pde = out_X[:,:2,0] - mu*lapl_u + torch.einsum('bij,bj->bi', out_X[:,:2,1:], out_pred[:,:2]) + out_X[:,-1,1:]
        
        # Divergence residual, here it is not necessary
        div_u = torch.einsum('bii->b', out_X[:,:2,1:3])
        div_pde = div_u
        
        # Calculate the loss
        mom_loss = self.loss_container(mom_pde, torch.zeros_like(mom_pde))
        div_loss = self.loss_container(div_pde, torch.zeros_like(div_pde))
        
        pde_loss = mom_loss + div_loss            
            
        # Derivative learning mode
        # In this case, we learn by supervision the partial derivatives
        der_loss = self.loss_container(out_X, D_pde)
        if H_pde is not None:
            der_loss += self.loss_container(torch.diagonal(out_XX, dim1=2, dim2=3), H_pde)
        
        # Vanilla learning mode
        # In this case we learn the output
        out_loss = self.loss_container(out_pred, y_pde)
        
        # Calculate the init prediction
        y_init_pred = self.forward(x_init)
        # Initial loss
        init_loss = self.loss_container(y_init_pred, y_init)
        #init_loss = self.get_init_loss()
        # BC loss
        y_bc_pred = self.forward(x_bc)
        #bc_loss = self.get_bc_loss(x_bc=x_bc, y_bc=y_bc)
        bc_loss = self.loss_container(y_bc_pred, y_bc)

        
        if mode == 'PINN':
            # Calculate the loss
            spec_loss = self.mom_weight*mom_loss + self.div_weight*div_loss
            
        elif mode == 'Derivative':
            # In this case, we learn by supervision the partial derivatives
            spec_loss = self.sys_weight*der_loss
        elif mode == 'Sobolev':
            spec_loss = self.sys_weight*out_loss + self.sys_weight*der_loss
        else:
            # In this case we learn the output
            spec_loss = self.sys_weight*out_loss

        # Total loss
        tot_loss = spec_loss + self.init_weight*init_loss + self.bc_weight*bc_loss
        if print_to_screen:     
            print(f'Step: {step}, total loss: {tot_loss}, init loss: {init_loss}, bc loss: {bc_loss}')
            print(f'mom loss: {mom_loss}, div loss: {div_loss}, out loss {out_loss}, der loss: {der_loss}')
        
        return step, mom_loss, div_loss, out_loss, der_loss, init_loss, bc_loss, tot_loss
    