import numpy as np

import os
import pandas as pd
import torch
import torch.nn as nn
import wandb
import os.path as osp

from copy import deepcopy
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from tqdm import tqdm

from data import load_dataset, preprocess
from model import ModelAsync
from setup_utils import load_train_yaml, set_seed
from torch_geometric.utils import negative_sampling
from nets import *

def main(args):
    model_name = "Async"
    yaml_data = load_train_yaml(args.dataset, model_name)

    config_df = pd.json_normalize(yaml_data, sep='/')
    # Number of time steps

    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

    g = load_dataset(args.dataset)
    X_one_hot_3d, Y, E_one_hot,\
        X_marginal, Y_marginal, E_marginal, X_cond_Y_marginals = preprocess(g)

    import pickle
    from utils import performance_measure, split_manual, refine_label_order
    labels, og_to_new = refine_label_order(Y)
    pseudo_label = './pretrain/' + args.dataset + '_' + str(args.imb_rate) + '_' + str(args.im_class_num) + '_' + args.pretrain +'.txt'
    with open(pseudo_label, 'rb') as f: 
        pred_Y, data_train_mask = pickle.load(f)
        Y = pred_Y.to(device)


    MODEL_PATH = './pretrain/GraphENS_' + args.dataset + '.pth'
    pretrained_model = GAT(2, X_one_hot_3d.size()[0], 64, Y.size()[1], 4, is_add_self_loops=True)
    pretrained_model.load_state_dict(torch.load(MODEL_PATH),strict=False)

    # (F, |V|, 2)
    X_one_hot_3d = X_one_hot_3d.to(device)
    # (|V|, F, 2)
    X_one_hot_2d = torch.transpose(X_one_hot_3d, 0, 1)
    # (|V|, 2 * F)
    X_one_hot_2d = X_one_hot_2d.reshape(X_one_hot_2d.size(0), -1)
    
    Y = Y.to(device)
    E_one_hot = E_one_hot.to(device)

    X_marginal = X_marginal.to(device)
    Y_marginal = Y_marginal.to(device)
    E_marginal = E_marginal.to(device)
    X_cond_Y_marginals = X_cond_Y_marginals.to(device)

    N = g.num_nodes()
    dst, src = torch.triu_indices(N, N, offset=1, device=device)
    # (|E|, 2), |E| for number of edges
    edge_index = torch.stack([dst, src], dim=1)

    # Set seed for better reproducibility.
    set_seed()

    train_config = yaml_data["train"]
    # For mini-batch training
    edge_idx = E_one_hot[:,:,1].nonzero().T

    data_loader = DataLoader(edge_index.cpu(), batch_size=train_config["batch_size"],
                             shuffle=True, num_workers=0)
    val_data_loader = DataLoader(edge_index, batch_size=train_config["val_batch_size"],
                                 shuffle=False)

    model = ModelAsync(X_marginal=X_marginal,
                       Y_marginal=Y_marginal,
                       E_marginal=E_marginal,
                       X_cond_Y_marginals=X_cond_Y_marginals,
                       data_train_mask=data_train_mask,
                       num_nodes=N,
                       mlp_X_config=yaml_data["mlp_X"],
                       gnn_E_config=yaml_data["gnn_E"],
                       pretrained_model=pretrained_model,
                       **yaml_data["diffusion"]).to(device)

    optimizer_X = torch.optim.AdamW(model.graph_encoder.pred_X.parameters(),
                                    **yaml_data["optimizer_X"])
    optimizer_E = torch.optim.AdamW(model.graph_encoder.pred_E.parameters(),
                                    **yaml_data["optimizer_E"])

    lr_scheduler_X = ReduceLROnPlateau(optimizer_X, mode='min', **yaml_data["lr_scheduler"])
    lr_scheduler_E = ReduceLROnPlateau(optimizer_E, mode='min', **yaml_data["lr_scheduler"])

    best_epoch_X = 0
    best_state_dict_X = deepcopy(model.graph_encoder.pred_X.state_dict())
    best_val_nll_X = float('inf')
    best_log_p_0_X = float('inf')
    best_denoise_match_X = float('inf')

    best_epoch_E = 0
    best_state_dict_E = deepcopy(model.graph_encoder.pred_E.state_dict())
    best_val_nll_E = float('inf')
    best_log_p_0_E = float('inf')
    best_denoise_match_E = float('inf')

    # Create the directory for saving model checkpoints.
    model_cpt_dir = f"{args.dataset}_cpts"
    os.makedirs(model_cpt_dir, exist_ok=True)

    num_patient_epochs = 0

    for epoch in range(train_config["num_epochs"]):
        model.train()

 #       for batch_edge_index in tqdm(data_loader):
 #           batch_edge_index = batch_edge_index.to(device)
            # (B), (B)
  #          batch_dst, batch_src = batch_edge_index.T
        for i in tqdm(range(100), mininterval=0.01):
            neg_edges = negative_sampling(edge_idx, num_nodes=N, num_neg_samples=edge_idx.size()[1]*3)
            batch_src = torch.cat([edge_idx[0],neg_edges[0]])
            batch_dst = torch.cat([edge_idx[1],neg_edges[1]])

            ###################
            probs = torch.softmax(Y, dim=1)
            Y_tmp = torch.multinomial(probs, num_samples=1).squeeze()
            ###################

            loss_X, loss_E, reg_loss = model.log_p_t(X_one_hot_3d,
                                           E_one_hot,
                                           Y.to(device),
                                           Y_tmp.to(device),
                                           X_one_hot_2d,
                                           batch_src,
                                           batch_dst,
                                           E_one_hot[batch_dst, batch_src])
            loss = loss_X + loss_E + reg_loss*0#.1
            optimizer_X.zero_grad()
            optimizer_E.zero_grad()
            loss.backward()

            nn.utils.clip_grad_norm_(
                model.graph_encoder.pred_X.parameters(), train_config["max_grad_norm"])
            nn.utils.clip_grad_norm_(
                model.graph_encoder.pred_E.parameters(), train_config["max_grad_norm"])

            optimizer_X.step()
            optimizer_E.step()

        if (epoch + 1) % train_config["val_every_epochs"] != 0:
            continue
        model.eval()

        num_patient_epochs += 1
        denoise_match_X = []
        denoise_match_E = []
        log_p_0_X = []
        log_p_0_E = []
        for batch_edge_index in tqdm(val_data_loader):
            batch_dst, batch_src = batch_edge_index.T

            ###################
            probs = torch.softmax(Y, dim=1)
            Y_tmp = torch.multinomial(probs, num_samples=1).squeeze()
            ###################

            # (B), (B)
            batch_denoise_match_E, batch_denoise_match_X,\
                batch_log_p_0_E, batch_log_p_0_X = model.val_step(
                    X_one_hot_3d,
                    E_one_hot,
                    Y.to(device),
                    Y_tmp.to(device),
                    X_one_hot_2d,
                    batch_src,
                    batch_dst,  
                    E_one_hot[batch_dst, batch_src])

            denoise_match_E.append(batch_denoise_match_E)
            denoise_match_X.append(batch_denoise_match_X)
            log_p_0_E.append(batch_log_p_0_E)
            log_p_0_X.append(batch_log_p_0_X)

        denoise_match_E = np.mean(denoise_match_E)
        denoise_match_X = np.mean(denoise_match_X)
        log_p_0_E = np.mean(log_p_0_E)
        log_p_0_X = np.mean(log_p_0_X)

        val_X = denoise_match_X + log_p_0_X
        val_E = denoise_match_E + log_p_0_E
        print("denoise_match_X :", denoise_match_X, "denoise_match_E :", denoise_match_E)

        to_save_cpt = False
        if val_X < best_val_nll_X:
            best_val_nll_X = val_X
            best_epoch_X = epoch
            best_state_dict_X = deepcopy(model.graph_encoder.pred_X.state_dict())
            to_save_cpt = True

        if val_E < best_val_nll_E:
            best_val_nll_E = val_E
            best_epoch_E = epoch
            best_state_dict_E = deepcopy(model.graph_encoder.pred_E.state_dict())
            to_save_cpt = True

        if to_save_cpt:
            best_val_nll = best_val_nll_X + best_val_nll_E
            torch.save({
                "dataset": args.dataset,
                "train_yaml_data": yaml_data,
                "best_val_nll": best_val_nll,
                "best_epoch_X": best_epoch_X,
                "best_epoch_E": best_epoch_E,
                "pred_X_state_dict": best_state_dict_X,
                "pred_E_state_dict": best_state_dict_E
            }, f"{model_cpt_dir}/{args.pretrain}_{model_name}_{str(args.imb_rate)}_{str(args.im_class_num)}_{args.pretrain}_pretrain_classificationloss.pth")
            print('model saved')

        if log_p_0_X < best_log_p_0_X:
            best_log_p_0_X = log_p_0_X
            num_patient_epochs = 0

        if denoise_match_X < best_denoise_match_X:
            best_denoise_match_X = denoise_match_X
            num_patient_epochs = 0

        if log_p_0_E < best_log_p_0_E:
            best_log_p_0_E = log_p_0_E
            num_patient_epochs = 0 

        if denoise_match_E < best_denoise_match_E:
            best_denoise_match_E = denoise_match_E
            num_patient_epochs = 0

        print("Epoch {} | best val X nll {:.7f} | best val E nll {:.7f} | patience {}/{}".format(
            epoch, best_val_nll_X, best_val_nll_E, num_patient_epochs, train_config["patient_epochs"]))

        if num_patient_epochs == train_config["patient_epochs"]:
            break

        lr_scheduler_X.step(log_p_0_X)
        lr_scheduler_E.step(log_p_0_E)

#    wandb.finish()

if __name__ == '__main__':
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument("-d", "--dataset", type=str, default='cora')
    parser.add_argument("--pretrain", default='GraphENS')
    parser.add_argument("--imb_rate", default=0.05)
    parser.add_argument("--im_class_num", default=3) 
    args = parser.parse_args()

    main(args)
