
import torch



class pq_PQ(torch.nn.Module):
  def __init__(self, input_dim, hidden_dim,latent_dim_p,latent_dim_q,num_hidden=2,momentum=False, angular_momentum=False, HPQ_trainable=True):
    super(pq_PQ, self).__init__()
    
    if num_hidden==0:
        self.linear1_P = torch.nn.Linear(input_dim, latent_dim_p)
        self.linear1_Q = torch.nn.Linear(input_dim , latent_dim_q)
    else:     
        self.linear1_P = torch.nn.Linear(input_dim, hidden_dim)
        self.linear2_P = torch.nn.Linear(hidden_dim, hidden_dim)
        self.linear3_P = torch.nn.Linear(hidden_dim, hidden_dim)
        self.linear4_P = torch.nn.Linear(hidden_dim, hidden_dim)
        self.linear5_P = torch.nn.Linear(hidden_dim, hidden_dim)
        if HPQ_trainable:
          self.linear6_P = torch.nn.Linear(hidden_dim, latent_dim_p)
        else:
          self.linear6_P = torch.nn.Linear(hidden_dim, latent_dim_q)

        for l in [self.linear1_P,self.linear2_P, self.linear3_P,self.linear4_P,self.linear5_P,self.linear6_P]:
              torch.nn.init.orthogonal_(l.weight) 


        if HPQ_trainable:
          self.linear1_Q = torch.nn.Linear(input_dim, hidden_dim)
          self.linear2_Q = torch.nn.Linear(hidden_dim, hidden_dim)
          self.linear3_Q = torch.nn.Linear(hidden_dim, hidden_dim)
          self.linear4_Q = torch.nn.Linear(hidden_dim, hidden_dim)
          self.linear5_Q = torch.nn.Linear(hidden_dim, hidden_dim)
          self.linear6_Q = torch.nn.Linear(hidden_dim, latent_dim_q)
          for l in [self.linear1_Q,self.linear2_Q, self.linear3_Q,self.linear4_Q,self.linear5_Q,self.linear6_Q]:
              torch.nn.init.orthogonal_(l.weight) 


    self.nonlinearity = torch.tanh
    self.num_hidden = num_hidden
    self.momentum = momentum
    self.angular_momentum = angular_momentum
    self.HPQ_trainable=HPQ_trainable

  def forward(self, x):
    if self.num_hidden==0:
        h_Q=self.linear1_Q(x)
        h_P=self.linear1_P(x)
        
        return h_P,h_Q
    else:
        h_P = self.nonlinearity( self.linear1_P(x) )
        if self.num_hidden>1:
            h_P = self.nonlinearity( self.linear2_P(h_P) )
            if self.num_hidden>2:
                h_P = self.nonlinearity( self.linear3_P(h_P) )
                if self.num_hidden>3:
                    h_P = self.nonlinearity( self.linear4_P(h_P) )
                    if self.num_hidden>4:
                        h_P = self.nonlinearity( self.linear5_P(h_P) )
        h_P=self.linear6_P(h_P)




        if self.HPQ_trainable:
            h_Q = self.nonlinearity( self.linear1_Q(x) )
            if self.num_hidden>1:
                h_Q = self.nonlinearity( self.linear2_Q(h_Q) )
                if self.num_hidden>2:
                    h_Q = self.nonlinearity( self.linear3_Q(h_Q) )
                    if self.num_hidden>3:
                        h_Q = self.nonlinearity( self.linear4_Q(h_Q) )
                        if self.num_hidden>4:
                            h_Q = self.nonlinearity( self.linear5_Q(h_Q) )
            h_Q=self.linear6_Q(h_Q)   
 

       
        if self.momentum:
            h_P[:,0]=x[:,4]+x[:,5]
            h_P[:,1]=x[:,6]+x[:,7]
        if self.angular_momentum:
            h_P[:,2]=(x[:,0]-x[:,1])*(x[:,6]-x[:,7])-(x[:,2]-x[:,3])*(x[:,4]-x[:,5])


        if self.HPQ_trainable:
            return h_P,h_Q
        else:
            return h_P, None




class P_H(torch.nn.Module):
  def __init__(self, input_dim, hidden_dim,latent_dim_p,test_dim , HPQ_trainable=True):
    super(P_H, self).__init__()
    if HPQ_trainable:
      self.linear1 = torch.nn.Linear(latent_dim_p, hidden_dim)
    

      self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
    
      self.linear3 = torch.nn.Linear(hidden_dim, 1, bias=None) 
      for l in [self.linear1, self.linear2, self.linear3]:
          torch.nn.init.orthogonal_(l.weight)

    self.nonlinearity = torch.tanh
    self.HPQ_trainable=HPQ_trainable
    self.test_dim=test_dim

  def forward(self, x):
    if self.HPQ_trainable:
  #  print(x.size())
        h = self.nonlinearity( self.linear1(x) )
        h = self.nonlinearity( self.linear2(h) )
        return self.linear3(h)
   # h = self.linear3(h) 
    else:
        return x[:,self.test_dim-1:self.test_dim]




