import os
import argparse
from irregular_sampled_datasets import PersonData,Walker2dImitationData
import torch.utils.data as data
import torch.nn as nn
import torch.optim as optim
from torchmetrics.functional import accuracy
import time
import numpy as np
import torch
import math
from torchdyn.core import NeuralDE,NeuralODE,MultipleShootingLayer
import torch.nn.functional as F
from itertools import chain
from GRUCell import GRU 
from NeuralPDE import NeuralPDE
import torchdyn
import logging


import argparse


parser = argparse.ArgumentParser()


parser.add_argument("--model", default= "wavedirect")

args = parser.parse_args()



print("Torch version:",torch.__version__)
print("Torchdyn version:",torchdyn.__version__)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = "cpu"
print("Device:",device)
solver = "tsit5"
size = 64
default = 200
lr = 0.005
steps = 2
dataset_name = "walk"
epochs = 200
#model = "wavedirect" 
## wave (1d system), wavedirect (2d Solver 1 GRU), wavedirectmlp(2d solver + MLP + GRU), wavedirectzeronn (2d solver + no NN)

logging.basicConfig(level = logging.INFO,format = '%(asctime)s %(message)s',filemode="w")
file_handler = logging.FileHandler(filename = dataset_name + args.model+".log", mode='a', encoding='utf=8', delay=False)
file_handler.setFormatter(logging.Formatter('%(asctime)s %(message)s'))
#logging.basicConfig(level = logging.INFO,format = '%(asctime)s %(message)s',filename = model_name + ".log",filemode="w")
logger = logging.getLogger(args.model+".log")
logger.addHandler(file_handler)
logger.info("Logging has started") 

if args.model == "wavedirect":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Single GRU + Wave Speed")
elif args.model == "wavedirectmlp":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Single GRU + Hidden State MLP + Wave Speed")
elif args.model == "wavedirectzeronn":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Null + Wave Speed")
elif args.model == "wave":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Single GRU + Wave Speed + 1D Linear System")
    
## declare the input,hidden,output size and the pre and post neural networks ##

def load_dataset(dataset_name):
    
    #if dataset_name == "walk":
    print("I am in load_dataset class")
    dataset = Walker2dImitationData(seq_len = 64)
    train_x = torch.Tensor(dataset.train_x)
    train_y = torch.LongTensor(dataset.train_y)
    train_ts = torch.Tensor(dataset.train_times)
    
    test_x = torch.Tensor(dataset.test_x)
    test_y = torch.LongTensor(dataset.test_y)
    test_ts = torch.Tensor(dataset.test_times)
    
    valid_x = torch.Tensor(dataset.valid_x)
    valid_y = torch.LongTensor(dataset.valid_y)
    valid_ts = torch.Tensor(dataset.valid_times)
    
    print(train_y.shape,test_y.shape)
    
    train = data.TensorDataset(train_x, train_ts, train_y)
    test = data.TensorDataset(test_x, test_ts, test_y)
    valid = data.TensorDataset(valid_x, valid_ts, valid_y)
    return_sequences = True
	    
    trainloader = data.DataLoader(train, batch_size=256, shuffle=True)
    testloader = data.DataLoader(test, batch_size=256, shuffle=False)
    in_features = train_x.size(-1)
    validloader = data.DataLoader(valid, batch_size=256, shuffle=False)
    num_classes = 17
    
    return trainloader, testloader, validloader , in_features, num_classes, return_sequences,train_x.shape[1]


input_size = 17  ##----> input dimension of the vector
hidden_size = 64  ##---> dimension of the pde/ode
output_size = 17  ##----> dimension of output
seq_len = 64

# pre_nn = PreNeuralNetwork(input_size,hidden_size)
# post_nn = PostNeuralNetwork(hidden_size,output_size)
# ## Heat equation to be declared ##
# heat = Heat(hidden_size)

neuralpde = NeuralPDE(input_size,hidden_size,output_size,seq_len,model = args.model)
neuralpde.to(device)

criterion = nn.MSELoss()
params = list(neuralpde.parameters())
optimizer = optim.Adam(params,lr = lr)

