import dgl
import torch
import torch.nn.functional as F

from data import load_dataset, preprocess
from eval_utils import Evaluator
from setup_utils import set_seed
import os
from nets import *

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"

def main(args):
    args.model_path = './' + args.dataset + '_cpts/' + args.pretrain +'_Async_' + str(args.imb_rate) + '_' + str(args.im_class_num) + '_' + args.pretrain +'_pretrain_classificationloss.pth'
    state_dict = torch.load(args.model_path)
    dataset = state_dict["dataset"]

    train_yaml_data = state_dict["train_yaml_data"]
    model_name = train_yaml_data["meta_data"]["variant"]

    print(f"Loaded GraphMaker-{model_name} model trained on {dataset}")
    print(f"Val Nll {state_dict['best_val_nll']}")

    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

    g_real = load_dataset(dataset)
    X_one_hot_3d_real, Y_real, E_one_hot_real,\
        X_marginal, Y_marginal, E_marginal, X_cond_Y_marginals = preprocess(g_real)
    Y_one_hot_real = F.one_hot(Y_real)

    ##############
    MODEL_PATH = './pretrain/GraphENS_' + args.dataset + '.pth'
    pretrained_model = GAT(2, X_one_hot_3d_real.size()[0], 64, Y_real.max()+1, 4, is_add_self_loops=True)
    pretrained_model.load_state_dict(torch.load(MODEL_PATH),strict=False)
    ##############

    import pickle
    pseudo_label = './pretrain/'+str(args.dataset) +'_' + str(args.imb_rate) + '_' + str(args.im_class_num) + '_GraphENS.txt'
    with open(pseudo_label, 'rb') as f: 
        pred_Y, data_train_mask = pickle.load(f)
        Y = pred_Y

    print(X_marginal.device, device)
    X_marginal = X_marginal.to(device)
    Y_marginal = Y_marginal.to(device)
    E_marginal = E_marginal.to(device)
    X_cond_Y_marginals = X_cond_Y_marginals.to(device)
    num_nodes = Y_real.size(0)

    from model import ModelAsync
    model = ModelAsync(X_marginal=X_marginal,
                        Y_marginal=Y_marginal,
                        E_marginal=E_marginal,
                        X_cond_Y_marginals=X_cond_Y_marginals,
                        data_train_mask=data_train_mask,
                        mlp_X_config=train_yaml_data["mlp_X"],
                        gnn_E_config=train_yaml_data["gnn_E"],
                        num_nodes=num_nodes,
                        pretrained_model=None,
                        **train_yaml_data["diffusion"]).to(device)

    model.graph_encoder.pred_X.load_state_dict(state_dict["pred_X_state_dict"])
    model.graph_encoder.pred_E.load_state_dict(state_dict["pred_E_state_dict"])
    model.eval()

#    evaluator = Evaluator(dataset,
#                          g_real,
#                          X_one_hot_3d_real,
#                          Y_one_hot_real)

    # Set seed for better reproducibility.
    set_seed()
    generated_x_list, generated_edge_list, generated_Y_one_hot_list = [], [], []
    for _ in range(args.num_samples):
        X_0_one_hot, Y_0_one_hot, E_0 = model.sample(Y)
        print(E_one_hot_real[:,:,1].nonzero().size(), E_0.nonzero().size())
        
        src, dst = E_0.nonzero().T
        generated_x = X_0_one_hot[:,:,1].T
        generated_edge = torch.cat([src.unsqueeze(0), dst.unsqueeze(0)], 0)
        Y_0_one_hot = Y_0_one_hot.argmax(1)

        generated_x_list.append(generated_x.cpu())
        generated_edge_list.append(generated_edge.cpu())
        generated_Y_one_hot_list.append(Y_0_one_hot.cpu())
        
#        Y_0_one_hot = F.one_hot(Y_0_one_hot, num_classes=Y_0_one_hot.max()+1)
#        g_sample = dgl.graph((src, dst), num_nodes=num_nodes).cpu()
#        evaluator.add_sample(g_sample,
#                             X_0_one_hot.cpu(),
#                             Y_0_one_hot.cpu())

  #  evaluator.summary()

    pseudo_label = './Generated_graphs/' + args.pretrain + '_' + args.dataset + '_' + str(args.imb_rate) + '_' + str(args.im_class_num) + '_uniform_ftheta.txt'
    with open(pseudo_label, "wb") as f:
        pickle.dump([generated_x_list, generated_edge_list, generated_Y_one_hot_list], f)

if __name__ == '__main__':
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument("-d", "--dataset", type=str, default='cora')
    parser.add_argument("--imb_rate", default=0.05)
    parser.add_argument("--im_class_num", default=3)
    parser.add_argument("--num_samples", type=int, default=5, help="Number of samples to generate.")
    parser.add_argument("--pretrain", default='GraphENS')
    args = parser.parse_args()
    main(args)
