import os
import argparse
from irregular_sampled_datasets import PersonData,Walker2dImitationData
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader, TensorDataset
import torch.utils.data as data
import torch.nn as nn
import torch.optim as optim
from torchmetrics.functional import accuracy
import time
import numpy as np
import torch
import math
from torchdyn.core import NeuralDE,NeuralODE,MultipleShootingLayer
import torch.nn.functional as F
from itertools import chain
from GRUCell import GRU 
from NeuralPDE import NeuralPDE
import torchdyn
import logging
import sys 
from datasets.sepsis import get_data
from tqdm import tqdm
import time

import argparse


parser = argparse.ArgumentParser()


parser.add_argument("--model", default= "wavedirect")

args = parser.parse_args()




print("Torch version:",torch.__version__)
print("Torchdyn version:",torchdyn.__version__)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = "cpu"
print("Device:",device)
solver = "tsit5"
size = 64
default = 200
lr = 0.005
steps = 2
dataset_name = "sepsiswithOI"
epochs = 200
#model = "wavedoublegating" 
## wave (1d system), wavedirect (2d Solver 1 GRU), wavedirectmlp(2d solver + MLP + GRU), wavedirectzeronn (2d solver + no NN)

pos_weight = 10
pos_weight = torch.tensor(pos_weight)



logging.basicConfig(level = logging.INFO,format = '%(asctime)s %(message)s',filemode="w")
file_handler = logging.FileHandler(filename = dataset_name + args.model+".log", mode='a', encoding='utf=8', delay=False)
file_handler.setFormatter(logging.Formatter('%(asctime)s %(message)s'))
#logging.basicConfig(level = logging.INFO,format = '%(asctime)s %(message)s',filename = model_name + ".log",filemode="w")
logger = logging.getLogger(args.model+".log")
logger.addHandler(file_handler)
logger.info("Logging has started")
if args.model == "wavedirect":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Single GRU + Wave Speed")
elif args.model == "wavedirectmlp":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Single GRU + Hidden State MLP + Wave Speed")
elif args.model == "wavedirectzeronn":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Null + Wave Speed")
elif args.model == "wave":
    logger.info(f"Model: {args.model} -- Hidden Dim:{size} -- Dataset:{dataset_name} -- Source:Single GRU + Wave Speed + 1D Linear System")
    
## declare the input,hidden,output size and the pre and post neural networks ##

def load_physionet_dataset(batch_size):
    static_intensity = True
    time_intensity = True

    times, train_dataloader, val_dataloader, test_dataloader = get_data(static_intensity, time_intensity, batch_size)

    modified_train_data = []
    modified_val_data = []
    modified_test_data = []

    for batch in train_dataloader:
        *features, target, final_index = batch

        # Extract static features
        static_features = features[-1]
        features = features[:-1]
        max_seq_len = max([feature.size(1) for feature in features])

        # Repeat the static_features tensor along dimension 1 to match the max_seq_len
        static_features = static_features.unsqueeze(1).expand(-1, max_seq_len, -1)

        # Combine features and static_features
        combined_features = torch.cat(features, dim=2)
        final_combined_features = torch.cat([combined_features, static_features], dim=2)

        # Create 'train_ts' tensor for training data
        train_ts = times.view(1, -1, 1)[:, :-1, :].expand(final_combined_features.size(0), -1, 1)

        # Append the modified data to the list for training
        modified_train_data.append((final_combined_features, train_ts, target))

    for batch in val_dataloader:
        *features, target, final_index = batch

         
        static_features = features[-1]
        features = features[:-1]
        max_seq_len = max([feature.size(1) for feature in features])

        
        static_features = static_features.unsqueeze(1).expand(-1, max_seq_len, -1)

        
        combined_features = torch.cat(features, dim=2)
        final_combined_features = torch.cat([combined_features, static_features], dim=2)

        
        val_ts = times.view(1, -1, 1)[:, :-1, :].expand(final_combined_features.size(0), -1, 1)

        
        modified_val_data.append((final_combined_features, val_ts, target))

    for batch in test_dataloader:
        *features, target, final_index = batch

        
        static_features = features[-1]
        features = features[:-1]
        max_seq_len = max([feature.size(1) for feature in features])

        
        static_features = static_features.unsqueeze(1).expand(-1, max_seq_len, -1)

        
        combined_features = torch.cat(features, dim=2)
        final_combined_features = torch.cat([combined_features, static_features], dim=2)

        
        test_ts = times.view(1, -1, 1)[:, :-1, :].expand(final_combined_features.size(0), -1, 1)

        
        modified_test_data.append((final_combined_features, test_ts, target))

    
    modified_train_features, modified_train_ts, modified_train_targets = zip(*modified_train_data)
    modified_val_features, modified_val_ts, modified_val_targets = zip(*modified_val_data)
    modified_test_features, modified_test_ts, modified_test_targets = zip(*modified_test_data)

    
    modified_train_features = torch.cat(modified_train_features, dim=0)
    modified_train_ts = torch.cat(modified_train_ts, dim=0)
    modified_train_targets = torch.cat(modified_train_targets, dim=0)

    modified_val_features = torch.cat(modified_val_features, dim=0)
    modified_val_ts = torch.cat(modified_val_ts, dim=0)
    modified_val_targets = torch.cat(modified_val_targets, dim=0)

    modified_test_features = torch.cat(modified_test_features, dim=0)
    modified_test_ts = torch.cat(modified_test_ts, dim=0)
    modified_test_targets = torch.cat(modified_test_targets, dim=0)

    
    assert modified_train_features.size(0) == modified_train_ts.size(0) == modified_train_targets.size(0)
    assert modified_val_features.size(0) == modified_val_ts.size(0) == modified_val_targets.size(0)
    assert modified_test_features.size(0) == modified_test_ts.size(0) == modified_test_targets.size(0)

    
    modified_train_dataset = TensorDataset(modified_train_features, modified_train_ts, modified_train_targets)
    modified_val_dataset = TensorDataset(modified_val_features, modified_val_ts, modified_val_targets)
    modified_test_dataset = TensorDataset(modified_test_features, modified_test_ts, modified_test_targets)

    modified_trainloader = DataLoader(modified_train_dataset, batch_size=batch_size, shuffle=True)
    modified_valloader = DataLoader(modified_val_dataset, batch_size=batch_size, shuffle=False)
    modified_testloader = DataLoader(modified_test_dataset, batch_size=batch_size, shuffle=False)

    
    in_features = modified_train_features.size(-1)
    num_classes = int(torch.max(modified_train_targets).item() + 1)
    return_sequences = False  

    
    seqlen = modified_train_features.size(1)

    return modified_trainloader, modified_valloader, modified_testloader, in_features, num_classes, return_sequences, seqlen

 

