import numpy as np
import random
import torch
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.neighbors import KernelDensity
from rdkit import Chem as Chem
from rdkit.Chem import AllChem
import json
from os import listdir
from os.path import isfile, join
import pickle
from collections import defaultdict
from common import args, prepare_mlp
from polymer_lib import form_double_units
from mlp_retrosyn.mlp_inference import preprocess


def same_mol(smi1, smi2):
    return Chem.CanonSmiles(smi1, useChiral=0) == Chem.CanonSmiles(smi2, useChiral=0)

def load_condensation_dict(cond_file):
    data = json.load(open(cond_file, 'r'))
    score_dict = {}
    valid_templates_dict = {}
    gt_template_id_dict = {}
    for k, v in data.items():
        repeat_unit = v['repeat_unit']
        candidates = v['candidates']
        valid_templates_dict[repeat_unit] = v['valid_template_ids']
        gt_template_id_dict[repeat_unit] = v['gt_template_id']
        score_dict[repeat_unit] = defaultdict(float)
        for cand in candidates:
            score_dict[repeat_unit][cand[0]] = max(score_dict[repeat_unit][cand[0]], cand[1])

    return score_dict, valid_templates_dict, gt_template_id_dict

def load_templates(template_rule_path):
    template_rules = {}
    with open(template_rule_path, 'r') as f:
        for i, l in tqdm(enumerate(f), desc='template rules'):
            rule= l.strip()
            template_rules[rule] = i
    idx2rule = {}
    for rule, idx in template_rules.items():
        idx2rule[idx] = rule

    return idx2rule

def compute_template_feat(template):
    rxn = AllChem.ReactionFromSmarts(template)
    n_reactants = rxn.GetNumReactantTemplates()
    n_products = rxn.GetNumProductTemplates()
    assert n_reactants == 1 and n_products == 2
    n_atoms = []
    for i in range(n_reactants):
        n_atoms.append(rxn.GetReactants()[i].GetNumAtoms())
    for i in range(n_products):
        n_atoms.append(rxn.GetProducts()[i].GetNumAtoms())

    return n_atoms

def synthesizability_check(monomer, monomer_dict):
    if monomer not in monomer_dict:
        return False

    if float(monomer_dict[monomer]['prob']) < 0.1:
        return False

    return True

def load_full_unit_dict(filter='first', only_gt=False):
    if only_gt:
        interval = [(0,60)]
    else:
        interval = [(0,100), (100,200), (200,300), (300,400), (400,500), (500,600), (600,700)]
    full_unit_candidates_dict = {}
    for start, end in interval:
        full_unit_candidates_dict.update(
            json.load(open('%s/monomer_dict_%d_%d.json' % (args.dataset_dir, start, end), 'r')))

    monomer_result_files = [f for f in listdir(args.monomer_result_dir) if isfile(join(args.monomer_result_dir, f))]
    monomer_dict = {}
    for f in monomer_result_files:
        r = json.load(open(args.monomer_result_dir+'/'+f, 'r'))
        monomer_dict[r['cand']] = {
            'route': r['route'],
            'prob': float(r['route_prob'])
        }

    probs = []
    for full_unit_cand in full_unit_candidates_dict:
        can_syn = []
        for cand in full_unit_candidates_dict[full_unit_cand]:
            probs.append(cand['prob'])
            reactants = cand['reactants'].split('.')
            if (synthesizability_check(reactants[0], monomer_dict)) and (synthesizability_check(reactants[1], monomer_dict)):
                can_syn.append(cand)
        if filter == 'first':
            full_unit_candidates_dict[full_unit_cand] = can_syn[:1]
        elif filter == 'all':
            full_unit_candidates_dict[full_unit_cand] = can_syn

        if len(can_syn) == 0:
            print(full_unit_cand, 'not synthesizable')

    full_unit_candidates_dict['default'] = np.array(probs).mean()

    return full_unit_candidates_dict

