import os
import time
import pickle
import math
import torch

from utils.logger import Logger, set_log, start_log, train_log, sample_log, check_log
from utils.loader import load_ckpt, load_data, load_seed, load_device, load_model_from_ckpt, \
                         load_ema_from_ckpt, load_sampling_fn, load_eval_settings
from utils.graph_utils import adjs_to_graphs, init_flags, quantize, quantize_mol
from utils.plot import save_graph_list, plot_graphs_list
from evaluation.stats import eval_graph_list
from utils.mol_utils import gen_mol, mols_to_smiles, load_smiles, canonicalize_smiles, mols_to_nx
import projop
# from projop.utils import satisfies
from projop.project_bisection import satisfies
import networkx as nx

# -------- Sampler for generic graph generation tasks --------
class Sampler(object):
    def __init__(self, config, constr_config, device=None):
        super(Sampler, self).__init__()

        self.config = config
        self.constr_config = constr_config
        self.device = load_device() if device is None else device

        # -------- Load checkpoint --------
        self.ckpt_dict = load_ckpt(self.config, self.device)
        self.configt = self.ckpt_dict['config']

        load_seed(self.configt.seed)
        self.train_graph_list, self.test_graph_list = load_data(self.configt, get_graph_list=True)


        # -------- Load models --------
        self.model_x = load_model_from_ckpt(self.ckpt_dict['params_x'], self.ckpt_dict['x_state_dict'], self.device)
        self.model_adj = load_model_from_ckpt(self.ckpt_dict['params_adj'], self.ckpt_dict['adj_state_dict'], self.device)

        if self.config.sample.use_ema:
            self.ema_x = load_ema_from_ckpt(self.model_x, self.ckpt_dict['ema_x'], self.configt.train.ema)
            self.ema_adj = load_ema_from_ckpt(self.model_adj, self.ckpt_dict['ema_adj'], self.configt.train.ema)
            
            self.ema_x.copy_to(self.model_x.parameters())
            self.ema_adj.copy_to(self.model_adj.parameters())

        self.sampling_fn = load_sampling_fn(self.configt, self.config, self.constr_config, self.device)

    def sample(self):
        self.log_folder_name, self.log_dir, _ = set_log(self.configt, 
                                                        constraint=self.constr_config.constraint + ("-eq" if self.constr_config.eq else ""), 
                                                        is_train=False)
        if self.constr_config.method.op == 'cond_guide':
            self.log_name = f"{self.config.ckpt}-{self.constr_config.method.op}"
            self.log_name += f"-{self.constr_config.method.guidance_scale_x}-{self.constr_config.method.guidance_scale_adj}"
            self.log_name += f"-{self.constr_config.burnin}"
            self.log_name += f"-{self.constr_config.add_diff_step}"
            # if 'implicit' in self.constr_config and self.constr_config.implicit:
            #     self.log_name += f"-imp"
            param_vals = map (str, self.constr_config.params)
            self.log_name += f"-{','.join(param_vals)}-{self.constr_config.rounding}"
            self.log_name += f"-{self.config.seed}"
            self.log_name = self.log_name.replace(".", "p")
        else:
            self.log_name = f"{self.config.ckpt}-{self.constr_config.method.op}-{self.constr_config.method.gamma}"
            # """Just for testing """
            # self.log_name += f"-{self.constr_config.method.solve_order}"
            self.log_name += f"-bisect"
            self.log_name += f"-{self.constr_config.schedule.gamma}{','.join(map(str, self.constr_config.schedule.params))}"
            self.log_name += f"-{self.constr_config.burnin}"
            self.log_name += f"-{self.constr_config.add_diff_step}"
            # if 'implicit' in self.constr_config and self.constr_config.implicit:
            #     self.log_name += f"-imp"
            param_vals = map (str, self.constr_config.params)
            self.log_name += f"-{','.join(param_vals)}-{self.constr_config.rounding}"
            self.log_name += f"-{self.config.seed}"
            self.log_name = self.log_name.replace(".", "p")
        logger = Logger(str(os.path.join(self.log_dir, f'{self.log_name}.log')), mode='a')
        # if os.path.exists ('./samples/pkl/{}/{}.pkl'.format(self.log_folder_name, self.log_name)):
        #     return
        # try:
        #     with open (str(os.path.join(self.log_dir, f'{self.log_name}.log')), 'r') as f:
        #         for line in f:
        #             if 'MMD_full' in line:
        #                 return
        # except:
        #     pass

        if not check_log(self.log_folder_name, self.log_name):
            logger.log(f'{self.log_name}')
            start_log(logger, self.configt)
            train_log(logger, self.configt)

        sample_log(logger, self.config)

        # -------- Generate samples --------
        logger.log(f'GEN SEED: {self.config.sample.seed}')
        load_seed(self.config.sample.seed)

        num_sampling_rounds = math.ceil(len(self.test_graph_list) / self.configt.data.batch_size)
        gen_graph_list = []
        for r in range(num_sampling_rounds):
            t_start = time.time()

            self.init_flags = init_flags(self.train_graph_list, self.configt).to(self.device)

            x, adj, _ = self.sampling_fn(self.model_x, self.model_adj, self.init_flags)

            # constr_val = projop.utils.satisfies(adj, x, self.constr_config).sum().item()/len(adj)
            # logger.log (f'Constraint Validity before round: {constr_val}')
            
            # samples_int = quantize(adj)
            if self.constr_config.constraint == 'None':
                samples_int = quantize(adj)
            elif self.constr_config.rounding == 'randomized':
                _, samples_int = projop.rounding.random_round(x, adj, self.constr_config, adj_vals=self.config.data.adj_vals,
                                                            feat_vals=self.config.data.feat_vals)
            elif self.constr_config.rounding == 'repeated':
                _, samples_int = projop.rounding.repeated_round(x, adj, self.constr_config, adj_vals=self.config.data.adj_vals,
                                                                feat_vals=self.config.data.feat_vals)
            else:
                samples_int = quantize(adj)
                
            logger.log(f"Round {r} : {time.time()-t_start:.2f}s")
            # graphs_samples = [torch.tensor(nx.to_numpy_matrix(graph)).to(samples_int.device) for graph in gen_graph_list]
            gen_graph_list.extend(adjs_to_graphs(samples_int, True))
        
        gen_graph_list = gen_graph_list[:len(self.test_graph_list)]
        graphs_samples = torch.zeros(len(gen_graph_list), self.configt.data.max_node_num, self.configt.data.max_node_num)
        for i, G in enumerate(gen_graph_list):
            nG = G.number_of_nodes()
            graphs_samples[i, :nG, :nG] = torch.tensor(nx.adjacency_matrix(G).todense())
        constr_val = satisfies(x, graphs_samples, self.constr_config).sum().item()/len(graphs_samples)

        print (f'Constraint Validity: {constr_val}')
        # -------- Save samples --------
        save_dir = save_graph_list(self.log_folder_name, self.log_name, gen_graph_list)
        with open(save_dir, 'rb') as f:
            sample_graph_list = pickle.load(f)
        plot_graphs_list(graphs=sample_graph_list, title=self.log_name, max_num=16, save_dir=self.log_folder_name)

        # # -------- Evaluation --------
        methods, kernels = load_eval_settings(self.config.data.data)
        result_dict = eval_graph_list(self.test_graph_list, gen_graph_list, methods=methods, kernels=kernels)
        logger.log(f'MMD_full {result_dict}', verbose=False)
        logger.log('='*100)