def load_dataset(dataset_name):
    
    #if dataset_name == "walk":
    print("I am in load_dataset class")
    dataset = Walker2dImitationData(seq_len = 64)
    train_x = torch.Tensor(dataset.train_x)
    train_y = torch.LongTensor(dataset.train_y)
    train_ts = torch.Tensor(dataset.train_times)
    
    test_x = torch.Tensor(dataset.test_x)
    test_y = torch.LongTensor(dataset.test_y)
    test_ts = torch.Tensor(dataset.test_times)
    
    valid_x = torch.Tensor(dataset.valid_x)
    valid_y = torch.LongTensor(dataset.valid_y)
    valid_ts = torch.Tensor(dataset.valid_times)
    
    print(train_y.shape,test_y.shape)
    
    train = data.TensorDataset(train_x, train_ts, train_y)
    test = data.TensorDataset(test_x, test_ts, test_y)
    valid = data.TensorDataset(valid_x, valid_ts, valid_y)
    return_sequences = True
	    
    trainloader = data.DataLoader(train, batch_size=256, shuffle=True)
    testloader = data.DataLoader(test, batch_size=256, shuffle=False)
    in_features = train_x.size(-1)
    validloader = data.DataLoader(valid, batch_size=256, shuffle=False)
    num_classes = 17
    
    return trainloader, testloader, validloader , in_features, num_classes, return_sequences,train_x.shape[1]

seeds = [8,10,35,69,96]

