import torch
from torch import nn, optim
from sklearn.model_selection import train_test_split
import argparse
import json
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import os
from sklearn.model_selection import StratifiedKFold, train_test_split
import itertools
import misc as misc
# import sys
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import math
import torch.utils.data.sampler as Sampler
# from ogb.graphproppred import PygGraphPropPredDataset

# import time
from transformer_scratch_modular import create_autoencoder
# SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
# sys.path.append(os.path.dirname(SCRIPT_DIR))
from laplace_data import LaplacianDataset
# from utils.utils import set_seed
from sklearn.metrics import roc_auc_score

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

torch.set_printoptions(precision=2)
# set_seed(42)
cwd = os.getcwd()

parser = argparse.ArgumentParser(description='TopNets')
parser.add_argument('--exp_name', type=str, default='exp_1', metavar='N',
                    help='experiment_name')
parser.add_argument('--batch_size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=10000, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--num_filtrations', type=int, default=8, metavar='nf',
                    help='Number of filtration functions')
parser.add_argument('--nsteps', type=int, default=20, metavar='nf',
                    help='Steps for the ODE solver')
parser.add_argument('--out_ph', type=int, default=64, metavar='nf',
                    help='Out PH embedding dim')
parser.add_argument('--fil_hid', type=int, default=16, metavar='nf',
                    help='Filtration hidden dim')
parser.add_argument('--lr', type=float, default=1e-5, metavar='N',
                    help='learning rate')
parser.add_argument('--cont', metavar='N', type=bool, default=True, help='Continuous Type or Not')
parser.add_argument("--dataset",type=str,default="ogbg-molhiv",choices=["PROTEINS_full","NCI109","NCI1","IMDB-BINARY",],)
parser.add_argument('--weight_decay', type=float, default=1e-8, metavar='N',
                    help='weight decay')
parser.add_argument("--diagram_type",type=str,default="rephine",choices=["rephine", "standard", "none"],) 
parser.add_argument("--gnn",type=str,default="gcn",choices=["gcn", "gin"],)              
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
                    help='Clip gradient norm (default: None, no clipping)')


parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
                    help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--layer_decay', type=float, default=0.75,
                    help='layer-wise lr decay from ELECTRA/BEiT')

parser.add_argument('--min_lr', type=float, default=1e-8, metavar='LR',
                    help='lower lr bound for cyclic schedulers that hit 0')

parser.add_argument('--num_workers', default=1, type=int)
parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')

parser.add_argument('--warmup_epochs', type=int, default=500, metavar='N',
                    help='epochs to warmup LR')
parser.add_argument('--gpu', type=int, default=[0,1], metavar='N',
                    help='epochs to warmup LR')
parser.add_argument('--world_size', default=4, type=int,
                    help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
                    help='url used to set up distributed training')
    
args = parser.parse_args()
args.cuda = torch.cuda.is_available()
device = torch.device("cuda")
dtype = torch.float32
print(args)

full_data = LaplacianDataset(root='/projappl/project_2011438/PhDProjects/DeepHKS/TopNets/RePHINE/datasets/protein_debugged')
indices = np.arange(0,len(full_data))

# train_size = 0.8
# val_size = 0.1
# test_size = 0.1

import time

class Timer:
    def __init__(self):
        self.start_time = None
        self.elapsed = 0.0
        self.running = False

    def start(self):
        if not self.running:
            self.start_time = time.time()
            self.running = True

    def stop(self):
        if self.running:
            self.elapsed += time.time() - self.start_time
            self.running = False

    def reset(self):
        self.start_time = None
        self.elapsed = 0.0
        self.running = False

    def get_elapsed(self):
        if self.running:
            return self.elapsed + (time.time() - self.start_time)
        return self.elapsed

# train_data, val_data, test_data = torch.utils.data.random_split(full_data, [train_size, val_size, test_size])
# dataset = PygGraphPropPredDataset(name='ogbg-molhiv', root=path)
# split_idx = dataset.get_idx_split()
# train_data, val_data, test_data = full_data[split_idx["train"]], full_data[split_idx["valid"]], full_data[split_idx["test"]]
def get_label_fromTU(dataset):
    labels = []
    for i in range(len(dataset)):
        labels.append(dataset[i][1])
    return labels

dataset = full_data
timer = Timer()
n_instances = len(dataset)

# skf = StratifiedKFold(n_splits=5,
#                         random_state=42, shuffle=True)
skf = StratifiedKFold(n_splits=5,
                        random_state=None, shuffle=False)
skf_iterator = skf.split(
    torch.tensor([i for i in range(n_instances)]), torch.tensor(get_label_fromTU(dataset)))



train_index, test_index = next(
    itertools.islice(skf_iterator, 0, None))
train_index, val_index = train_test_split(
    train_index, random_state=42)

train_index = train_index.tolist()
val_index = val_index.tolist()
test_index = test_index.tolist()

array1 = np.array(train_index)
array2 = np.array(test_index)

# data = np.load("mat.npz")
# array1 = data['name1'].astype(int).reshape(-1) # train data
# array2 = data['name2'].astype(int).reshape(-1) # test data

