import os
import argparse
import pickle
from irregular_sampled_datasets import PersonData,Walker2dImitationData
import torch.utils.data as data
import torch.nn as nn
import torch.optim as optim
from torchmetrics.functional import accuracy
import time
import numpy as np
import torch
import math
from torchdyn.core import NeuralDE,NeuralODE,MultipleShootingLayer
import torch.nn.functional as F
from itertools import chain
from GRUCell import GRU 
from NeuralPDE import NeuralPDE
import torchdyn
import logging
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, precision_score, recall_score, roc_curve, roc_auc_score


import argparse


parser = argparse.ArgumentParser()


parser.add_argument("--model", default= "wavedirect")
parser.add_argument("--topic", default= "sydneysiege")

args = parser.parse_args()


print("Torch version:",torch.__version__)
print("Torchdyn version:",torchdyn.__version__)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = "cpu"
print("Device:",device)
solver = "tsit5"
size = 64
default = 200
seed = 0
lr = 0.005
steps = 2
dataset_name = "twitterunseen"
epochs = 200
#model = "wavedoublegating" 
#topic = "sydneysiege"
## wave (1d system), wavedirect (2d Solver 1 GRU), wavedirectmlp(2d solver + MLP + GRU), wavedirectzeronn (2d solver + no NN)

logging.basicConfig(level = logging.INFO,format = '%(asctime)s %(message)s',filemode="w")
file_handler = logging.FileHandler(filename = dataset_name + args.topic+ args.model+".log", mode='a', encoding='utf=8', delay=False)
file_handler.setFormatter(logging.Formatter('%(asctime)s %(message)s'))
#logging.basicConfig(level = logging.INFO,format = '%(asctime)s %(message)s',filename = model_name + ".log",filemode="w")
logger = logging.getLogger(args.model+".log")
logger.addHandler(file_handler)
logger.info("Logging has started") 
if args.model == "wavedirect":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Single GRU + Wave Speed")
elif args.model == "wavedirectmlp":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Single GRU + Hidden State MLP + Wave Speed")
elif args.model == "wavedirectzeronn":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Null + Wave Speed")
elif args.model == "wave":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Single GRU + Wave Speed + 1D Linear System")
    
## declare the input,hidden,output size and the pre and post neural networks ##

def load_dataset(dataset_name):
    
    #if dataset_name == "walk":
    print("I am in load_dataset class")
    dataset = Walker2dImitationData(seq_len = 64)
    train_x = torch.Tensor(dataset.train_x)
    train_y = torch.LongTensor(dataset.train_y)
    train_ts = torch.Tensor(dataset.train_times)
    
    test_x = torch.Tensor(dataset.test_x)
    test_y = torch.LongTensor(dataset.test_y)
    test_ts = torch.Tensor(dataset.test_times)
    
    valid_x = torch.Tensor(dataset.valid_x)
    valid_y = torch.LongTensor(dataset.valid_y)
    valid_ts = torch.Tensor(dataset.valid_times)
    
    print(train_y.shape,test_y.shape)
    
    train = data.TensorDataset(train_x, train_ts, train_y)
    test = data.TensorDataset(test_x, test_ts, test_y)
    valid = data.TensorDataset(valid_x, valid_ts, valid_y)
    return_sequences = True
	    
    trainloader = data.DataLoader(train, batch_size=256, shuffle=True)
    testloader = data.DataLoader(test, batch_size=256, shuffle=False)
    in_features = train_x.size(-1)
    validloader = data.DataLoader(valid, batch_size=256, shuffle=False)
    num_classes = 17
    
    return trainloader, testloader, validloader , in_features, num_classes, return_sequences,train_x.shape[1]


datafile = "./data/rumoureval2019/train/%s_new.pkl" % (args.topic)

 
batch_size = 32
seq_len = 10

