import torch
import torch.nn as nn


class Wave(nn.Module):
    def __init__(self,gru,hidden_size,c,times = 1,batch_size = 256,gru_layers = 1,neumann = False):
        super().__init__()
        self.hidden_size = hidden_size
        self.gru_layers = gru_layers
        self.gru = gru ##nn.GRU(self.hidden_size,self.hidden_size,self.gru_layers)
        self.ts = times
        self.batch_size = batch_size
        self.previous_hidden_state = None
        self.neumann = False
        self.c = c

    def forward(self,hidden_state):
        col_count = hidden_state.shape[1]//2
        x1 = hidden_state[:,:col_count]
        x2 = hidden_state[:,col_count:]
        
        #print("x1:",x1.shape)
        
        if self.previous_hidden_state is None:
            self.previous_hidden_state = torch.zeros_like(x1)
        
        d2 = torch.zeros_like(x1)
        d1 = x2
        d2[1*self.batch_size:-1*self.batch_size] = \
        (self.c**2)*(x1[2*self.batch_size:] - 2*x1[1*self.batch_size:-1*self.batch_size] + x1[:-2*self.batch_size])/(self.ts**2) + \
        self.gru(self.previous_hidden_state[1*self.batch_size:-1*self.batch_size],x1[:-2*self.batch_size])
        
        #print("Gru input 1 size:",x1[:-2*self.batch_size].shape)
        #print("Gru input 2 size:",self.previous_hidden_state[1*self.batch_size:-1*self.batch_size].shape)
        
        
        ## Neumann Boundary Conditions ##
        if self.neumann:
            d2[:self.batch_size] = d2[self.batch_size:2*self.batch_size]
            d2[-self.batch_size] = d2[-2*self.batch_size:-1*self.batch_size]
        
        self.previous_hidden_state = x1.clone().detach()
        dudt = torch.cat([d1,d2],axis = 1)
        #print(dudt.shape)
        return dudt
    
    

class WaveDirect(nn.Module):
    def __init__(self,gru,hidden_size,c,times = 1,batch_size = 256,gru_layers = 1,neumann = False):
        super().__init__()
        self.hidden_size = hidden_size
        self.gru_layers = gru_layers
        self.gru = gru ##nn.GRU(self.hidden_size,self.hidden_size,self.gru_layers)
        self.ts = times
        self.batch_size = batch_size
        self.previous_hidden_state = None
        self.neumann = False
        self.c = c

    def forward(self,u):
        u = u.squeeze().T
        dudt = torch.zeros_like(u)
        if self.previous_hidden_state is None:
            self.previous_hidden_state = torch.zeros_like(u)
            
        dudt[1*self.batch_size:-1*self.batch_size] =  \
            (self.c**2)*(u[2*self.batch_size:] - 2*u[1*self.batch_size:-1*self.batch_size] + u[:-2*self.batch_size])/(self.ts**2) + \
            self.gru(self.previous_hidden_state[1*self.batch_size:-1*self.batch_size],u[:-2*self.batch_size])
        
                
        if self.neumann:
            dudt[:self.batch_size] = dudt[self.batch_size:2*self.batch_size]
            dudt[-self.batch_size] = dudt[-2*self.batch_size:-1*self.batch_size]
        
        dudt = dudt.T[:,None,:]
        self.previous_hidden_state = u.clone().detach()
        #print(dudt.shape)
        return dudt
    
class WaveDirectMLP(nn.Module):
    def __init__(self,gru,mlp,hidden_size,c,times = 1,batch_size = 256,gru_layers = 1,neumann = False):
        super().__init__()
        self.hidden_size = hidden_size
        self.gru_layers = gru_layers
        self.mlp = mlp ##nn.GRU(self.hidden_size,self.hidden_size,self.gru_layers)
        self.gru = gru
        self.ts = times
        self.batch_size = batch_size
        self.previous_hidden_state = None
        self.neumann = False
        self.c = c

    def forward(self,u):
        u = u.squeeze().T
        dudt = torch.zeros_like(u)
        if self.previous_hidden_state is None:
            self.previous_hidden_state = torch.zeros_like(u)
        
        joint_states = torch.cat([u[2*self.batch_size:],u[1*self.batch_size:-1*self.batch_size],u[:-2*self.batch_size]],axis = 1)
        mlp_joint_states = self.mlp(joint_states)
        # print("Dimension of joint states:",joint_states.shape)
        # print("Dimension of previous_hidden_state:",self.previous_hidden_state.shape)
        # print("Dimension of mlp joint states:",mlp_joint_states.shape)
            
        dudt[1*self.batch_size:-1*self.batch_size] =  \
            (self.c**2)*(u[2*self.batch_size:] - 2*u[1*self.batch_size:-1*self.batch_size] + u[:-2*self.batch_size])/(self.ts**2)\
            + self.gru(mlp_joint_states,self.previous_hidden_state[1*self.batch_size:-1*self.batch_size])
        
        if self.neumann:
            dudt[:self.batch_size] = dudt[self.batch_size:2*self.batch_size]
            dudt[-self.batch_size] = dudt[-2*self.batch_size:-1*self.batch_size]
        
        dudt = dudt.T[:,None,:]
        self.previous_hidden_state = u.clone().detach()
        #print(dudt.shape)
        return dudt