def feed_forward(loader,num_classes,train = True,return_sequences = True,seq_len = 64):
    if return_sequences:
        Y_hat = torch.zeros((0,seq_len,num_classes))
    else:
        Y_hat = torch.zeros((0,num_classes))

    Y = torch.zeros((0)).type(torch.int)

    if train:
        print("Entering with train mode")
        neuralpde.train()
    else:
        print("Entering eval mode")
        neuralpde.eval()

    Loss = 0
    k = 0
    for batch in loader:
        train_x,train_ts,train_y = batch
        mask = None
        optimizer.zero_grad()
        batch_size = train_x.shape[0]
        #print("batch size:",batch_size)
        seq_len = train_x.shape[1]
        seq_size = hidden_size
        hidden_state = torch.zeros((batch_size*seq_len + 2*batch_size,seq_size),device = device)
        Times = torch.zeros((batch_size*seq_len),device = device)
        #print("Batch size:",batch_size)
        for j in range(seq_len):
            input = train_x[:,j].to(device)
            input = neuralpde.pre_nn(input)
            ts = train_ts[:,j].to(device)
            #print(ts)
            hidden_state[(j+1)*batch_size:(j+2)*batch_size] = input
            #print((j)*batch_size,(j+1)*batch_size)
            #print(((ts-ts.min())/(ts.max()-ts.min())).shape)
            Times[j*batch_size:(j+1)*batch_size] = ((ts-ts.min())/(ts.max()-ts.min()))[:,0]+1
            if j == 0:
                hidden_state[:batch_size] = input
            elif j == seq_len - 1:
                hidden_state[-batch_size:] = input
        
        #print("Hidden_state size:",hidden_state.shape)
        #asdad += 1            
        Times[Times.isnan()] = 1.0
        Times = Times.reshape([Times.shape[0],1])
        negative_hidden_state = torch.zeros((batch_size*seq_len + 2*batch_size,seq_size),device = device)
        init_hidden_state = torch.cat([hidden_state,negative_hidden_state],axis = 1)
        #print(init_hidden_state.shape)
        ## can be used for printing trajectories ##
        neuralpde.pdefunc.previous_hidden_state = None
        new_hidden_state_tajectory = neuralpde(init_hidden_state,batch_size,Times.to(device))
        new_hidden_state = new_hidden_state_tajectory[1]
        
        y_ = new_hidden_state[1*batch_size:-1*batch_size,:seq_size]
        
        #print(new_hidden_state.shape)
        outputs = []
        
        for j in range(seq_len):
            output = neuralpde.post_nn(y_[j*batch_size:(j+1)*batch_size])
            outputs.append(output)
        y_hat = torch.stack(outputs,dim = 1).to(device)
        #print(y_hat.shape)
        #print(train_y.shape)  
        #error += 1
        y_hat = torch.stack(outputs,dim = 1).to(device)
        train_y = train_y.type(y_hat.type())

        Y_hat = torch.cat((Y_hat,y_hat.detach().cpu()))
        Y = torch.cat((Y,train_y.detach().cpu()))
        loss = criterion(y_hat,train_y.to(device))
        if train:
            if k % 5 == 0:
                print("k = ", k)
                print(loss.item())
            k += 1
            #print("k = ",k)
            loss.backward()
            # for name,params in neuralpde.named_parameters():
            #     print("Name:",name)
            #     print(params.grad.shape)
            #     print(params.grad)
            #     print("-"*20)
            # print("*"*40)            
            optimizer.step()

        Loss += loss.item()
    
    loss = criterion(Y_hat,Y)
    print("Y_hat shape:",Y_hat.shape)
    print("Y shape:",Y.shape)
    if train:
        return loss.item()
    else:
        return loss.item(),Y_hat,Y 


