import os
import random
import time
import pickle
import math
import torch

from evaluation.stats import eval_graph_list
from utils.logger import Logger, set_log, start_log, train_log, sample_log, check_log
from utils.loader import load_ckpt, load_data, load_seed, load_device, load_model_from_ckpt, \
    load_ema_from_ckpt, load_eval_settings, load_sampling_fn
from utils.graph_utils import adjs_to_graphs, quantize, init_flags, mask_adjs, quantize_mol
from utils.plot import save_graph_list, plot_graphs_list
from utils.mol_utils import gen_mol, mols_to_smiles, load_smiles, canonicalize_smiles, mols_to_nx
from moses.metrics.metrics import get_all_metrics
import numpy as np
# import networkx as nx


# -------- Sampler for generic graph generation tasks --------
class Sampler(object):
    def __init__(self, config, epochs=None):
        super(Sampler, self).__init__()
        self.is_return = True if epochs is not None else False
        self.config = config
        self.device = load_device()

    def sample(self):
        # -------- Load checkpoint --------
        self.ckpt_dict = load_ckpt(self.config, self.device)

        self.configt = self.ckpt_dict['config']

        if not "type" in self.configt:
            print("self.configt has not type setting")
            self.configt.type = self.config.type

        load_seed(self.configt.seed)
        self.train_graph_list, self.test_graph_list = load_data(self.configt, get_graph_list=True)

        self.log_folder_name, self.log_dir, _ = set_log(self.configt, is_train=False)
        self.log_name = f"{self.config.ckpt}-sample"
        logger = Logger(str(os.path.join(self.log_dir, f'{self.log_name}.log')), mode='a')

        if not check_log(self.log_folder_name, self.log_name):
            logger.log(f'{self.log_name}')
            start_log(logger, self.configt)
            train_log(logger, self.configt)

        sample_log(logger, self.config)

        # -------- Load models --------
        # load_model_from_ckpt is a function in loader.py that loads the model from the checkpoint.
        self.model_x = load_model_from_ckpt(self.ckpt_dict['params_x'], self.ckpt_dict['x_state_dict'], self.device)
        self.model_adj = load_model_from_ckpt(self.ckpt_dict['params_adj'], self.ckpt_dict['adj_state_dict'], self.device)

        # -------- Load EMA --------
        # load_ema_from_ckpt is a function in loader.py that loads the EMA from the checkpoint.
        # Exponential Moving Average (EMA) is a type of infinite impulse response filter that applies weighting factors which decrease exponentially.
        if self.config.sample.use_ema:
            self.ema_x = load_ema_from_ckpt(self.model_x, self.ckpt_dict['ema_x'], self.configt.train.ema)
            self.ema_adj = load_ema_from_ckpt(self.model_adj, self.ckpt_dict['ema_adj'], self.configt.train.ema)

            self.ema_x.copy_to(self.model_x.parameters())
            self.ema_adj.copy_to(self.model_adj.parameters())

        # -------- Load sampling function --------
        self.sampling_fn2 = load_sampling_fn(self.configt, self.config.sampler, self.config.sample, self.device)  # Load the sampling function.

        # -------- Generate samples --------
        # self.config.sample.seed = random.randint(1, 10000)
        logger.log(f'GEN SEED: {self.config.sample.seed}')  # Seed for generating samples
        load_seed(self.config.sample.seed)

        num_sampling_rounds = math.ceil(len(self.test_graph_list) / self.configt.data.batch_size)   # Number of sampling rounds
        gen_graph_list = []

        for r in range(num_sampling_rounds):
            t_start = time.time()

            self.init_flags, train_tensor = init_flags(self.train_graph_list, self.configt)
            self.init_flags = self.init_flags.to(self.device[0])
            train_tensor = train_tensor.to(self.device[0])

            x, adj, _ = self.sampling_fn2(self.model_x, self.model_adj, self.init_flags, train_tensor, 'g')  # Generate samples.

            adj = mask_adjs(adj, self.init_flags)

            logger.log(f"Round {r} : {time.time() - t_start:.2f}s")

            adj_np = adj.cpu().numpy()

            # Save the upper triangular part of the matrix and add it to the transposed matrix to get the full matrix.
            adj = adj.triu(1)  # Upper triangular part of the matrix.
            adj = adj + torch.transpose(adj, -1, -2)  # Transpose the matrix and add it to the original matrix.

            samples_int = quantize(adj)  # Quantize the adjacency matrix. (0, 1).
            samples_int_np = samples_int.cpu().numpy()  # Convert the samples to a numpy array.

            # Save the adjacency matrix and the samples if it is the first round.
            if r == 0:
                adj_np_to_save = adj_np  # Save the adjacency matrix.
                samples_int_np_to_save = samples_int_np  # Save the samples.
            else:
                adj_np_to_save = np.vstack([adj_np_to_save, adj_np])  # Stack the adjacency matrices.
                samples_int_np_to_save = np.vstack([samples_int_np_to_save, samples_int_np])  # Stack the samples.

            gen_graph_list.extend(adjs_to_graphs(samples_int, True))  # Convert the samples to graphs.

        gen_graph_list = gen_graph_list[:len(self.test_graph_list)]  # To get the first N elements of a list.

        # -------- Evaluation --------
        methods, kernels = load_eval_settings(self.config.data.data)
        result_dict = eval_graph_list(self.test_graph_list, gen_graph_list, methods=methods, kernels=kernels)
        logger.log(f'MMD_full {result_dict}', verbose=False)
        logger.log('=' * 100)

        if self.is_return:
            return result_dict

        # -------- Save samples --------
        save_dir = save_graph_list(self.log_folder_name, self.log_name, gen_graph_list)
        with open(save_dir, 'rb') as f:
            sample_graph_list = pickle.load(f)  # Load the saved samples.
        plot_graphs_list(graphs=sample_graph_list, title=f'{self.config.ckpt}', max_num=16,
                         save_dir=self.log_folder_name)  # Plot the samples.