def create_data_new(data_x,data_t,data_y):
	data_size = data_x.shape[0]//seq_len# - seq_len + 1
	x = np.zeros((data_size,seq_len,data_x.shape[-1]))
	y = np.zeros((data_size,seq_len))#.type(torch.LongTensor)
	t = np.zeros((data_size,seq_len))
	for i in range(data_size):
		#print(x[i,:].shape,data_x[i*seq_len:(i+1)*seq_len].shape)
		x[i,:] = data_x[i*seq_len:(i+1)*seq_len]
		t[i,:] = data_t[i*seq_len:(i+1)*seq_len]
		y[i,:] = data_y[i*seq_len:(i+1)*seq_len]
	return x,t,y

train_x ,train_t,train_y = None,None,None
test_x, test_t,test_y = None,None,None
valid_x, valid_t,valid_y = None,None,None
started = 0
Topics = ['charliehebdo','sydneysiege','ottawashooting','ferguson']
for i in range(3):
    if not args.topic == Topics[i]:
        datafile = "./data/rumoureval2019/train/%s_new.pkl" % (Topics[i])
        print("Reading from:", datafile, "seed:", seed)
        f = open(datafile, 'rb')
        (train_X, valid_X, _, train_Y, valid_Y, _) = pickle.load(f)
        print(train_X.shape)
        f.close()
        if started == 0:
            train_x, train_t, train_y = (torch.from_numpy(train_X[:, :, 2:]), torch.from_numpy(train_X[:, :, 1]), torch.from_numpy(train_Y))
            valid_x, valid_t, valid_y = (torch.from_numpy(valid_X[:, :, 2:]), torch.from_numpy(valid_X[:, :, 1]), torch.from_numpy(valid_Y))
            started = 1
        else:
            tx, tt, ty = (torch.from_numpy(train_X[:, :, 2:]), torch.from_numpy(train_X[:, :, 1]), torch.from_numpy(train_Y))
            train_x = np.concatenate((train_x, tx), 0)
            train_t = np.concatenate((train_t, tt), 0)
            train_y = np.concatenate((train_y, ty), 0)

            tx, tt, ty = (torch.from_numpy(valid_X[:, :, 2:]), torch.from_numpy(valid_X[:, :, 1]), torch.from_numpy(valid_Y))
            valid_x = np.concatenate((valid_x, tx), 0)
            valid_t = np.concatenate((valid_t, tt), 0)
            valid_y = np.concatenate((valid_y, ty), 0)
    else:
        datafile = "./data/rumoureval2019/train/%s_new.pkl" % (Topics[i])
        print("Reading from:", datafile, "seed:", seed)
        f = open(datafile, 'rb')
        (_, _, test_x, _, _, test_y) = pickle.load(f)
        test_x, test_t, test_y = test_x[:, :, 2:], test_x[:, :, 1], test_y

print(train_x.shape,train_t.shape,train_y.shape)

train_t = train_t[:, : , None] 
test_t = test_t[:, : , None]
valid_t = valid_t[ :, : , None]

train = data.TensorDataset(torch.from_numpy(train_x).type(torch.FloatTensor), torch.from_numpy(train_t).type(torch.FloatTensor), torch.from_numpy(train_y).type(torch.LongTensor))
test = data.TensorDataset(torch.from_numpy(test_x).type(torch.FloatTensor), torch.from_numpy(test_t).type(torch.FloatTensor), torch.from_numpy(test_y).type(torch.LongTensor))
valid = data.TensorDataset(torch.from_numpy(valid_x).type(torch.FloatTensor), torch.from_numpy(valid_t).type(torch.FloatTensor), torch.from_numpy(valid_y).type(torch.LongTensor))
'''
train = data.TensorDataset(train_x.type(torch.FloatTensor), train_t.type(torch.FloatTensor), train_y.type(torch.LongTensor))
test = data.TensorDataset(test_x.type(torch.FloatTensor), test_t.type(torch.FloatTensor), test_y.type(torch.LongTensor))
valid = data.TensorDataset(valid_x.type(torch.FloatTensor), valid_t.type(torch.FloatTensor), valid_y.type(torch.LongTensor))
'''





