import torch
import torch.nn as nn
from WaveClass import Wave,WaveDirect,WaveDirectMLP,WaveZeroNN,WaveDoubleGating,WaveOnlyMLP
from GRUCell import GRU
import math
import torchdyn
from torchdyn.core import NeuralDE,NeuralODE,MultipleShootingLayer

class PreNeuralNetwork(nn.Module):
    def __init__(self,input_size,hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.model = torch.nn.Sequential(
            nn.Linear(self.input_size,self.hidden_size),
        )

    def forward(self,input):
        return self.model(input)


class PostNeuralNetwork(nn.Module):
    def __init__(self,hidden_size,output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.model = torch.nn.Sequential(
            nn.Linear(self.hidden_size,self.output_size),
        )
    def forward(self,input):
        return self.model(input)


class NeuralPDE(nn.Module):
    def __init__(self,input_size,hidden_size,output_size,seq_len,model = "wave"):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size 
        self.seq_len = seq_len
        self.c = nn.Parameter(torch.Tensor([self.seq_len]).repeat(1,self.hidden_size))
        self.gru = GRU(self.input_size,self.hidden_size)
        print("Initializing the neural DE")
        if model == "wave":
            self.pdefunc = Wave(self.gru,self.hidden_size,self.c)
            self.NODE = NeuralODE(self.pdefunc,sensitivity="interpolated_adjoint",solver = "tsit5",atol_adjoint = 1e-3,rtol_adjoint = 1e-3)
        elif model == "wavedirect":
            self.pdefunc = WaveDirect(self.gru,self.hidden_size,self.c)
            self.NODE = NeuralODE(self.pdefunc,order = 2,sensitivity="interpolated_adjoint",solver = "tsit5",atol_adjoint = 1e-3,rtol_adjoint = 1e-3)
        elif model == "wavedirectmlp":
            self.mlp = nn.Sequential(nn.Linear(3*self.hidden_size,self.hidden_size),nn.Tanh(),nn.Linear(self.hidden_size,self.hidden_size))
            self.pdefunc = WaveDirectMLP(self.gru,self.mlp,self.hidden_size,self.c)
            self.NODE = NeuralODE(self.pdefunc,order = 2,sensitivity="interpolated_adjoint",solver = "tsit5",atol_adjoint = 1e-3,rtol_adjoint = 1e-3)
        elif model == "wavedirectzeronn":
            self.pdefunc = WaveZeroNN(self.hidden_size,self.c)
            self.NODE = NeuralODE(self.pdefunc,order = 2,sensitivity="interpolated_adjoint",solver = "tsit5",atol_adjoint = 1e-3,rtol_adjoint = 1e-3)
        elif model == "wavedoublegating":
            self.gru_past = GRU(self.input_size,self.hidden_size)
            self.gru_future = GRU(self.input_size,self.hidden_size)
            self.mlp = nn.Sequential(nn.Linear(2*self.hidden_size,self.hidden_size),nn.Tanh(),nn.Linear(self.hidden_size,self.hidden_size))
            self.pdefunc = WaveDoubleGating(self.gru_past,self.gru_future,self.mlp,self.hidden_size,self.c)
            self.NODE = NeuralODE(self.pdefunc,order = 2,sensitivity="interpolated_adjoint",solver = "tsit5",atol_adjoint = 1e-3,rtol_adjoint = 1e-3)
        elif model == "waveonlymlp":
            self.mlp = nn.Sequential(nn.Linear(3*self.hidden_size,self.hidden_size),nn.Tanh(),nn.Linear(self.hidden_size,self.hidden_size))
            self.pdefunc = WaveOnlyMLP(self.mlp,self.hidden_size,self.c)
            self.NODE = NeuralODE(self.pdefunc,order = 2,sensitivity="interpolated_adjoint",solver = "tsit5",atol_adjoint = 1e-3,rtol_adjoint = 1e-3)
        self.pre_nn = PreNeuralNetwork(self.input_size,self.hidden_size)
        self.post_nn = PostNeuralNetwork(self.hidden_size,self.output_size)
        self.ivpnet = None
        #self.NODE = NeuralODE(self.pdefunc,sensitivity="interpolated_adjoint",solver = "tsit5",atol_adjoint = 1e-3,rtol_adjoint = 1e-3)
        #self.NODE = MultipleShootingLayer(self.pdefunc,sensitivity="interpolated_adjoint",solver = "direct",atol_adjoint= 1e-3,rtol_adjoint=1e-3)
        self.batch_size = None
        self.h = None
        self.reset_parameters()
    
    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self,x0,batch_size,times,time_steps = torch.Tensor([0.,1.]),solver = "dopri5"):
        #print("I am inside the neural pde func")
        self.pdefunc.ts = times
        self.pdefunc.batch_size = batch_size
        #print(len(self.NODE(x0,time_steps)))
        return self.NODE.trajectory(x0,time_steps)
    
    
class NeuralPDECls(nn.Module):
    def __init__(self,input_size,hidden_size,output_size,seq_len,model = "wave"):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size 
        self.seq_len = seq_len
        self.c = nn.Parameter(torch.Tensor([self.seq_len]).repeat(1,self.hidden_size))
        self.gru = GRU(self.input_size,self.hidden_size)
        print("Initializing the neural DE")
        if model == "wave":
            self.pdefunc = Wave(self.gru,self.hidden_size,self.c)
            self.NODE = NeuralODE(self.pdefunc,sensitivity="interpolated_adjoint",solver = "tsit5",atol_adjoint = 1e-3,rtol_adjoint = 1e-3)
        elif model == "wavedirect":
            self.pdefunc = WaveDirect(self.gru,self.hidden_size,self.c)
            self.NODE = NeuralODE(self.pdefunc,order = 2,sensitivity="interpolated_adjoint",solver = "tsit5",atol_adjoint = 1e-3,rtol_adjoint = 1e-3)
        elif model == "wavedirectmlp":
            self.mlp = nn.Sequential(nn.Linear(3*self.hidden_size,self.hidden_size),nn.Tanh(),nn.Linear(self.hidden_size,self.hidden_size))
            self.pdefunc = WaveDirectMLP(self.gru,self.mlp,self.hidden_size,self.c)
            self.NODE = NeuralODE(self.pdefunc,order = 2,sensitivity="interpolated_adjoint",solver = "tsit5",atol_adjoint = 1e-3,rtol_adjoint = 1e-3)
        elif model == "wavedirectzeronn":
            self.pdefunc = WaveZeroNN(self.hidden_size,self.c)
            self.NODE = NeuralODE(self.pdefunc,order = 2,sensitivity="interpolated_adjoint",solver = "tsit5",atol_adjoint = 1e-3,rtol_adjoint = 1e-3)

        self.pre_nn = PreNeuralNetwork(self.input_size,self.hidden_size)
        self.post_nn = PostNeuralNetwork(self.hidden_size,self.output_size)
        self.ivpnet = None
        #self.NODE = NeuralODE(self.pdefunc,sensitivity="interpolated_adjoint",solver = "tsit5",atol_adjoint = 1e-3,rtol_adjoint = 1e-3)
        #self.NODE = MultipleShootingLayer(self.pdefunc,sensitivity="interpolated_adjoint",solver = "direct",atol_adjoint= 1e-3,rtol_adjoint=1e-3)
        self.batch_size = None
        self.h = None
        self.reset_parameters()
    
    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self,x0,batch_size,times,time_steps = torch.Tensor([0.,1.]),solver = "dopri5"):
        #print("I am inside the neural pde func")
        print("x shape inside pde:",x0.shape)
        self.pdefunc.ts = times
        self.pdefunc.batch_size = batch_size
        #print(len(self.NODE(x0,time_steps)))
        return self.NODE.trajectory(x0,time_steps)
    
