import dgl
import torch
import torch.nn.functional as F
import os
from data import load_dataset, preprocess
from eval_utils import Evaluator
from setup_utils import set_seed

def main(args):
    state_dict = torch.load(args.model_path)
    dataset = state_dict["dataset"]

    train_yaml_data = state_dict["train_yaml_data"]
    model_name = train_yaml_data["meta_data"]["variant"]

    print(f"Loaded GraphMaker-{model_name} model trained on {dataset}")
    print(f"Val Nll {state_dict['best_val_nll']}")

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    g_real = load_dataset(dataset)
    X_one_hot_3d_real, s_real, E_one_hot_real,\
        X_marginal, s_marginal, E_marginal, X_cond_s_marginals = preprocess(g_real)
    s_one_hot_real = F.one_hot(s_real)
    for i in range(5):
        evaluator = Evaluator(dataset,
                              os.path.dirname(args.model_path),
                              g_real,
                              X_one_hot_3d_real,
                              s_one_hot_real,
                              seed=i)

        X_marginal = X_marginal.to(device)
        s_marginal = s_marginal.to(device)
        E_marginal = E_marginal.to(device)
        X_cond_s_marginals = X_cond_s_marginals.to(device)
        num_nodes = s_real.size(0)

        from Model import ModelSync

        model = ModelSync(X_marginal=X_marginal,
                          s_marginal=s_marginal,
                          E_marginal=E_marginal,
                          gnn_X_config=train_yaml_data["gnn_X"],
                          gnn_E_config=train_yaml_data["gnn_E"],
                          num_nodes=num_nodes,
                          **train_yaml_data["diffusion"]).to(device)

        model.graph_encoder.pred_X.load_state_dict(state_dict["pred_X_state_dict"])
        model.graph_encoder.pred_E.load_state_dict(state_dict["pred_E_state_dict"])

    
        model.eval()

        # Set seed for better reproducibility.
        set_seed()

        for _ in range(1):
            X_0_one_hot, s_0_one_hot, E_0 = model.sample(is_diff_X=True)
            src, dst = E_0.nonzero().T
            g_sample = dgl.graph((src, dst), num_nodes=num_nodes).cpu()

            evaluator.add_sample(g_sample,
                                 X_0_one_hot.cpu(),
                                 s_0_one_hot.cpu(),
                                 Y_one_hot=None)
        
        print('For '+str(i)+'th seed: ')
        evaluator.summary()
        

if __name__ == '__main__':
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument("--model_path", type=str, help="Path to the model.")
    parser.add_argument("--num_samples", type=int, default=10,
                        help="Number of samples to generate.")
    args = parser.parse_args()

    main(args)