class WaveZeroNN(nn.Module):
    def __init__(self,hidden_size,c,times = 1,batch_size = 256,gru_layers = 1,neumann = False):
        super().__init__()
        self.hidden_size = hidden_size
        self.gru_layers = gru_layers
        #self.mlp = mlp ##nn.GRU(self.hidden_size,self.hidden_size,self.gru_layers)
        #self.gru = gru
        self.ts = times
        self.batch_size = batch_size
        self.previous_hidden_state = None
        self.neumann = False
        self.c = c

    def forward(self,u):
        u = u.squeeze().T
        dudt = torch.zeros_like(u)
        if self.previous_hidden_state is None:
            self.previous_hidden_state = torch.zeros_like(u)
        
        # joint_states = torch.concat([u[2*self.batch_size:],u[1*self.batch_size:-1*self.batch_size],u[:-2*self.batch_size]],axis = 1)
        # mlp_joint_states = self.mlp(joint_states)
        # print("Dimension of joint states:",joint_states.shape)
        # print("Dimension of previous_hidden_state:",self.previous_hidden_state.shape)
        # print("Dimension of mlp joint states:",mlp_joint_states.shape)
            
        dudt[1*self.batch_size:-1*self.batch_size] =  \
            (self.c**2)*(u[2*self.batch_size:] - 2*u[1*self.batch_size:-1*self.batch_size] + u[:-2*self.batch_size])/(self.ts**2)
            
        
        if self.neumann:
            dudt[:self.batch_size] = dudt[self.batch_size:2*self.batch_size]
            dudt[-self.batch_size] = dudt[-2*self.batch_size:-1*self.batch_size]
        
        dudt = dudt.T[:,None,:]
        self.previous_hidden_state = u.clone().detach()
        #print(dudt.shape)
        return dudt

class WaveDoubleGating(nn.Module):
    def __init__(self,gru_past,gru_future,mlp,hidden_size,c,times = 1,batch_size = 256,gru_layers = 1,neumann = False):
        super().__init__()
        self.hidden_size = hidden_size
        self.gru_layers = gru_layers
        self.ts = times
        self.batch_size = batch_size
        self.previous_hidden_state = None
        self.neumann = False
        self.c = c
        self.mlp = mlp
        self.gru_past = gru_past
        self.gru_future = gru_future
        
    def forward(self,u):
        u = u.squeeze().T
        dudt = torch.zeros_like(u)
        if self.previous_hidden_state is None:
            self.previous_hidden_state = torch.zeros_like(u)
        
        joint_states = torch.cat([self.previous_hidden_state[1*self.batch_size:-1*self.batch_size],u[1*self.batch_size:-1*self.batch_size]],axis = 1)
        mlp_joint_states = self.mlp(joint_states)
        # print("Dimension of joint states:",joint_states.shape)
        # print("Dimension of previous_hidden_state:",self.previous_hidden_state.shape)
        # print("Dimension of mlp joint states:",mlp_joint_states.shape)
        
        dudt[1*self.batch_size:-1*self.batch_size] =  \
            (self.c**2)*(u[2*self.batch_size:] - 2*u[1*self.batch_size:-1*self.batch_size] + u[:-2*self.batch_size])/(self.ts**2) + \
            self.gru_past(mlp_joint_states,u[:-2*self.batch_size]) + \
            self.gru_future(mlp_joint_states,u[2*self.batch_size:])
        
        if self.neumann:
            dudt[:self.batch_size] = dudt[self.batch_size:2*self.batch_size]
            dudt[-self.batch_size] = dudt[-2*self.batch_size:-1*self.batch_size]
        
        dudt = dudt.T[:,None,:]
        self.previous_hidden_state = u.clone().detach()
        #print(dudt.shape)
        return dudt
    
class WaveOnlyMLP(nn.Module):
    def __init__(self,mlp,hidden_size,c,times = 1,batch_size = 256,gru_layers = 1,neumann = False):
        super().__init__()
        self.hidden_size = hidden_size
        self.gru_layers = gru_layers
        self.ts = times
        self.batch_size = batch_size
        self.previous_hidden_state = None
        self.neumann = False
        self.c = c
        self.mlp = mlp
    
    def forward(self,u):
        u = u.squeeze().T
        dudt = torch.zeros_like(u)
        
        if self.previous_hidden_state is None:
            self.previous_hidden_state = torch.zeros_like(u)
        
        joint_states = torch.cat([u[2*self.batch_size:],u[1*self.batch_size:-1*self.batch_size],u[:-2*self.batch_size]],axis = 1)
        mlp_joint_states = self.mlp(joint_states)
        
        dudt[1*self.batch_size:-1*self.batch_size] =  \
        (self.c**2)*(u[2*self.batch_size:] - 2*u[1*self.batch_size:-1*self.batch_size] + u[:-2*self.batch_size])/(self.ts**2) + \
        mlp_joint_states
    
            
        if self.neumann:
            dudt[:self.batch_size] = dudt[self.batch_size:2*self.batch_size]
            dudt[-self.batch_size] = dudt[-2*self.batch_size:-1*self.batch_size]
        
        dudt = dudt.T[:,None,:]
        self.previous_hidden_state = u.clone().detach()
        #print(dudt.shape)
        return dudt
            
            


        
    