def feed_forward_direct(loader,num_classes,train = True,return_sequences = True,seq_len = 64):
    if return_sequences:
        Y_hat = torch.zeros((0,seq_len,num_classes))
    else:
        Y_hat = torch.zeros((0,num_classes))

    Y = torch.zeros((0)).type(torch.int)

    if train:
        print("Entering with train mode")
        neuralpde.train()
    else:
        print("Entering eval mode")
        neuralpde.eval()

    Loss = 0
    k = 0
    
    for batch in loader:
        train_x,train_ts,train_y = batch
        #print("train_ts inside loader:",train_ts.shape)
        mask = None
        optimizer.zero_grad()
        batch_size = train_x.shape[0]
        #print("batch size:",batch_size)
        seq_len = train_x.shape[1]
        seq_size = hidden_size
        hidden_state = torch.zeros((batch_size*seq_len + 2*batch_size,seq_size),device = device)
        Times = torch.zeros((batch_size*seq_len),device = device)
        #print("Batch size:",batch_size)
        for j in range(seq_len):
            input = train_x[:,j].to(device)
            input = neuralpde.pre_nn(input)
            ts = train_ts[:,j].to(device)
            #print(ts)
            hidden_state[(j+1)*batch_size:(j+2)*batch_size] = input
            #print((j)*batch_size,(j+1)*batch_size)
            #print(((ts-ts.min())/(ts.max()-ts.min())).shape)
            Times[j*batch_size:(j+1)*batch_size] = ((ts-ts.min())/(ts.max()-ts.min()))[:,0]+1
            if j == 0:
                hidden_state[:batch_size] = input
            elif j == seq_len - 1:
                hidden_state[-batch_size:] = input
        
        hidden_state = hidden_state.T[:,None,:]
        #print("Hidden_state size:",hidden_state.shape)
        #asdad += 1            
        Times[Times.isnan()] = 1.0
        Times = Times.reshape([Times.shape[0],1])
        # negative_hidden_state = torch.zeros((batch_size*seq_len + 2*batch_size,seq_size),device = device)
        # init_hidden_state = torch.concat([hidden_state,negative_hidden_state],axis = 1)
        #print(init_hidden_state.shape)
        ## can be used for printing trajectories ##
        neuralpde.pdefunc.previous_hidden_state = None
        new_hidden_state_tajectory = neuralpde(hidden_state,batch_size,Times.to(device))
        new_hidden_state = new_hidden_state_tajectory[1].squeeze().T
        #print("Solver output shape:",new_hidden_state.shape)
        y_ = new_hidden_state[1*batch_size:-1*batch_size,:seq_size]
        
        
        outputs = []
        
        for j in range(seq_len):
            output = neuralpde.post_nn(y_[j*batch_size:(j+1)*batch_size])
            outputs.append(output)
        y_hat = torch.stack(outputs,dim = 1).to(device)
        #print(y_hat.shape)
        #print(train_y.shape)  
        #error += 1
        y_hat = torch.stack(outputs,dim = 1).to(device)
        train_y = train_y.type(y_hat.type())

        Y_hat = torch.cat((Y_hat,y_hat.detach().cpu()))
        Y = torch.cat((Y,train_y.detach().cpu()))
        loss = criterion(y_hat,train_y.to(device))
        if train:
            if k % 5 == 0:
                print("k = ", k)
                print(loss.item())
            k += 1
            loss.backward()
            # for name,params in neuralpde.named_parameters():
            #     print("Name:",name)
            #     print(params.grad.shape)
            #     print(params.grad)
            #     print("-"*20)
            # print("*"*40)
            optimizer.step()

        Loss += loss.item()
    
    loss = criterion(Y_hat,Y)
    print("Y_hat shape:",Y_hat.shape)
    print("Y shape:",Y.shape)
    if train:
        return loss.item()
    else:
        return loss.item(),Y_hat,Y 



train_loss = []
train_acc = []
test_loss = []
test_acc = []

Best_loss = 10000
Best_loss_train = 10000
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones = [100],gamma = 0.1)

print("Original place where train and test and valid loaders are being called")
trainloader, testloader, validloader, in_features, num_classes, return_sequences,seq_len = load_dataset("walk")

for epoch in range(epochs):
    print("\nEpoch:",epoch,seq_len)
    if dataset_name == "walk":
        if args.model == "wave":
            loss = feed_forward(trainloader,num_classes,train = True)
            print("Train loss:",loss)
            train_loss.append(loss)
            loss,_,_ = feed_forward(validloader,num_classes,train = False)
            print("Valid loss:",loss)
            test_loss.append(loss)
            print("*"*20)
            logger.info(f"Epoch: {epoch} - Train Loss:{train_loss[-1]}; Valid Loss: {test_loss[-1]}")
            logger.info(f"Best Validtaion loss: {Best_loss}, Best train loss:{Best_loss_train}")
        else:
            loss = feed_forward_direct(trainloader,num_classes,train = True)
            print("Train loss:",loss)
            train_loss.append(loss)
            loss,_,_ = feed_forward_direct(validloader,num_classes,train = False)
            print("Valid loss:",loss)
            test_loss.append(loss)
            print("*"*20)
            logger.info(f"Epoch: {epoch} - Train Loss:{train_loss[-1]}; Valid Loss: {test_loss[-1]}")
            logger.info(f"Best Validtaion loss: {Best_loss}, Best train loss:{Best_loss_train}")
       
    if loss < Best_loss:
        Best_loss = loss
        Best_loss_train = train_loss[-1]
        torch.save(neuralpde.state_dict(), "./"+dataset_name+"_"+args.model+"%s_model.pth"%(solver))
        print("Best validation loss:",Best_loss," then train loss:",Best_loss_train, "\n")
    scheduler.step()
    #np.save("/content/drive/MyDrive/wave_models/"+dataset+"heat_steps%d_%s_train_loss.npy"%(steps,solver),train_loss)
    #np.save("/content/drive/MyDrive/wave_models/"+dataset+"heat_steps%d_%s_valid_loss.npy"%(steps,solver),test_loss)

neuralpde.load_state_dict(torch.load("./"+dataset_name+"_"+args.model+"%s_model.pth"%(solver)))
if args.model == "wave":
    loss,_,_ = feed_forward(testloader,num_classes,train=False)
else:
    loss,_,_ = feed_forward_direct(testloader,num_classes,train = False)
logger.info(f"Test Loss: {loss}")
print("Test Loss:",loss)