# -------- Sampler for molecule generation tasks --------
class Sampler_mol(object):
    def __init__(self, config, constr_config, device=None):
        self.config = config
        self.constr_config = constr_config
        self.device = load_device() if device is None else device

        # load
        # -------- Load checkpoint --------
        self.ckpt_dict = load_ckpt(self.config, self.device)
        self.configt = self.ckpt_dict['config']

        load_seed(self.config.seed)

        # -------- Load models --------
        self.model_x = load_model_from_ckpt(self.ckpt_dict['params_x'], self.ckpt_dict['x_state_dict'], self.device)
        self.model_adj = load_model_from_ckpt(self.ckpt_dict['params_adj'], self.ckpt_dict['adj_state_dict'], self.device)
        
        # self.sampling_fn = load_sampling_fn(self.configt, self.config.sampler, self.config.sample, self.device)
        self.sampling_fn = load_sampling_fn(self.configt, self.config, self.constr_config, self.device)

        self.train_graph_list, _ = load_data(self.configt, get_graph_list=True)     # for init_flags
        with open(f'data/{self.configt.data.data.lower()}_test_nx.pkl', 'rb') as f:
            self.test_graph_list = pickle.load(f)                                   # for NSPDK MMD

    def sample(self):
        self.log_folder_name, self.log_dir, _ = set_log(self.configt, 
                                                constraint=self.constr_config.constraint + ("-eq" if self.constr_config.eq else ""), 
                                                is_train=False)
        if self.constr_config.method.op == 'cond_guide':
            self.log_name = f"{self.config.ckpt}-{self.constr_config.method.op}"
            self.log_name += f"-{self.constr_config.method.guidance_scale_x}-{self.constr_config.method.guidance_scale_adj}"
            self.log_name += f"-{self.constr_config.burnin}"
            self.log_name += f"-{self.constr_config.add_diff_step}"
            # if 'implicit' in self.constr_config and self.constr_config.implicit:
            #     self.log_name += f"-imp"
            param_vals = map (str, self.constr_config.params)
            self.log_name += f"-{','.join(param_vals)}-{self.constr_config.rounding}"
            self.log_name += f"-{self.config.seed}"
            self.log_name = self.log_name.replace(".", "p")
        else:
            self.log_name = f"{self.config.ckpt}-{self.constr_config.method.op}-{self.constr_config.method.gamma}"
            self.log_name += f"-bisect"
            self.log_name += f"-{self.constr_config.schedule.gamma}{','.join(map(str, self.constr_config.schedule.params))}"
            self.log_name += f"-{self.constr_config.burnin}"
            self.log_name += f"-{self.constr_config.add_diff_step}"
            # if 'implicit' in self.constr_config and self.constr_config.implicit:
            #     self.log_name += f"-imp"
            param_vals = map (str, self.constr_config.params)
            self.log_name += f"-{','.join(param_vals)}-{self.constr_config.rounding}"
            self.log_name += f"-{self.config.seed}"
            self.log_name = self.log_name.replace(".", "p")
        logger = Logger(str(os.path.join(self.log_dir, f'{self.log_name}.log')), mode='a')

        logger.log(f'GEN SEED: {self.config.sample.seed}')
        load_seed(self.config.sample.seed)

        if not check_log(self.log_folder_name, self.log_name):
            start_log(logger, self.configt)
            train_log(logger, self.configt)
        sample_log(logger, self.config)

        self.init_flags = init_flags(self.train_graph_list, self.configt, 10000).to(self.device)
        x, adj, _ = self.sampling_fn(self.model_x, self.model_adj, self.init_flags)

        constr_val = satisfies(x, adj, self.constr_config).sum().item()/len(adj)
        logger.log (f'Constraint Validity before round: {constr_val}')

        if self.constr_config.constraint == 'None':
            x = torch.where(x > 0.5, 1., 0.).to(x.device)
            samples_int = quantize_mol(adj)
        elif self.constr_config.rounding == 'randomized':
            x = torch.where(x > 0.5, 1., 0.).to(x.device)
            _, samples_int = projop.rounding.random_round(x, adj, self.constr_config, adj_vals=self.config.data.adj_vals,
                                                        feat_vals=self.config.data.feat_vals)
        elif self.constr_config.rounding == 'repeated':
            x = torch.where(x > 0.5, 1., 0.).to(x.device)
            _, samples_int = projop.rounding.repeated_round(x, adj, self.constr_config, adj_vals=self.config.data.adj_vals,
                                                            feat_vals=self.config.data.feat_vals)
        elif self.constr_config.rounding == 'heurval':
            x, samples_int = projop.rounding.heur_val_round(x, adj, self.constr_config, adj_vals=self.config.data.adj_vals)
        else:
            x = torch.where(x > 0.5, 1., 0.).to(x.device)
            samples_int = quantize_mol(adj)

        constr_val = satisfies(x, samples_int, self.constr_config).sum().item()/len(samples_int)
        print (f'Constraint Validity: {constr_val}')        
        logger.log (f'Constraint Validity: {constr_val}')
        # samples_int = quantize_mol(adj)
        # torch.save(samples_int, "samples_temp.pt")
        # torch.save(x, "x_temp.pt")

        samples_int = samples_int - 1
        samples_int[samples_int == -1] = 3      # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2

        adj = torch.nn.functional.one_hot(torch.tensor(samples_int), num_classes=4).permute(0, 3, 1, 2)
        # torch.save(x, "temp_x.pt")
        # torch.save(samples_int, "temp_samples.pt")
        x = torch.cat([x, 1 - x.sum(dim=-1, keepdim=True)], dim=-1)      # 32, 9, 4 -> 32, 9, 5

        # -------- Evaluation --------
        gen_mols, num_mols_wo_correction = gen_mol(x, adj, self.configt.data.data, to_correct=False, largest_connected_comp=False)
        # import pickle as pkl
        # pkl.dump(gen_mols, open("gen_mols.pkl", "wb"))
        num_mols = len(gen_mols)
        print (num_mols)

        gen_smiles = mols_to_smiles(gen_mols)
        gen_smiles = [smi for smi in gen_smiles if len(smi)]
        
        # -------- Save generated molecules --------
        with open(os.path.join(self.log_dir, f'{self.log_name}.txt'), 'w+') as f:
            for smiles in gen_smiles:
                f.write(f'{smiles}\n')

        # -------- Evaluation --------
        # # -------- Generate samples --------
        # load_seed(self.config.sample.seed)

        # train_smiles, test_smiles = load_smiles(self.configt.data.data)
        # train_smiles, test_smiles = canonicalize_smiles(train_smiles), canonicalize_smiles(test_smiles)
        
        # from moses.metrics.metrics import get_all_metrics
        # scores = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=self.device, n_jobs=8, test=test_smiles, train=train_smiles)
        # print ([f'{metric}: {scores[metric]}' for metric in ['valid', f'unique@{len(gen_smiles)}', 'FCD/Test', 'Novelty']])
        # scores_nspdk = eval_graph_list(self.test_graph_list, mols_to_nx(gen_mols), methods=['nspdk'])['nspdk']

        # logger.log(f'Number of molecules: {num_mols}')
        # logger.log(f'validity w/o correction: {num_mols_wo_correction / num_mols}')
        # for metric in ['valid', f'unique@{len(gen_smiles)}', 'FCD/Test', 'Novelty']:
        #     logger.log(f'{metric}: {scores[metric]}')
        # logger.log(f'NSPDK MMD: {scores_nspdk}')
        # logger.log('='*100)

        # print ("========= Original =========")

        # gen_mols, num_mols_wo_correction = gen_mol(x, adj, self.configt.data.data)
        # num_mols = len(gen_mols)
        # print (num_mols)

        # gen_smiles = mols_to_smiles(gen_mols)
        # gen_smiles = [smi for smi in gen_smiles if len(smi)]
        # train_smiles, test_smiles = load_smiles(self.configt.data.data)
        # train_smiles, test_smiles = canonicalize_smiles(train_smiles), canonicalize_smiles(test_smiles)

        # from moses.metrics.metrics import get_all_metrics
        # scores = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=self.device, n_jobs=8, test=test_smiles, train=train_smiles)
        # scores_nspdk = eval_graph_list(self.test_graph_list, mols_to_nx(gen_mols), methods=['nspdk'])['nspdk']

        # logger.log(f'Number of molecules: {num_mols}')
        # logger.log(f'validity w/o correction: {num_mols_wo_correction / num_mols}')
        # for metric in ['valid', f'unique@{len(gen_smiles)}', 'FCD/Test', 'Novelty']:
        #     logger.log(f'{metric}: {scores[metric]}')
        # logger.log(f'NSPDK MMD: {scores_nspdk}')
        # logger.log('='*100)
        # print ("==================")