# -------- Sampler for molecule generation tasks --------
class Sampler_mol(object):
    def __init__(self, config):
        self.config = config
        self.device = load_device()

    def sample(self):
        # -------- Load checkpoint --------
        self.ckpt_dict = load_ckpt(self.config, self.device)
        self.configt = self.ckpt_dict['config']

        load_seed(self.config.seed)

        self.log_folder_name, self.log_dir, _ = set_log(self.configt, is_train=False)
        self.log_name = f"{self.config.ckpt}-sample"
        logger = Logger(str(os.path.join(self.log_dir, f'{self.log_name}.log')), mode='a')

        if not check_log(self.log_folder_name, self.log_name):
            start_log(logger, self.configt)
            train_log(logger, self.configt)
        sample_log(logger, self.config)

        # -------- Load models --------
        self.model_x = load_model_from_ckpt(self.ckpt_dict['params_x'], self.ckpt_dict['x_state_dict'], self.device)
        self.model_adj = load_model_from_ckpt(self.ckpt_dict['params_adj'], self.ckpt_dict['adj_state_dict'],
                                              self.device)

        self.sampling_fn2 = load_sampling_fn(self.configt, self.config.sampler, self.config.sample, self.device)

        # -------- Generate samples --------
        logger.log(f'GEN SEED: {self.config.sample.seed}')
        load_seed(self.config.sample.seed)

        train_smiles, test_smiles = load_smiles(self.configt.data.data)
        train_smiles, test_smiles = canonicalize_smiles(train_smiles), canonicalize_smiles(test_smiles)

        self.train_graph_list, _ = load_data(self.configt, get_graph_list=True)  # for init_flags
        with open(f'data/{self.configt.data.data.lower()}_test_nx.pkl', 'rb') as f:
            self.test_graph_list = pickle.load(f)  # for NSPDK MMD

        self.init_flags, train_tensor = init_flags(self.train_graph_list, self.configt, 10000)
        self.init_flags = self.init_flags.to(self.device[0])
        train_tensor = train_tensor.to(self.device[0])

        x, adj, _ = self.sampling_fn2(self.model_x, self.model_adj, self.init_flags,
                                          train_tensor, 'm')  # Generate samples.

        samples_int = quantize_mol(adj)

        samples_int = samples_int - 1
        samples_int[samples_int == -1] = 3  # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2
        adj = torch.nn.functional.one_hot(torch.tensor(samples_int), num_classes=4).permute(0, 3, 1, 2)
        x = torch.where(x > 0.5, 1, 0)
        x = torch.concat([x, 1 - x.sum(dim=-1, keepdim=True)], dim=-1)  # 32, 9, 4 -> 32, 9, 5

        gen_mols, num_mols_wo_correction = gen_mol(x, adj, self.configt.data.data)
        num_mols = len(gen_mols)

        gen_smiles = mols_to_smiles(gen_mols)
        gen_smiles = [smi for smi in gen_smiles if len(smi)]

        # -------- Save generated molecules --------
        with open(os.path.join(self.log_dir, f'{self.log_name}.txt'), 'a') as f:
            for smiles in gen_smiles:
                f.write(f'{smiles}\n')

        # -------- Evaluation --------
        scores = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=self.device[0], n_jobs=8, test=test_smiles,
                                 train=train_smiles)
        scores_nspdk = eval_graph_list(self.test_graph_list, mols_to_nx(gen_mols), methods=['nspdk'])['nspdk']

        logger.log(f'Number of molecules: {num_mols}')
        logger.log(f'validity w/o correction: {num_mols_wo_correction / num_mols}')
        for metric in ['valid', f'unique@{len(gen_smiles)}', 'FCD/Test', 'Novelty']:
            logger.log(f'{metric}: {scores[metric]}')
        logger.log(f'NSPDK MMD: {scores_nspdk}')
        logger.log('=' * 100)