import os
import argparse
from irregular_sampled_datasets import PersonData,Walker2dImitationData
import torch.utils.data as data
import torch.nn as nn
import torch.optim as optim
from torchmetrics.functional import accuracy
import time
import numpy as np
import torch
import math
from torchdyn.core import NeuralDE,NeuralODE,MultipleShootingLayer
import torch.nn.functional as F
from itertools import chain
from GRUCell import GRU 
from NeuralPDE import NeuralPDE
import torchdyn
import logging
import argparse


parser = argparse.ArgumentParser()



parser.add_argument("--model", default= "wavedirect")

args = parser.parse_args()


print("Torch version:",torch.__version__)
print("Torchdyn version:",torchdyn.__version__)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = "cpu"
print("Device:",device)
solver = "tsit5"
size = 64
default = 200
lr = 0.005
steps = 2
dataset_name = "person"
epochs = 200 

logging.basicConfig(level = logging.INFO,format = '%(asctime)s %(message)s',filemode="w")
file_handler = logging.FileHandler(filename = dataset_name + args.model+".log", mode='a', encoding='utf=8', delay=False)
file_handler.setFormatter(logging.Formatter('%(asctime)s %(message)s'))
#logging.basicConfig(level = logging.INFO,format = '%(asctime)s %(message)s',filename = model_name + ".log",filemode="w")
logger = logging.getLogger(args.model+".log")
logger.addHandler(file_handler)
logger.info("Logging has started")
logger.info("Logging has started")
if args.model == "wavedirect":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Single GRU + Wave Speed")
elif args.model == "wavedirectmlp":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Single GRU + Hidden State MLP + Wave Speed")
elif args.model == "wavedirectzeronn":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Null + Wave Speed")
elif args.model == "wave":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Single GRU + Wave Speed + 1D Linear System")
    
## declare the input,hidden,output size and the pre and post neural networks ##

def load_dataset(dataset_name):
    if dataset_name == "person":
        dataset = PersonData()
        train_x = torch.Tensor(dataset.train_x)
        train_y = torch.LongTensor(dataset.train_y)#[:,0]
        train_ts = torch.Tensor(dataset.train_t)
        test_x = torch.Tensor(dataset.test_x)
        test_y = torch.LongTensor(dataset.test_y)#[:,0]
        test_ts = torch.Tensor(dataset.test_t)
        train = data.TensorDataset(train_x, train_ts, train_y)
        test = data.TensorDataset(test_x, test_ts, test_y)

                
        print(train_y.shape,test_y.shape)
        return_sequences = True
    elif dataset_name == 'walk':
        dataset =  Walker2dImitationData(seq_len=64)
        train_x = torch.Tensor(dataset.train_x)
        train_y = torch.LongTensor(dataset.train_y)#[:,0]
        train_ts = torch.Tensor(dataset.train_times)
        test_x = torch.Tensor(dataset.test_x)
        test_y = torch.LongTensor(dataset.test_y)#[:,0]
        test_ts = torch.Tensor(dataset.test_times)
        valid_x = torch.Tensor(dataset.valid_x)
        valid_y = torch.LongTensor(dataset.valid_y)#[:,0]
        valid_ts = torch.Tensor(dataset.valid_times)

        train = data.TensorDataset(train_x, train_ts, train_y)
        test = data.TensorDataset(test_x, test_ts, test_y)

        valid = data.TensorDataset(valid_x, valid_ts, valid_y)
        return_sequences = True

        print(train_y.shape,test_y.shape,valid_y.shape)

    trainloader = data.DataLoader(train, batch_size=256, shuffle=True)
    testloader = data.DataLoader(test, batch_size=256, shuffle=False)
    in_features = train_x.size(-1)
    if dataset_name == 'walk':
        print("Entering load_dataset")
        validloader = data.DataLoader(valid, batch_size=256, shuffle=False)     
        num_classes = 17
        return trainloader, testloader, validloader , in_features, num_classes, return_sequences,train_x.shape[1]
    else:
        num_classes = int(torch.max(train_y).item() + 1)
        return trainloader, testloader, in_features, num_classes, return_sequences,train_x.shape[1]


seeds = [8,10,35,69,96]

