import dgl
import torch
import torch.nn.functional as F
import os
from data import load_dataset, preprocess, load_datasets_nc
from eval_utils import Evaluator
from setup_utils import set_seed

def main(args):
    state_dict = torch.load(args.model_path, map_location='cpu')
    dataset = state_dict["dataset"]

    train_yaml_data = state_dict["train_yaml_data"]
    model_name = train_yaml_data["meta_data"]["variant"]

    print(f"Loaded GraphMaker-{model_name} model trained on {dataset}")
    print(f"Val Nll {state_dict['best_val_nll']}")

    #device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else 'cpu')

    if dataset in ['cora', 'citeseer', 'amazon_photo', 'amazon_computer']:
        g_real = load_dataset(dataset)
    else:
        g_real = load_datasets_nc(dataset)
    X_one_hot_3d_real, s_real, y_real, E_one_hot_real,\
        X_marginal, s_marginal, y_marginal, E_marginal, X_cond_s_marginals, X_cond_y_marginals, y_cond_s_marginals, p_values = preprocess(g_real)
    s_one_hot_real = F.one_hot(s_real)
    if y_real is not None:
        Y_one_hot_3d_real = F.one_hot(y_real)
    else:
        Y_one_hot_3d_real = None
    evaluator = Evaluator(dataset,
                          os.path.dirname(args.model_path),
                          g_real,
                          X_one_hot_3d_real,
                          s_one_hot_real,
                          Y_one_hot_3d_real)
    if y_real is not None:
        y_marginal = y_marginal.to(device)
        y_cond_s_marginals = y_cond_s_marginals.to(device)
    X_marginal = X_marginal.to(device)
    s_marginal = s_marginal.to(device)   
    E_marginal = E_marginal.to(device)
    X_cond_s_marginals = X_cond_s_marginals.to(device)
    num_nodes = s_real.size(0)

    from Model import ModelSync

    model = ModelSync(X_marginal=X_marginal,
                      s_marginal=s_marginal,
                      y_marginal=y_marginal, 
                      E_marginal=E_marginal,
                      num_nodes=num_nodes,
                      p_values=p_values,
                      y_cond_s_marginal = y_cond_s_marginals,
                      gnn_X_config=train_yaml_data["gnn_X"],
                      gnn_E_config=train_yaml_data["gnn_E"],
                      **train_yaml_data["diffusion"]).to(device)

    model.graph_encoder.pred_X.load_state_dict(state_dict["pred_X_state_dict"])
    model.graph_encoder.pred_E.load_state_dict(state_dict["pred_E_state_dict"])
    model.to(device)
    
    model.eval()

    # Set seed for better reproducibility.
    set_seed()

    for _ in range(args.num_samples):
        X_0_one_hot, s_0_one_hot, y_0_one_hot, E_0 = model.sample(is_diff_X=True)
        src, dst = E_0.nonzero().T
        g_sample = dgl.graph((src, dst), num_nodes=num_nodes).cpu()

        evaluator.add_sample(g_sample,
                             X_0_one_hot.cpu(),
                             s_0_one_hot.cpu(),
                             y_0_one_hot.cpu() if y_0_one_hot is not None else y_0_one_hot)

    evaluator.summary()

if __name__ == '__main__':
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument("--model_path", type=str, help="Path to the model.")
    parser.add_argument("--num_samples", type=int, default=10,
                        help="Number of samples to generate.")
    parser.add_argument("--gpu", type=int, default=0, required=False,  choices=[0, 1, 2, 3])
    args = parser.parse_args()

    main(args)