return_sequences = True

trainloader = data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4)
testloader = data.DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=4)
validloader = data.DataLoader(valid, batch_size=batch_size, shuffle=False, num_workers=4)

in_features = train_x.shape[-1]

num_classes = 4

seeds = [8,10,35,69,96]

for s in seeds:
    torch.cuda.empty_cache() 
    torch.manual_seed(s)
    torch.cuda.manual_seed(s)
    np.random.seed(s)
    logger.info(f"Seed: {s}")

    input_size = in_features  ##----> input dimension of the vector
    hidden_size = size  ##---> dimension of the pde/ode
    output_size = num_classes 

    # pre_nn = PreNeuralNetwork(input_size,hidden_size)
    # post_nn = PostNeuralNetwork(hidden_size,output_size)
    # ## Heat equation to be declared ##
    # heat = Heat(hidden_size)

    neuralpde = NeuralPDE(input_size,hidden_size,output_size,seq_len,model = args.model)
    neuralpde.to(device)

    criterion = nn.CrossEntropyLoss()
    params = list(neuralpde.parameters())
    optimizer = optim.Adam(params,lr = lr)

    def feed_forward(loader,num_classes,train = True,return_sequences = True,seq_len = 10):
        if return_sequences:
            Y_hat = torch.zeros((0,seq_len,num_classes))
        else:
            Y_hat = torch.zeros((0,num_classes))

        Y = torch.zeros((0)).type(torch.int)

        if train:
            print("Entering with train mode")
            neuralpde.train()
        else:
            print("Entering eval mode")
            neuralpde.eval()

        Loss = 0
        k = 0
        for batch in loader:
            
            train_x,train_ts,train_y = batch
            mask = None
            optimizer.zero_grad()
            batch_size = train_x.shape[0] 
            seq_len = train_x.shape[1]
            seq_size = hidden_size
            hidden_state = torch.zeros((batch_size*seq_len + 2*batch_size,seq_size),device = device)
            Times = torch.zeros((batch_size*seq_len),device = device) 
            for j in range(seq_len):
                input = train_x[:,j].to(device)
                input = neuralpde.pre_nn(input)
                ts = train_ts[:,j].to(device) 
                hidden_state[(j+1)*batch_size:(j+2)*batch_size] = input 
                Times[j*batch_size:(j+1)*batch_size] = ((ts-ts.min())/(ts.max()-ts.min()))[:,0]+1
                if j == 0:
                    hidden_state[:batch_size] = input
                elif j == seq_len - 1:
                    hidden_state[-batch_size:] = input
            
            Times[Times.isnan()] = 1.0
            Times = Times.reshape([Times.shape[0],1])
            negative_hidden_state = torch.zeros((batch_size*seq_len + 2*batch_size,seq_size),device = device)
            init_hidden_state = torch.cat([hidden_state,negative_hidden_state],axis = 1) 
            neuralpde.pdefunc.previous_hidden_state = None
            new_hidden_state_tajectory = neuralpde(init_hidden_state,batch_size,Times.to(device))
            new_hidden_state = new_hidden_state_tajectory[1] 
            y_ = new_hidden_state[1*batch_size:-1*batch_size,:seq_size] 
            outputs = [] 
            for j in range(seq_len):
                output = neuralpde.post_nn(y_[j*batch_size:(j+1)*batch_size])
                outputs.append(output) 
            y_hat = torch.stack(outputs,dim = 1).to(device)  
            Y_hat = torch.cat((Y_hat,y_hat.detach().cpu()))  
            y_hat = y_hat.view(-1, y_hat.size(-1))
            train_y = train_y.view(-1) 
            Y = torch.cat((Y,train_y)) 
            loss = criterion(y_hat,train_y.to(device))

            if train:
                if k % 5 == 0:
                    print("k = ", k)
                    print(loss.item())
                k += 1 
                loss.backward() 
                optimizer.step()

            Loss += loss.item()
        
        preds = torch.argmax(Y_hat, dim=-1) 
        Y = Y.view(-1, seq_len)
        acc = accuracy(preds, Y.type(torch.int))  
        return acc,Loss,Y.type(torch.int).flatten(),preds.flatten()



    def feed_forward_direct(loader,num_classes,train = True,return_sequences = True,seq_len = 10):
        if return_sequences:
            Y_hat = torch.zeros((0,seq_len,num_classes))
        else:
            Y_hat = torch.zeros((0,num_classes))

        Y = torch.zeros((0)).type(torch.int)

        print("Num of classes",num_classes )
        if train:
            print("Entering with train mode")
            neuralpde.train()
        else:
            print("Entering eval mode")
            neuralpde.eval()

        Loss = 0
        k = 0
        
        for batch in loader:
            train_x,train_ts,train_y = batch 
            mask = None
            optimizer.zero_grad()
            batch_size = train_x.shape[0] 
            seq_len = train_x.shape[1]
            seq_size = hidden_size
            hidden_state = torch.zeros((batch_size*seq_len + 2*batch_size,seq_size),device = device)
            Times = torch.zeros((batch_size*seq_len),device = device) 
            for j in range(seq_len):
                input = train_x[:,j].to(device)
                input = neuralpde.pre_nn(input)
                ts = train_ts[:,j].to(device) 
                hidden_state[(j+1)*batch_size:(j+2)*batch_size] = input 
                Times[j*batch_size:(j+1)*batch_size] = ((ts-ts.min())/(ts.max()-ts.min()))[:,0]+1
                if j == 0:
                    hidden_state[:batch_size] = input
                elif j == seq_len - 1:
                    hidden_state[-batch_size:] = input
            
            hidden_state = hidden_state.T[:,None,:] 
            Times[Times.isnan()] = 1.0
            Times = Times.reshape([Times.shape[0],1]) 
            neuralpde.pdefunc.previous_hidden_state = None
            new_hidden_state_tajectory = neuralpde(hidden_state,batch_size,Times.to(device))
            new_hidden_state = new_hidden_state_tajectory[1].squeeze().T 
            y_ = new_hidden_state[1*batch_size:-1*batch_size,:seq_size] 
            outputs = [] 
            for j in range(seq_len):
                output = neuralpde.post_nn(y_[j*batch_size:(j+1)*batch_size])
                outputs.append(output) 
            y_hat = torch.stack(outputs,dim = 1).to(device)  
            Y_hat = torch.cat((Y_hat,y_hat.detach().cpu()))  
            y_hat = y_hat.view(-1, y_hat.size(-1))
            train_y = train_y.view(-1) 
            Y = torch.cat((Y,train_y)) 
            loss = criterion(y_hat,train_y.to(device))

            if train:
                if k % 5 == 0:
                    print("k = ", k)
                    print(loss.item())
                k += 1 
                loss.backward() 
                optimizer.step()

            Loss += loss.item()
        
        preds = torch.argmax(Y_hat, dim=-1) 
        Y = Y.view(-1, seq_len)
        acc = accuracy(preds, Y.type(torch.int))  
        return acc,Loss,Y.type(torch.int).flatten(),preds.flatten()


    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []
    best_acc = 0
    Best_loss = 10000
    Best_loss_train = 10000
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones = [100],gamma = 0.1)

    for epoch in range(epochs):
        print("\nEpoch:",epoch,seq_len)
        if dataset_name == "twitterunseen":
            if args.model == "wave":
                acc,loss,_,_ = feed_forward(trainloader,num_classes, True, True, seq_len)
                print("Train Accuracy:",acc," Train loss:",loss)
                train_loss.append(loss)
                acc,loss,_,_ = feed_forward(validloader,num_classes, False, True, seq_len)
                print("valid Accuracy:",acc," Test loss:",loss)
                test_loss.append(loss)
                test_acc.append(acc)
                print("*"*20)
                logger.info(f"Epoch: {epoch} - Train Loss:{train_loss[-1]}; Valid Loss: {test_loss[-1]}")
                logger.info(f"Test Accuracy: {acc}")
                logger.info(f"Best Validtaion loss: {Best_loss}, Best train loss:{Best_loss_train}")
            else:
                acc,loss,_,_ = feed_forward_direct(trainloader,num_classes,True, True, seq_len)
                print("Train Accuracy:",acc," Train loss:",loss)
                train_loss.append(loss)
                acc,loss,_,_ = feed_forward_direct(validloader,num_classes, False, True, seq_len)
                print("valid Accuracy:",acc," Test loss:",loss)
                test_loss.append(loss)
                test_acc.append(acc)
                print("*"*20)
                logger.info(f"Epoch: {epoch} - Train Loss:{train_loss[-1]}; Valid Loss: {test_loss[-1]}")
                logger.info(f"Test Accuracy: {acc}")
                logger.info(f"Best Validtaion loss: {Best_loss}, Best train loss:{Best_loss_train}")
        
            if loss < Best_loss:
                Best_loss = loss
                Best_loss_train = train_loss[-1]
                torch.save(neuralpde.state_dict(), "./"+dataset_name+args.topic+"_"+args.model+"%s_model.pth"%(solver))
                print("Best validation loss:",Best_loss," then train loss:",Best_loss_train, "\n")
            logging.info(f"Best validation loss:: {Best_loss} - Train Loss:{Best_loss_train};")
        scheduler.step()
        #np.save("/content/drive/MyDrive/wave_models/"+dataset+"heat_steps%d_%s_train_loss.npy"%(steps,solver),train_loss)
        #np.save("/content/drive/MyDrive/wave_models/"+dataset+"heat_steps%d_%s_valid_loss.npy"%(steps,solver),test_loss)

    neuralpde.load_state_dict(torch.load("./"+dataset_name+args.topic+"_"+args.model+"%s_model.pth"%(solver)))

    if args.model == "wave":
        acc,loss,test_y,pred_y =feed_forward(testloader,num_classes,False,True, seq_len)
    else:
        acc,loss,test_y,pred_y =feed_forward_direct(testloader,num_classes, False, True, seq_len)

    print("Test loss:",loss)
    print("test Accuracy:",acc," test loss:",loss)
    logger.info(f"Test Loss: {loss}, Test Acc, {acc}")
    test_y = test_y.numpy()
    pred_y = pred_y.numpy()
    acc_score = accuracy_score(test_y, pred_y)
    averageMethod = 'weighted'

    from sklearn.preprocessing import LabelBinarizer
    def multiclass_roc_auc_score(y_test, y_pred, average):
        lb = LabelBinarizer()
        lb.fit(y_test)
        y_test_tr = lb.transform(y_test)
        y_pred_tr = lb.transform(y_pred)
    #     fpr, tpr, _ = roc_curve(y_test, y_pred)
        auc_score = roc_auc_score(y_test_tr, y_pred_tr, average=average)
        return auc_score
    fscore = f1_score(test_y, pred_y, average=averageMethod)
    recall = recall_score(test_y, pred_y, average=averageMethod)
    precision = precision_score(test_y, pred_y, average=averageMethod)
    auc_score = multiclass_roc_auc_score(test_y, pred_y, average=averageMethod)

    print("auc_score : %f, fscore : %f, recall : %f, precision : %f"%(auc_score, fscore, recall, precision))
    logger.info(f"auc_score: {auc_score}, fscore: {fscore}, recall: {recall}, precision: {precision}")