import os
import time
import torch
import pickle
import pdb

from utils.logger import Logger, set_log, start_log, train_log, sample_log, check_log
from utils.loader import load_model_from_ckpt, load_data, load_sampling_fn, load_sample, load_seed, load_device, \
                         load_ckpt, load_sampling_fn_conditional
from utils.graph_utils import init_flags
from utils.mol_utils import gen_mol, mols_to_smiles, load_smiles, canonicalize_smiles, mols_to_nx, x_adj_to_nx
from moses.metrics.metrics import get_all_metrics
from evaluation.stats import eval_graph_list

from utils.classifier_utils import load_classifier_from_ckpt, load_classifier_ckpt
from utils.classifier_utils import train_log as train_logc
from utils.classifier_utils import sample_log as sample_logc

from get_pds import get_pds, get_pds_no_ds


class Sampler_mol(object):
    def __init__(self, config):
        self.config = config
        self.device = load_device(self.config.gpu)

    def sample(self):
        self.s = load_ckpt(self.config, self.device)['S']
        self.configs = self.s['config']

        load_seed(self.config.seed)

        self.log_folder_name, self.log_dir, _ = set_log(self.configs, is_train=False, foldername='ood')
        self.log_name = f"{self.config.module['S'].ckpt}_oodsqrt{self.config.sample.ood}"
        if self.config.seed != 42:
            self.log_name += f'_{self.config.seed}'
        print(f'logname: {self.log_name}')
        logger = Logger(str(os.path.join(self.log_dir, f'{self.log_name}.log')), mode='a')

        if not check_log(self.log_folder_name, self.log_name):
            start_log(logger, self.configs)
            train_log(logger, self.configs)
        sample_log(logger, self.config)

        self.model_x = load_model_from_ckpt(self.s['params_x'], self.s['x_state_dict'], self.device)
        self.model_adj = load_model_from_ckpt(self.s['params_adj'], self.s['adj_state_dict'], self.device)
        
        self.sampling_fn = load_sampling_fn(self.configs, self.config.module.S, self.config.sample, self.device)

        train_smiles, test_smiles = load_smiles(self.configs.data.data)
        train_smiles, test_smiles = canonicalize_smiles(train_smiles), canonicalize_smiles(test_smiles)

        t_start = time.time()

        if self.s['params_adj']['model_type'] == 'rgcn_attn':
            x, adj, _ = self.sampling_fn(self.model_x, self.model_adj, None)
        else:
            self.train_graph_list, _ = load_data(self.configs, get_graph_list=True)     # for init_flags
            with open(f'data/{self.configs.data.data.lower()}_test_nx.pkl', 'rb') as f:
                self.test_graph_list = pickle.load(f)                                   # for NSPDK MMD

            self.init_flags = init_flags(self.train_graph_list, self.config.sample.num_samples, self.configs).to(self.device[0])
            x, adj, _ = self.sampling_fn(self.model_x, self.model_adj, self.init_flags)

        if self.config.sample.check:
            pdb.set_trace()

        logger.log(f"{time.time()-t_start:.2f} sec elapsed for sampling")
        
        samples_int = load_sample(x, adj, self.config.sample.check, mol=True)

        if self.configs.data.max_feat_num in [4, 9]:
            samples_int_ = samples_int - 1
            samples_int_[samples_int_ == -1] = 3    # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2

        logger.log(f'[before quantization] mean: {adj.mean().item():.4f} | max: {adj.max().item():.2f} | min: {adj.min().item():.2f}')
        logger.log(f'[after quantization]  mean: {samples_int.mean():.4f} | max: {samples_int.max():.2f} | min: {samples_int.min():.2f}')

        # assert self.configs.data.max_feat_num in [4, 9]
        adj = torch.nn.functional.one_hot(torch.tensor(samples_int_), num_classes=4).permute(0,3,1,2)
        x = torch.where(x > 0.5, 1, 0)
        x = torch.concat([x, 1 - x.sum(dim=-1, keepdim=True)], dim=-1)      # 32, 9, 4 -> 32, 9, 5

        gen_mols, num_mols_wo_correction, num_mols_wo_fc_correction = gen_mol(x, adj, self.configs.data.data)
        num_mols = len(gen_mols)

        gen_smiles = mols_to_smiles(gen_mols)
        gen_smiles = [smi for smi in gen_smiles if len(smi)]
        
        if self.config.module.S.corrector == 'None':
            filename = f'{self.log_name}'
        else:
            filename = f'{self.log_name}_snr={int(self.config.module.S.snr*100):03d}_seps={int(self.config.module.S.scale_eps*100):03d}'
        with open(os.path.join(self.log_dir, f'{filename}.txt'), 'a') as f:
            for smiles in gen_smiles:
                f.write(f'{smiles}\n')
        
        scores = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=self.device[0], n_jobs=8,
                                 test=test_smiles, train=train_smiles)
        scores_nspdk = eval_graph_list(self.test_graph_list, mols_to_nx(gen_mols), methods=['nspdk'])['nspdk']
        scores_nspdk_x_adj = \
            eval_graph_list(self.test_graph_list, x_adj_to_nx(x, samples_int, self.configs.data.data), methods=['nspdk'])['nspdk']

        logger.log(f'Number of molecules: {num_mols} | '
                   f'Number of molecules w/o correction: {num_mols_wo_correction} | '
                   f'Number of molecules w/o fc correction: {num_mols_wo_fc_correction}')
        logger.log(f'Number of non-empty molecules: {len(gen_smiles)}')
        for name, value in scores.items():
            logger.log(f'{name}: {value}')
        logger.log(f'NSPDK MMD: {scores_nspdk}')
        logger.log(f'NSPDK MMD (X and A): {scores_nspdk_x_adj}')
        logger.log(f'Total {time.time()-t_start:.2f}s elapsed')
        logger.log('='*100)