print(len(dataset))
train_indices = array1.astype(int)
test_indices = array2.astype(int)

# train_data, val_data = full_data[array1], full_data[array2]


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def mean(lst):
    return sum(lst) / len(lst)

def collate_point_clouds(batch):
    point_clouds = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    labels = torch.tensor(labels)
    padded_clouds = pad_sequence(point_clouds, batch_first=True, padding_value=0)
    batch_size = padded_clouds.size(0)
    # padded_clouds = padded_clouds.view(batch_size, -1)
    lengths = torch.tensor([pc.size(0) for pc in point_clouds])
    return padded_clouds, labels, lengths

# train_loader = DataLoader(train_data, batch_size=args.batch_size, collate_fn=collate_point_clouds)
# val_loader = DataLoader(val_data, batch_size=len(val_data), collate_fn=collate_point_clouds)

misc.init_distributed_mode(args)
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()

class DistributedSubset:
    """
    https://discuss.pytorch.org/t/how-to-use-my-own-sampler-when-i-already-use-distributedsampler/62143/8
    
    """
    #It’s common to call the total number of processes the world size

    def __init__(self,indices):
        self.indices = indices
    
    def __iter__(self):
        # deterministically shuffle based on epoch
        x=[self.indices[i] for i in torch.randperm(len(self.indices))]
        return iter(x)
    def __len__(self):
        return len(self.indices)

    def set_epoch(self, epoch):
        self.epoch = epoch

# sampler_train = DistributedSubset(
#     train_indices, num_replicas=num_tasks, rank=global_rank, shuffle=False
# )

# sampler_val = DistributedSubset(
#                 test_indices, num_replicas=num_tasks, rank=global_rank, shuffle=False)


sampler_train = DistributedSubset(
    train_indices
)

sampler_val = DistributedSubset(
                test_indices)

train_loader = torch.utils.data.DataLoader(
        full_data, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
        prefetch_factor=2,
        collate_fn=collate_point_clouds
    )

val_loader = torch.utils.data.DataLoader(
    full_data, sampler=sampler_val,
    # batch_size=args.batch_size,
    batch_size=len(test_indices),
    # num_workers=args.num_workers,
    num_workers=1,
    pin_memory=args.pin_mem,
    drop_last=False,
    collate_fn=collate_point_clouds
)

model = create_autoencoder(node_features = 101, dim = 64, M = 64, latent_dim = 16, N = 128, depth = 6, output_dim=2)
model = model.to(device)

model = torch.nn.parallel.DistributedDataParallel(model, device_ids=None, find_unused_parameters=True)

# evaluator = Evaluator(args.dataset)
# evaluator = None

# print(model)
def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate with half-cycle cosine after warmup"""
    if epoch < args.warmup_epochs:
        lr = args.lr * epoch / args.warmup_epochs 
    else:
        lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
            (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
    for param_group in optimizer.param_groups:
        if "lr_scale" in param_group:
            param_group["lr"] = lr * param_group["lr_scale"]
        else:
            param_group["lr"] = lr
    return lr

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(model))
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.0)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
best_loss = float('inf')
loss_ls = []
num_classes = 2
loss_fn = torch.nn.CrossEntropyLoss()
# loss_fn = torch.nn.L1Loss()

for epoch in range(args.epochs):
    total_train_loss = 0
    total_val_loss = 0
    model.train()
    break_var = False
    for id, batch in enumerate(train_loader):
        
        if id > 9:
            break_var = True
            break
        
        if id % 3 == 0:
            adjust_learning_rate(optimizer, id / len(train_loader) + epoch, args)
        
        optimizer.zero_grad()
        input = batch[0].to(device)
        labels = batch[1].to(device).long()
        labels = torch.nn.functional.one_hot(labels, num_classes=2).to(device).float()

        assert not torch.isnan(input).any(), "INPTUT Tensor contains NaN values!"
        timer.start()
        pred = model(input.float())
        timer.stop()
        nz = torch.count_nonzero(labels)
        total_count = labels.shape[0]

        # labels = torch.nn.functional.one_hot(labels.squeeze(), num_classes=num_classes).float().to(device)
        # print("val_acc", val_acc)
        loss = loss_fn(pred.squeeze(), labels)
        loss.backward()
        optimizer.step()
        total_train_loss = total_train_loss + loss.item()
        loss_ls.append(loss.item())
        if id % 100 == 0:
            print(f"Epoch: {epoch}, Batch: {id/ len(train_loader)}, Loss: {loss.item()}")
    # print(loss_ls[-1])
    break
    if epoch % 500 == 0:
        torch.save(model.state_dict(), f"model_state_dict_transformer_scratch_protein_{epoch}.pth")
    
    if epoch % 10 == 0:
        model.eval()
        for batch in val_loader:
            input = batch[0].to(device)
            labels = batch[1].to(device)
            pred = model(input.float())
            labels = torch.nn.functional.one_hot(labels, num_classes=2).to(device)
            pred = pred.detach()
            labels = labels.detach()
            auc = (pred.argmax(dim=-1) == labels.argmax(dim=-1)).float().mean()
            print("Val acc:", auc.item())
            

print(f"Total time: {timer.get_elapsed():.2f} seconds")