def test(k_list, gt_info, full_unit_candidates_dict, valid_unit_polymer):
    device_id = 0 if args.gpu >= 0 else -1
    r0_one_step = prepare_mlp(args.mlp_templates_r0, args.mlp_model_dump_r0,
                              device_id, polymer=True)

    _, _, gt_template_id_dict = load_condensation_dict(args.real_dataset)

    # overwrite
    cond_valid_templates_dict = {}
    for repeat_unit in valid_unit_polymer:
        str_template_ids = [int(s) for s in valid_unit_polymer[repeat_unit].keys()]
        cond_valid_templates_dict[repeat_unit] = str_template_ids

    templates = load_templates(args.mlp_templates_r0)

    result_groups = {}
    result_groups['summary'] = {}
    for k in k_list:
        result_groups['summary'][k] = {}
        result_groups['summary'][k]['recovered'] = {}
        result_groups['summary'][k]['monomer_recovered'] = {}

    for group_idx in range(5):
        print('=================================== fold %d ===================================' % group_idx)
        train_repeat_units = [repeat_unit for repeat_unit in gt_info if gt_info[repeat_unit]['group'] == group_idx]
        test_repeat_units = [repeat_unit for repeat_unit in gt_info if gt_info[repeat_unit]['group'] != group_idx]

        # kernel density estimation
        known_repeat_units = train_repeat_units
        X = []
        for repeat_unit in known_repeat_units:
            id = gt_template_id_dict[repeat_unit]
            if id >= 0:
                feat = compute_template_feat(templates[id])
                X.append(feat)
        X = np.array(X)
        kde = KernelDensity(kernel='gaussian', bandwidth=1).fit(X)

        results = {}
        for k in k_list:
            results[k] = {}
            results[k]['recovered'] = {}
            results[k]['monomer_recovered'] = {}

        for repeat_unit in test_repeat_units:
            template_cands = cond_valid_templates_dict[repeat_unit]

            double_unit, _ = form_double_units(repeat_unit)
            arr = preprocess(double_unit, fp_dim=r0_one_step.fp_dim)
            arr = np.reshape(arr, [-1, arr.shape[0]])
            arr = torch.tensor(arr, dtype=torch.float32)
            preds_backward = F.softmax(r0_one_step.net(arr), dim=1)[0]

            preds_prior = torch.zeros(len(templates))

            for id in cond_valid_templates_dict[repeat_unit]:
                feat = compute_template_feat(templates[id])
                X = np.array(feat).reshape(1, -1)
                preds_prior[id] = np.exp(kde.score(X))
            preds_prior /= preds_prior.sum()


            if args.alg == 'uspto':
                preds = preds_backward
            elif args.alg == 'polyretro':
                preds = preds_backward + preds_prior * 0.001
            elif args.alg == 'random':
                preds = torch.rand(len(templates))


            filter = torch.tensor([0 if idx in template_cands else -np.inf for idx in range(preds.shape[-1])]).reshape(preds.shape)
            preds += filter

            for k in k_list:
                _, idx_list = torch.topk(preds, k=k)
                recovered = False
                for idx in idx_list:
                    if idx.item() == gt_template_id_dict[repeat_unit]:
                        recovered = True
                results[k]['recovered'][repeat_unit] = recovered


            # monomer recover
            for id in cond_valid_templates_dict[repeat_unit]:
                stab_prob = full_unit_candidates_dict['default']
                if str(id) in valid_unit_polymer[repeat_unit]:
                    full_unit = valid_unit_polymer[repeat_unit][str(id)]
                    if same_mol(full_unit, gt_info[repeat_unit]['full_unit']):
                        full_unit = gt_info[repeat_unit]['full_unit']

                    if full_unit in full_unit_candidates_dict:
                        if len(full_unit_candidates_dict[full_unit]) > 0:
                            stab_prob = full_unit_candidates_dict[full_unit][0]['prob']
                            # stab_prob = max(stab_prob, 1e-10)

                if gt_info[repeat_unit]['monomer_match']:
                    if id == gt_template_id_dict[repeat_unit]:
                        if stab_prob < 1e-8:
                            print(repeat_unit, stab_prob)

                preds[id] *= stab_prob

            for k in k_list:
                _, idx_list = torch.topk(preds, k=k)
                monomer_recovered = False
                for idx in idx_list:
                    if idx.item() == gt_template_id_dict[repeat_unit]:
                        recovered = True
                        if gt_info[repeat_unit]['monomer_match']:
                            monomer_recovered = True
                results[k]['monomer_recovered'][repeat_unit] = monomer_recovered

        result_groups[group_idx] = results
        for k in k_list:
            print('k:', k)
            recovered = np.array(list(results[k]['recovered'].values()), dtype='float').mean()
            print('average unit polymer recovered:', recovered)
            monomer_recovered = np.array(list(results[k]['monomer_recovered'].values()), dtype='float').mean()
            print('average monomer recovered:', monomer_recovered)

            result_groups['summary'][k]['recovered'][group_idx] = recovered
            result_groups['summary'][k]['monomer_recovered'][group_idx] = monomer_recovered


    for k in k_list:
        print('final summary:')
        print('k:', k)
        recovered = np.array(list(result_groups['summary'][k]['recovered'].values()), dtype='float').mean()
        std = np.array(list(result_groups['summary'][k]['recovered'].values()), dtype='float').std()
        print('average recovered:', recovered, std)
        recovered = np.array(list(result_groups['summary'][k]['monomer_recovered'].values()), dtype='float').mean()
        std = np.array(list(result_groups['summary'][k]['monomer_recovered'].values()), dtype='float').std()
        print('average monomer_recovered:', recovered, std)

    result_file = args.final_result_dir + '/%s_real.pkl' % args.alg
    print('saving results to %s' % result_file)
    pickle.dump(result_groups, open(result_file, 'wb'))