for s in seeds:
    torch.cuda.empty_cache() 
    torch.manual_seed(s)
    torch.cuda.manual_seed(s)
    np.random.seed(s)
    logger.info(f"Seed: {s}")

    input_size = 283  ##----> input dimension of the vector
    hidden_size = size  ##---> dimension of the pde/ode
    output_size = 2  ##----> dimension of output
    seq_len = 71

    # pre_nn = PreNeuralNetwork(input_size,hidden_size)
    # post_nn = PostNeuralNetwork(hidden_size,output_size)
    # ## Heat equation to be declared ##
    # heat = Heat(hidden_size)

    neuralpde = NeuralPDE(input_size,hidden_size,output_size,seq_len,model = args.model)
    neuralpde.to(device)

    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    params = list(neuralpde.parameters())
    optimizer = optim.Adam(params,lr = lr)

    def calculate_auc(Y_hat, Y):
        auc_score = roc_auc_score(Y.detach().cpu().numpy(), Y_hat.detach().cpu().numpy())
        return auc_score

    def feed_forward(loader,num_classes,train = True,return_sequences = True,seq_len = 64):
        if return_sequences:
            Y_hat = torch.zeros((0,seq_len,num_classes))
        else:
            Y_hat = torch.zeros((0,num_classes))

        Y = torch.zeros((0)).type(torch.int)

        if train:
            print("Entering with train mode")
            neuralpde.train()
        else:
            print("Entering eval mode")
            neuralpde.eval()

        Loss = 0
        k = 0
        for batch in loader:
            train_x,train_ts,train_y = batch
            mask = None
            optimizer.zero_grad()
            batch_size = train_x.shape[0]
            #print("batch size:",batch_size)
            seq_len = train_x.shape[1]
            seq_size = hidden_size
            hidden_state = torch.zeros((batch_size*seq_len + 2*batch_size,seq_size),device = device)
            Times = torch.zeros((batch_size*seq_len),device = device)
            #print("Batch size:",batch_size)
            for j in range(seq_len):
                input = train_x[:,j].to(device)
                input = neuralpde.pre_nn(input)
                ts = train_ts[:,j].to(device)
                #print(ts)
                hidden_state[(j+1)*batch_size:(j+2)*batch_size] = input
                #print((j)*batch_size,(j+1)*batch_size)
                #print(((ts-ts.min())/(ts.max()-ts.min())).shape)
                Times[j*batch_size:(j+1)*batch_size] = ((ts-ts.min())/(ts.max()-ts.min()))[:,0]+1
                if j == 0:
                    hidden_state[:batch_size] = input
                elif j == seq_len - 1:
                    hidden_state[-batch_size:] = input
            
            #print("Hidden_state size:",hidden_state.shape)
            #asdad += 1            
            Times[Times.isnan()] = 1.0
            Times = Times.reshape([Times.shape[0],1])
            negative_hidden_state = torch.zeros((batch_size*seq_len + 2*batch_size,seq_size),device = device)
            init_hidden_state = torch.cat([hidden_state,negative_hidden_state],axis = 1)
            #print(init_hidden_state.shape)
            ## can be used for printing trajectories ##
            neuralpde.pdefunc.previous_hidden_state = None
            new_hidden_state_tajectory = neuralpde(init_hidden_state,batch_size,Times.to(device))
            new_hidden_state = new_hidden_state_tajectory[1]
            
            y_ = new_hidden_state[1*batch_size:-1*batch_size,:seq_size]
            
            #print(new_hidden_state.shape)
            
            
            for j in range(seq_len):
                output = neuralpde.post_nn(y_[j*batch_size:(j+1)*batch_size])


            y_hat = output.to(device) 
            target  = train_y
            train_y = train_y.to(torch.long)
            train_y = F.one_hot(train_y, num_classes=2).to(torch.float) 
            train_y = train_y.type(y_hat.type())
            
            Y = torch.cat((Y,target.detach().cpu()))
            Y_hat = torch.cat((Y_hat,y_hat.detach().cpu()))

            loss = criterion(y_hat.to(device),train_y.to(device))

            if train:
                if k % 500 == 0:
                    print("k = ", k)
                    print(loss.item())
                k += 1
                loss.backward() 
                optimizer.step()

            Loss += loss.item()
    
        preds = F.softmax(Y_hat, dim=1)[:, 1].unsqueeze(1) 
        Y = Y.unsqueeze(1)  
        auc = calculate_auc(preds, Y) 
        return auc,Loss 

            

    def feed_forward_direct(loader,num_classes,train = True,return_sequences = True,seq_len = 64):
        if return_sequences:
            Y_hat = torch.zeros((0,seq_len,num_classes))
        else:
            Y_hat = torch.zeros((0,num_classes))

        Y = torch.zeros((0)).type(torch.int)

        if train:
            print("Entering with train mode")
            neuralpde.train()
        else:
            print("Entering eval mode")
            neuralpde.eval()

        Loss = 0
        k = 0
        
        for batch in loader:
            train_x,train_ts,train_y = batch
            #print("train_ts inside loader:",train_ts.shape)
            mask = None
            optimizer.zero_grad()
            batch_size = train_x.shape[0]
            #print("batch size:",batch_size)
            seq_len = train_x.shape[1]
            seq_size = hidden_size
            hidden_state = torch.zeros((batch_size*seq_len + 2*batch_size,seq_size),device = device)
            Times = torch.zeros((batch_size*seq_len),device = device)
            #print("Batch size:",batch_size)
            for j in range(seq_len):
                input = train_x[:,j].to(device)
                input = neuralpde.pre_nn(input)
                ts = train_ts[:,j].to(device)
                #print(ts)
                hidden_state[(j+1)*batch_size:(j+2)*batch_size] = input
                #print((j)*batch_size,(j+1)*batch_size)
                #print(((ts-ts.min())/(ts.max()-ts.min())).shape)
                Times[j*batch_size:(j+1)*batch_size] = ((ts-ts.min())/(ts.max()-ts.min()))[:,0]+1
                if j == 0:
                    hidden_state[:batch_size] = input
                elif j == seq_len - 1:
                    hidden_state[-batch_size:] = input
            
            hidden_state = hidden_state.T[:,None,:]
            #print("Hidden_state size:",hidden_state.shape)
            #asdad += 1            
            Times[Times.isnan()] = 1.0
            Times = Times.reshape([Times.shape[0],1])
            # negative_hidden_state = torch.zeros((batch_size*seq_len + 2*batch_size,seq_size),device = device)
            # init_hidden_state = torch.concat([hidden_state,negative_hidden_state],axis = 1)
            #print(init_hidden_state.shape)
            ## can be used for printing trajectories ##
            neuralpde.pdefunc.previous_hidden_state = None
            new_hidden_state_tajectory = neuralpde(hidden_state,batch_size,Times.to(device))
            new_hidden_state = new_hidden_state_tajectory[1].squeeze().T
            #print("Solver output shape:",new_hidden_state.shape)
            y_ = new_hidden_state[1*batch_size:-1*batch_size,:seq_size]
            
            
            for j in range(seq_len):
                output = neuralpde.post_nn(y_[j*batch_size:(j+1)*batch_size])

            y_hat = output.to(device) 
            target  = train_y
            train_y = train_y.to(torch.long)
            train_y = F.one_hot(train_y, num_classes=2).to(torch.float) 
            train_y = train_y.type(y_hat.type())
            
            Y = torch.cat((Y,target.detach().cpu()))
            Y_hat = torch.cat((Y_hat,y_hat.detach().cpu()))

            loss = criterion(y_hat.to(device),train_y.to(device))

            if train:
                if k % 500 == 0:
                    print("k = ", k)
                    print(loss.item())
                k += 1
                loss.backward() 
                optimizer.step()

            Loss += loss.item()
    
        preds = F.softmax(Y_hat, dim=1)[:, 1].unsqueeze(1) 
        Y = Y.unsqueeze(1)  
        auc = calculate_auc(preds, Y) 
        return auc,Loss


    train_loss = []
    train_acc = []
    test_loss = []
    test_auc = []
    train_auc = []

    Best_loss = 10000
    Best_loss_train = 10000
    best_auc = 0
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones = [100],gamma = 0.1)

    print("Original place where train and test and valid loaders are being called")
    trainloader, testloader, validloader, in_features, num_classes, return_sequences,seq_len = load_physionet_dataset(256)

    for epoch in range(epochs):
        print("\nEpoch: and  Seq Len",epoch,seq_len)
        if dataset_name == "sepsiswithOI":
            if args.model == "wave":
                auc,loss =feed_forward(trainloader, num_classes, train=True,return_sequences = False, seq_len = seq_len)
                print("Train Auc:",auc," Train loss:",loss)
                train_loss.append(loss)
                train_auc.append(auc)  
                auc,loss =feed_forward(testloader,num_classes, train=False,return_sequences = False, seq_len = seq_len)
                print("Test Auc:",auc," Test loss:",loss)
                test_loss.append(loss)
                test_auc.append(auc)
                print("*"*20)
                logger.info(f"Epoch: {epoch} - Train Loss:{train_loss[-1]}; Test Loss: {test_loss[-1]}")
                logger.info(f"Auc: {auc}") 

            else:
                auc,loss =feed_forward_direct(trainloader, num_classes, train=True,return_sequences = False, seq_len = seq_len)
                print("Train Auc:",auc," Train loss:",loss)
                train_loss.append(loss)
                train_auc.append(auc)  
                auc,loss =feed_forward_direct(testloader,num_classes, train=False,return_sequences = False, seq_len = seq_len)
                print("Test Auc:",auc," Test loss:",loss)
                test_loss.append(loss)
                test_auc.append(auc)
                print("*"*20)
                logger.info(f"Epoch: {epoch} - Train Loss:{train_loss[-1]}; Test Loss: {test_loss[-1]}")
                logger.info(f"Auc: {auc}") 
        
        if auc > best_auc:
            best_auc = auc
        logger.info(f"Best Auc: {best_auc}") 
        scheduler.step()
    