class Sampler_conditional(object):
    def __init__(self, config):
        super(Sampler_conditional, self).__init__()

        self.config = config
        self.device = load_device(self.config.gpu)
    
    def get_total_model(self, model, final_linear):
        def total_model(x, adj, flags):
            out = model(x, adj, flags)
            return final_linear(out)
        return total_model

    def sample(self):
        self.s = load_ckpt(self.config, self.device)['S']
        log_ckpt = f"{self.config.module['S'].ckpt}"
        self.configs = self.s['config']

        self.cmodules = load_classifier_ckpt(self.config, self.device)
        self.c = self.cmodules['C']
        self.configc = self.c['config']

        self.check_config(self.configs, self.configc)
        load_seed(self.config.seed)

        if self.config.sample.low:
            foldername = 'low_ood_prop'
        else:
            foldername = 'ood_prop'
        
        foldername = f'{foldername}/{self.configc.train.prop}_3000mols'
        
        self.log_folder_name, self.log_dir, _ = set_log(self.configs, is_train=False, foldername=foldername)
        self.log_name = f"{log_ckpt}-{self.config.module['C'].ckpt}_" \
                        f"oodsqrt{self.config.sample.ood}_" \
                        f"Xw{self.config.module.C.weight_x}_Aw{self.config.module.C.weight_adj}_" \
                        f"os{self.config.sample.ood_scheduling}_ws{self.config.sample.weight_scheduling}"
        print(f'logname: {self.log_name}')
        logger = Logger(str(os.path.join(self.log_dir, f'{self.log_name}.log')), mode='a')

        if not check_log(self.log_folder_name, self.log_name):
            start_log(logger, self.configs)
            train_log(logger, self.configs)
            train_logc(logger, self.configc)
            sample_logc(logger, self.config.module['C'])
        logger.log(f'snr={self.config.module.S.snr} seps={self.config.module.S.scale_eps} n_steps={self.config.module.S.n_steps}')

        self.model_x = load_model_from_ckpt(self.s['params_x'], self.s['x_state_dict'], self.device)
        self.model_adj = load_model_from_ckpt(self.s['params_adj'], self.s['adj_state_dict'], self.device)
        self.classifier = load_classifier_from_ckpt(self.c['params'], self.c['state_dict'], self.device)
        self.classifier.eval()
        
        self.sampling_fn = load_sampling_fn_conditional(self.configs, self.config.module.S, self.config.module.C,
                                                        self.config.sample, self.configc, self.device, logger)

        t_start = time.time()

        self.train_graph_list, _ = load_data(self.configc, get_graph_list=True)     # for init_flags
        with open(f'data/{self.configs.data.data.lower()}_test_nx.pkl', 'rb') as f:
            self.test_graph_list = pickle.load(f)                                   # for NSPDK MMD

        self.init_flags = init_flags(self.train_graph_list, self.config.sample.num_samples, self.configs).to(self.device[0])

        x, adj, _ = self.sampling_fn(self.model_x, self.model_adj, self.init_flags, self.classifier)

        logger.log(f"{time.time()-t_start:.2f} sec elapsed for sampling")
        
        samples_int = load_sample(x, adj, self.config.sample.check, mol=True)

        if self.configs.data.max_feat_num in [4, 9]:
            samples_int_ = samples_int - 1
            samples_int_[samples_int_ == -1] = 3    # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2

        logger.log(f'[before quantization] mean: {adj.mean().item():.4f} | max: {adj.max().item():.2f} | min: {adj.min().item():.2f}')
        logger.log(f'[after quantization]  mean: {samples_int.mean():.4f} | max: {samples_int.max():.2f} | min: {samples_int.min():.2f}')

        adj = torch.nn.functional.one_hot(torch.tensor(samples_int_), num_classes=4).permute(0,3,1,2)
        x = torch.where(x > 0.5, 1, 0)
        x = torch.concat([x, 1 - x.sum(dim=-1, keepdim=True)], dim=-1)      # 32, 9, 4 -> 32, 9, 5

        gen_mols, num_mols_wo_correction, num_mols_wo_fc_correction = gen_mol(x, adj, self.configs.data.data)
        num_mols = len(gen_mols)

        gen_smiles = mols_to_smiles(gen_mols)

        logger.log(f'Number of molecules: {num_mols} | '
                   f'Number of molecules w/o correction: {num_mols_wo_correction} | '
                   f'Number of molecules w/o fc correction: {num_mols_wo_fc_correction}')
        
        if self.config.module.S.corrector == 'None':
            filename = f'{self.log_name}'
        else:
            filename = f'{self.log_name}_snr={int(self.config.module.S.snr*100):03d}_seps={int(self.config.module.S.scale_eps*100):03d}'
        
        thrs = [0.5, 5]     # QED, SA
        if 'parp1' in self.configc.train.prop: protein = 'parp1'
        elif 'fa7' in self.configc.train.prop: protein = 'fa7'
        elif '5ht1b' in self.configc.train.prop: protein = '5ht1b'
        elif 'braf' in self.configc.train.prop: protein = 'braf'
        elif 'jak2' in self.configc.train.prop: protein = 'jak2'
        elif 'tgfr1' in self.configc.train.prop: protein = 'tgfr1'
        else: protein = None

        if protein is None:
            result = get_pds_no_ds(os.path.join(self.log_dir, filename), gen_smiles, gen_mols, thrs)
            
            logger.log(f'Validity: {result["validity"]}')
            logger.log(f'Uniqueness: {result["uniqueness"]}')
            logger.log(f'Novelty: {result["novelty"]}')
            logger.log(f'Novelty (sim. < 0.2): {result["novelty_02"]}')
            logger.log(f'Novelty (sim. < 0.3): {result["novelty_03"]}')
            logger.log(f'Novelty (sim. < 0.4): {result["novelty_04"]}')
            logger.log(f'pass rate (QED > {thrs[0]}, SA < {thrs[1]}): {result["pass_rate"]}')

        else:
            print(f'Calculating docking scores w.r.t. {protein}...')
            result = get_pds(protein, os.path.join(self.log_dir, filename), gen_smiles, gen_mols, thrs)
            
            logger.log(f'Validity: {result["validity"]}')
            logger.log(f'Uniqueness: {result["uniqueness"]}')
            logger.log(f'Novelty: {result["novelty"]}')
            logger.log(f'Novelty (sim. < 0.2): {result["novelty_02"]}')
            logger.log(f'Novelty (sim. < 0.3): {result["novelty_03"]}')
            logger.log(f'Novelty (sim. < 0.4): {result["novelty_04"]}')
            logger.log(f'DS: {result["ds"][0]:.4f} ± {result["ds"][1]:.4f}')
            logger.log(f'top 5% DS: {result["top_ds"][0]:.4f} ± {result["top_ds"][1]:.4f}')
            logger.log(f'pass rate (QED > {thrs[0]}, SA < {thrs[1]}): {result["pass_rate"]}')
            logger.log(f'top 5% DS (QED > {thrs[0]}, SA < {thrs[1]}): '
                    f'{result["top_pass_ds"][0]:.4f} ± {result["top_pass_ds"][1]:.4f}')
            logger.log(f'novel top 5% DS (QED > {thrs[0]}, SA < {thrs[1]}, sim. < 0.3): '
                    f'{result["top_pass_ds_novel_03"][0]:.4f} ± {result["top_pass_ds_novel_03"][1]:.4f}')
            logger.log(f'novel top 5% DS (QED > {thrs[0]}, SA < {thrs[1]}, sim. < 0.4): '
                    f'{result["top_pass_ds_novel_04"][0]:.4f} ± {result["top_pass_ds_novel_04"][1]:.4f}')
            logger.log(f'hit ratio: {result["hit"] * 100:.4f} %')
            logger.log(f'hit ratio (QED > {thrs[0]}, SA < {thrs[1]}): {result["hit_pass"] * 100:.4f} %')
            logger.log(f'novel hit ratio (QED > {thrs[0]}, SA < {thrs[1]}, sim. < 0.3): {result["hit_novel_03"] * 100:.4f} %')
            logger.log(f'novel hit ratio (QED > {thrs[0]}, SA < {thrs[1]}, sim. < 0.4): {result["hit_novel_04"] * 100:.4f} %')

            logger.log(f"{time.time()-t_start:.2f} sec elapsed for docking simulation")

        train_smiles, test_smiles = load_smiles(self.configc.data.data)
        train_smiles, test_smiles = canonicalize_smiles(train_smiles), canonicalize_smiles(test_smiles)
        gen_smiles = [smi for smi in gen_smiles if smi]

        scores = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=self.device[0], n_jobs=8,
                                 test=test_smiles, train=train_smiles)
        scores_nspdk = eval_graph_list(self.test_graph_list, mols_to_nx(gen_mols), methods=['nspdk'])['nspdk']
        for name, value in scores.items():
            logger.log(f'{name}: {value}')
        logger.log(f'NSPDK MMD: {scores_nspdk}')
        logger.log(f'Total {time.time()-t_start:.2f}s elapsed')
        logger.log('='*100)
        
    def check_config(self, config1, config2):
        assert config1.data.batch_size == config2.data.batch_size, 'Batch size Mismatch'
        assert config1.data.max_node_num == config2.data.max_node_num, 'Max node num Mismatch'
        assert config1.data.max_feat_num == config2.data.max_feat_num, 'Max feat. num Mismatch'
        assert config1.sde.x == config2.sde.x, 'SDE Mismatch: X'
        assert config1.sde.adj == config2.sde.adj, 'SDE Mismatch: Adj'