def main_experiment():
    valid_unit_polymer = json.load(open('%s/unit_polymer_dict.json' % (args.dataset_dir), 'r'))

    full_unit_candidates_dict = load_full_unit_dict()
    gt_info = json.load(open(args.dataset_dir + '/cond2monomer_gt_info.json', 'r'))

    polymer_full_unit_monomer_result = {}

    cnt = 0
    for repeat_unit in gt_info:
        polymer_full_unit_monomer_result[repeat_unit] = {}

        gt_info[repeat_unit]['monomer_match'] = False
        for template_idx in valid_unit_polymer[repeat_unit]:
            full_unit = valid_unit_polymer[repeat_unit][template_idx]
            if same_mol(full_unit, gt_info[repeat_unit]['full_unit']):
                full_unit = gt_info[repeat_unit]['full_unit']

            if full_unit in full_unit_candidates_dict:
                for cand in full_unit_candidates_dict[full_unit]:
                    polymer_full_unit_monomer_result[repeat_unit][template_idx] = {
                        'full_unit': full_unit,
                        'monomer': cand['reactants'],
                        'recovered': same_mol(cand['reactants'], gt_info[repeat_unit]['monomer'])
                    }
                    if polymer_full_unit_monomer_result[repeat_unit][template_idx]['recovered']:
                        if same_mol(full_unit, gt_info[repeat_unit]['full_unit']):
                            cnt += 1
                            gt_info[repeat_unit]['monomer_match'] = True
                            break


    k_list = [1, 2, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
    test(k_list, gt_info, full_unit_candidates_dict, valid_unit_polymer)

def compute_upperbound():
    _, _, gt_template_id_dict = load_condensation_dict(args.real_dataset)
    gt_info = json.load(open(args.dataset_dir + '/cond2monomer_gt_info.json', 'r'))
    full_unit_candidates_dict_all = load_full_unit_dict(filter='all', only_gt=True)
    full_unit_candidates_dict_none = load_full_unit_dict(filter=None, only_gt=True)

    full_unit_upperbound = {}
    monomer_upperbound = {}
    monomer_upperbound_with_synthesizability = {}

    for repeat_unit in gt_info:
        full_unit = gt_info[repeat_unit]['full_unit']
        gt_template = gt_template_id_dict[repeat_unit]
        full_unit_upperbound[repeat_unit] = gt_template >= 0

        monomer_upperbound[repeat_unit] = False
        monomer_upperbound_with_synthesizability[repeat_unit] = False
        if gt_template >= 0:
            for cand in full_unit_candidates_dict_all[full_unit]:
                if same_mol(cand['reactants'], gt_info[repeat_unit]['monomer']):
                    monomer_upperbound[repeat_unit] = True
            for cand in full_unit_candidates_dict_none[full_unit]:
                if same_mol(cand['reactants'], gt_info[repeat_unit]['monomer']):
                    monomer_upperbound_with_synthesizability[repeat_unit] = True

    print('upperbound on full units:', np.array(list(full_unit_upperbound.values())).sum() * 1.0 / len(gt_info))
    print('upperbound on monomers:', np.array(list(monomer_upperbound.values())).sum() * 1.0 / len(gt_info))
    print('upperbound on synthesizable monomers:', np.array(list(monomer_upperbound_with_synthesizability.values())).sum() * 1.0 / len(gt_info))

    upperbounds = {
        'full_units': np.array(list(full_unit_upperbound.values())).sum() * 1.0 / len(gt_info),
        'monomers': np.array(list(monomer_upperbound.values())).sum() * 1.0 / len(gt_info),
        'synthesizability': np.array(list(monomer_upperbound_with_synthesizability.values())).sum() * 1.0 / len(gt_info)
    }

    result_file = args.final_result_dir + '/upperbound_real.pkl'
    print('saving results to %s' % result_file)
    pickle.dump(upperbounds, open(result_file, 'wb'))


if __name__ == '__main__':
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    main_experiment()
    if args.alg == 'polyretro':
        compute_upperbound()

