import torch


class Generator(torch.nn.Module):
    def __init__(self, input_dim, differentiable_model_1,differentiable_model_2,test_dim):
        super(Generator, self).__init__()

        self.step_1 = differentiable_model_1
        self.step_2 = differentiable_model_2
        self.input_dim=input_dim
        self.test_dim=test_dim
        self.M = self.permutation_tensor(input_dim) # Levi-Civita permutation tensor
        
    
        
    def forward_1(self, x):
        y = self.step_1(x)
        return y
    
    def forward_2(self,x):
        return self.step_2(x)
    
    def forward_12(self, x):
        y = self.step_1(x)
        y1 = self.step_2(y[0])                
        return y1

    
    def time_derivative(self, x):
        F1 = self.forward_12(x) # traditional forward pass
        dF1 = torch.autograd.grad(F1.sum(), x, create_graph=True)[0] # gradients for solenoidal field
        solenoidal_field = dF1 @ self.M.t()
        return solenoidal_field
    
    def time_derivative_P(self,x):
        P=self.forward_1(x)[0].clone().detach().requires_grad_(True)
        F1=self.forward_2(P)
        dF1=torch.autograd.grad(F1.sum(), P, create_graph=True)[0]
        return dF1 
        
    
    def derivative_P_M(self,x,i):
        F1 = self.step_1(x)[0][:,i]
        dF1 = torch.autograd.grad(F1.sum(), x, create_graph=True)[0]
        return dF1 @ self.M
    
    def derivative_P(self,x,i, t=None):
        F1 = self.step_1(x)[0][:,i]
        dF1 = torch.autograd.grad(F1.sum(), x, create_graph=True)[0]
        return dF1
    
    def derivative_Q_M(self,x,i, t=None):
        F1 = self.step_1(x)[1][:,i]
        dF1=torch.autograd.grad(F1.sum(), x, create_graph=True)[0] 
        return dF1 
    
    def derivative_Q(self,x,i, t=None):
        F1 = self.step_1(x)[1][:,i]

        d=torch.autograd.grad(F1.sum(), x, create_graph=True)[0]
        return d
    
    def get_loss(self,x, dx,alpha_hpq_P=0.01,alpha_hpq_Q=0.01,alpha_poisson=0.01):
        
        loss_hpq_P=0.
        if alpha_hpq_P!=0:
            #####################  P_dot=0  #####################
            for i in range(self.test_dim):      
                loss_hpq_P+=torch.bmm(self.derivative_P(x,i).view(x.size()[0],1,self.input_dim),dx.view(x.size()[0],self.input_dim,1)).pow(2).mean()
     
          
                
            #####################  {H,P}=0 explicitly using (p,q) derivatives  #####################
            for i in range(self.test_dim):
                loss_hpq_P+=(torch.bmm(self.time_derivative(x).view(x.size()[0],1,self.input_dim), self.derivative_P(x,i).view(x.size()[0],self.input_dim,1))).pow(2).mean()
                

            #####################  \partial H/\partial P =Q_dot  #####################
        loss_hpq_Q=0.
        if alpha_hpq_Q!=0:
                
            dummy_2=torch.zeros(x.size()[0],self.test_dim)
            for i in range(self.test_dim):            
                dummy_2[:,i]=torch.bmm(self.derivative_Q(x,i).view(x.size()[0],1,self.input_dim),dx.view(x.size()[0],self.input_dim,1)).view(x.size()[0])       
                #print(self.time_derivative_P(x).size())
            loss_hpq_Q+=(dummy_2-self.time_derivative_P(x)[:,:self.test_dim]).pow(2).mean()

 
        loss_poisson=0.
        if alpha_poisson!=0:
            for i in range(self.test_dim):       
                for j in range(i+1,self.test_dim):               
                    loss_poisson+=torch.bmm(self.derivative_P(x,i).view(x.size()[0],1,self.input_dim),
                                   self.derivative_P_M(x,j).view(x.size()[0],self.input_dim,1)).view(x.size()[0]).pow(2).mean()
        


            for i in range(self.test_dim):
                for j in range(i+1,self.test_dim):  
                    
                    loss_poisson+=torch.bmm(self.derivative_Q(x,i).view(x.size()[0],1,self.input_dim),
                                   self.derivative_Q_M(x,j).view(x.size()[0],self.input_dim,1)).view(x.size()[0]).pow(2).mean()
               
            for i in range(self.test_dim):
                for j in range(self.test_dim):
                    if i!=j:
                        loss_poisson+=torch.bmm(self.derivative_Q(x,i).view(x.size()[0],1,self.input_dim),
                                       self.derivative_P_M(x,j).view(x.size()[0],self.input_dim,1)).view(x.size()[0]).pow(2).mean()                
            for i in range(self.test_dim):
                loss_poisson+=(torch.abs(torch.bmm(self.derivative_Q(x,i).view(x.size()[0],1,self.input_dim),
                           self.derivative_P_M(x,i).view(x.size()[0],self.input_dim,1)).view(x.size()[0]))-1).pow(2).mean()

       
        
        return alpha_hpq_Q*loss_hpq_Q+  loss_hpq_P* alpha_hpq_P + alpha_poisson*loss_poisson
    
    def permutation_tensor(self,n):    
        M = torch.eye(n)
        M = torch.cat([M[n//2:], -M[:n//2]])
        return M
    