for s in seeds:
    torch.cuda.empty_cache() 
    torch.manual_seed(s)
    torch.cuda.manual_seed(s)
    np.random.seed(s)
    logger.info(f"Seed: {s}")
    input_size = 7  ##----> input dimension of the vector
    hidden_size = 64  ##---> dimension of the pde/ode
    output_size = 7  ##----> dimension of output 
    seq_len = 32

    neuralpde = NeuralPDE(input_size,hidden_size,output_size,seq_len,model = args.model)
    neuralpde.to(device)

    criterion = nn.CrossEntropyLoss()
    params = list(neuralpde.parameters())
    optimizer = optim.Adam(params,lr = lr)

    def feed_forward(loader,num_classes,train,return_sequences,seq_len):
        if return_sequences:
            Y_hat = torch.zeros((0,seq_len,num_classes))
        else:
            Y_hat = torch.zeros((0,num_classes))

        Y = torch.zeros((0)).type(torch.int)

        if train:
            print("Entering with train mode")
            neuralpde.train()
        else:
            print("Entering eval mode")
            neuralpde.eval()

        Loss = 0
        k = 0
        for batch in loader:
            train_x,train_ts,train_y = batch
            mask = None
            optimizer.zero_grad()
            batch_size = train_x.shape[0] 
            seq_len = train_x.shape[1]
            seq_size = hidden_size
            hidden_state = torch.zeros((batch_size*seq_len + 2*batch_size,seq_size),device = device)
            Times = torch.zeros((batch_size*seq_len),device = device)
            #print("Batch size:",batch_size)
            for j in range(seq_len):
                input = train_x[:,j].to(device)
                input = neuralpde.pre_nn(input)
                ts = train_ts[:,j].to(device)
                #print(ts)
                hidden_state[(j+1)*batch_size:(j+2)*batch_size] = input
                #print((j)*batch_size,(j+1)*batch_size)
                #print(((ts-ts.min())/(ts.max()-ts.min())).shape)
                Times[j*batch_size:(j+1)*batch_size] = ((ts-ts.min())/(ts.max()-ts.min()))[:,0]+1
                if j == 0:
                    hidden_state[:batch_size] = input
                elif j == seq_len - 1:
                    hidden_state[-batch_size:] = input
            
            #print("Hidden_state size:",hidden_state.shape)
            #asdad += 1            
            Times[Times.isnan()] = 1.0
            Times = Times.reshape([Times.shape[0],1])
            negative_hidden_state = torch.zeros((batch_size*seq_len + 2*batch_size,seq_size),device = device)
            init_hidden_state = torch.cat([hidden_state,negative_hidden_state],axis = 1)
            #print(init_hidden_state.shape)
            ## can be used for printing trajectories ##
            neuralpde.pdefunc.previous_hidden_state = None
            new_hidden_state_tajectory = neuralpde(init_hidden_state,batch_size,Times.to(device))
            new_hidden_state = new_hidden_state_tajectory[1]
            
            y_ = new_hidden_state[1*batch_size:-1*batch_size,:seq_size]
            
            #print(new_hidden_state.shape)
            outputs = []
            
            for j in range(seq_len):
                output = neuralpde.post_nn(y_[j*batch_size:(j+1)*batch_size])
                outputs.append(output)
            y_hat = torch.stack(outputs,dim = 1).to(device)
            Y_hat = torch.cat((Y_hat,y_hat.detach().cpu()))
            y_hat = y_hat.view(-1, y_hat.size(-1))
            train_y = train_y.view(-1) 
            Y = torch.cat((Y,train_y)) 

            loss = criterion(y_hat,train_y.to(device))
            if train:
                if k % 500 == 0:
                    print("k = ", k)
                    print(loss.item())
                k += 1
                loss.backward()           
                optimizer.step()

            Loss += loss.item()
        
        preds = torch.argmax(Y_hat, dim=-1) 
        Y = Y.view(-1, seq_len)
        acc = accuracy(preds, Y.type(torch.int)) 
        return acc,Loss 


    def feed_forward_direct(loader,num_classes,train,return_sequences,seq_len):
        if return_sequences:
            Y_hat = torch.zeros((0,seq_len,num_classes))
        else:
            Y_hat = torch.zeros((0,num_classes))

        Y = torch.zeros((0)).type(torch.int)

        if train:
            print("Entering with train mode")
            neuralpde.train()
        else:
            print("Entering eval mode")
            neuralpde.eval()

        Loss = 0
        k = 0
        
        for batch in loader:
            train_x,train_ts,train_y = batch
            #print("train_ts inside loader:",train_ts.shape)
            mask = None
            optimizer.zero_grad()
            batch_size = train_x.shape[0]
            #print("batch size:",batch_size) 
            seq_size = hidden_size
            hidden_state = torch.zeros((batch_size*seq_len + 2*batch_size,seq_size),device = device)
            Times = torch.zeros((batch_size*seq_len),device = device)
            #print("Batch size:",batch_size)
            for j in range(seq_len):
                input = train_x[:,j].to(device)
                input = neuralpde.pre_nn(input)
                ts = train_ts[:,j].to(device)
                #print(ts)
                hidden_state[(j+1)*batch_size:(j+2)*batch_size] = input
                #print((j)*batch_size,(j+1)*batch_size)
                #print(((ts-ts.min())/(ts.max()-ts.min())).shape)
                Times[j*batch_size:(j+1)*batch_size] = ((ts-ts.min())/(ts.max()-ts.min()))[:,0]+1
                if j == 0:
                    hidden_state[:batch_size] = input
                elif j == seq_len - 1:
                    hidden_state[-batch_size:] = input
            
            hidden_state = hidden_state.T[:,None,:]
            #print("Hidden_state size:",hidden_state.shape)
            #asdad += 1            
            Times[Times.isnan()] = 1.0
            Times = Times.reshape([Times.shape[0],1])
            # negative_hidden_state = torch.zeros((batch_size*seq_len + 2*batch_size,seq_size),device = device)
            # init_hidden_state = torch.concat([hidden_state,negative_hidden_state],axis = 1)
            #print(init_hidden_state.shape)
            ## can be used for printing trajectories ##
            neuralpde.pdefunc.previous_hidden_state = None
            new_hidden_state_tajectory = neuralpde(hidden_state,batch_size,Times.to(device))
            new_hidden_state = new_hidden_state_tajectory[1].squeeze().T
            #print("Solver output shape:",new_hidden_state.shape)
            y_ = new_hidden_state[1*batch_size:-1*batch_size,:seq_size]
            
            
            outputs = []
            
            for j in range(seq_len):
                output = neuralpde.post_nn(y_[j*batch_size:(j+1)*batch_size])
                outputs.append(output)
            y_hat = torch.stack(outputs,dim = 1).to(device)
            Y_hat = torch.cat((Y_hat,y_hat.detach().cpu()))
            y_hat = y_hat.view(-1, y_hat.size(-1)) 
            train_y = train_y.view(-1) 
            Y = torch.cat((Y,train_y)) 

            loss = criterion(y_hat,train_y.to(device))
            if train:
                if k % 500 == 0:
                    print("k = ", k)
                    print(loss.item())
                k += 1
                loss.backward()           
                optimizer.step()

            Loss += loss.item()
        
        preds = torch.argmax(Y_hat, dim=-1) 
        Y = Y.view(-1, seq_len)
        acc = accuracy(preds, Y.type(torch.int)) 
        return acc,Loss 



    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []
    Best_loss = 10000
    best_acc = 0
    Best_loss_train = 10000
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones = [100],gamma = 0.1)

    print("Original place where train and test and valid loaders are being called")
    trainloader, testloader, in_features, num_classes, return_sequences,seq_len = load_dataset(dataset_name)

    for epoch in range(epochs):
        print("\nEpoch:",epoch,seq_len)
        if dataset_name == "person":
            if args.model == "wave":
                acc, loss = feed_forward(trainloader,num_classes, True, return_sequences, seq_len)
                print("Train Accuracy:",acc," Train loss:",loss)
                train_loss.append(loss)
                train_acc.append(acc)
                acc,loss =feed_forward(testloader,num_classes,False, return_sequences, seq_len)
                print("Test Accuracy:",acc," Test loss:",loss)
                test_loss.append(loss)
                test_acc.append(acc)
                print("*"*20)
                logger.info(f"Epoch: {epoch} - Train Loss:{train_loss[-1]}; Test Loss: {test_loss[-1]}")
                logger.info(f"Accuracy : {acc}, Best Accuracy :{best_acc}")
            else:
                acc, loss = feed_forward_direct(trainloader,num_classes, True, return_sequences, seq_len)
                print("Train Accuracy:",acc," Train loss:",loss)
                train_loss.append(loss)
                train_acc.append(acc)
                acc,loss =feed_forward_direct(testloader,num_classes,  False, return_sequences, seq_len)
                print("Test Accuracy:",acc," Test loss:",loss)
                test_loss.append(loss)
                test_acc.append(acc)
                print("*"*20)
                logger.info(f"Epoch: {epoch} - Train Loss:{train_loss[-1]}; Test Loss: {test_loss[-1]}")
                logger.info(f"Accuracy : {acc}, Best Accuracy :{best_acc}") 
        if acc > best_acc:
            best_acc = acc
        scheduler.step()
 
 