from os import listdir
from os.path import isfile, join
import glob
import json
from tqdm import tqdm
from collections import defaultdict
from functools import partial
from rdkit import Chem
from rdkit.Chem import AllChem
from common import args, prepare_test_polymers, prepare_mlp
from polymer_lib import form_double_units, produce_candidate_monomer, polymerization_is_valid
from test_real_polymers import load_condensation_dict


ROUTE_THRES = 0.1
FORWARD_THRES = 0.00001

def extract_training_set():
    result_files = glob.glob(args.output_dir + '/**/*.json', recursive=True)
    print('%d files found' % len(result_files))

    freq_dict = defaultdict(float)

    gt_list = {}
    for f in result_files:
        f_idx = int(f.split('/')[-1].split('.')[0].split('_')[1])
        r = json.load(open(f, 'r'))
        gt_f = []
        for cand in r['candidates']:
            if 'route_prob' in cand and float(cand['route_prob']) > ROUTE_THRES and float(cand['forward_prob']) > FORWARD_THRES:
                gt_f.append((cand['template_id'], float(cand['forward_prob'])))
        gt_f = sorted(gt_f, key=lambda x: -x[1])

        if len(gt_f) > 0:
            for idx, prob in gt_f:
                if prob > 0.5:
                    freq_dict[idx] += 1
            gt_list[f_idx] = {
                'repeat_unit': r['repeat_unit'],
                'candidates': gt_f
            }

    gt_list = {k: v for k, v in sorted(gt_list.items(), key=lambda x: x[0])}
    freq_dict = {k: v for k, v in sorted(freq_dict.items(), key=lambda x: -x[1])}

    return gt_list

def same_mol(smi1, smi2):
    return Chem.CanonSmiles(smi1, useChiral=0) == Chem.CanonSmiles(smi2, useChiral=0)

def extract_test_gt():
    gt_info = json.load(open(args.dataset_dir + '/gt_info.json', 'r'))

    result_files = [f for f in listdir(args.test_output_dir) if isfile(join(args.test_output_dir, f))]
    gt_list = {}
    for f in result_files:
        idx = int(f.split('.')[0].split('_')[1])
        r = json.load(open(args.test_output_dir+'/'+f, 'r'))
        gt_f = []
        for cand in r['candidates']:
            if 'route_prob' in cand and float(cand['route_prob']) > ROUTE_THRES and float(cand['forward_prob']) > FORWARD_THRES:
                gt_f.append((cand['monomer'], float(cand['forward_prob'])))
        gt_f = sorted(gt_f, key=lambda x: -x[1])
        gt_list[idx] = {
            'candidates': gt_f
        }

    gt_list = {k: v for k, v in sorted(gt_list.items(), key=lambda x: x[0])}



    repeat_units, monomers, _, _ = \
        prepare_test_polymers(args.test_real_polymers)
    device_id = 0 if args.gpu >= 0 else -1
    r0_one_step = prepare_mlp(args.mlp_templates_r0, args.mlp_model_dump_r0,
                              device_id, polymer=True)
    r0_expand_fn = partial(r0_one_step.run, n_cores=args.n_cores)

    for (polyidx, repeat_unit) in enumerate(repeat_units):
        gt_list[polyidx]['repeat_unit'] = repeat_unit
        gt_list[polyidx]['gt_template_id'] = -1

        double_unit, double_idxs = form_double_units(repeat_unit)
        out = r0_expand_fn(double_unit)

        num_atoms = len(Chem.MolFromSmiles(repeat_unit).GetAtoms()) - 2
        reactants_candidates = []
        template_dict = {}
        if out is not None and (len(out['reactants']) > 0):
            for j in range(len(out['reactants'])):
                if polymerization_is_valid(out['reactants'][j], num_atoms):
                    reactants_candidates.append(out['reactants'][j])
                    template_dict[out['reactants'][j]] = out['template'][j]

        forward_data = []
        for reactants in reactants_candidates:
            m = produce_candidate_monomer(reactants, double_idxs)

            if m is not None:
                forward_data.append((reactants, m))

                if same_mol(gt_info[repeat_unit]['full_unit'], m):
                    print('found ground truth:', template_dict[reactants], m)
                    gt_list[polyidx]['gt_template_id'] = template_dict[reactants]


        gt_f = gt_list[polyidx]['candidates']
        gt_list[polyidx]['valid_template_ids'] = list(set([template_dict[r] for r, m in forward_data]))

        updated_gt_f = []
        id_set = set()
        for monomer, forward_prob in gt_f:
            found = False
            for reactants, m in forward_data:
                if monomer == m:
                    found = True
                    if template_dict[reactants] in id_set:
                        continue
                    id_set.add(template_dict[reactants])
                    updated_gt_f.append((template_dict[reactants], forward_prob))
            if not found:
                print('%s not found!' % monomer)
        gt_list[polyidx]['candidates'] = updated_gt_f

    return gt_list

def extract_monomer_candidate_dict():
    gt_info = json.load(open(args.dataset_dir + '/condensation_gt_info.json', 'r'))

    full_unit_candidate_list = []
    repeat_units = []

    result_files = [f for f in listdir(args.test_output_dir) if isfile(join(args.test_output_dir, f))]
    for f in result_files:
        r = json.load(open(args.test_output_dir+'/'+f, 'r'))
        repeat_unit = r['repeat_unit']
        if repeat_unit not in gt_info:
            continue
        if gt_info[repeat_unit]['from'] != 'reaction':
            continue
        full_unit_candidate_list.append(gt_info[repeat_unit]['full_unit'])
        repeat_units.append(repeat_unit)
        # for cand in r['candidates']:
        #     if 'route_prob' in cand and float(cand['route_prob']) > ROUTE_THRES and float(cand['forward_prob']) > FORWARD_THRES:
        #         full_unit_candidate_list.append(cand['monomer'])

    device_id = 0 if args.gpu >= 0 else -1
    one_step = prepare_mlp(args.mlp_templates_r0, args.mlp_model_dump_r0, device_id, use_load_model=True)
    expand_fn = partial(one_step.run,
                        topk=args.expansion_topk, n_cores=args.n_cores, stability_check=True)

    monomer_candidate_dict = {}

    # overwrite
    # full_unit_candidate_list = json.load(open('%s/full_unit_candidates.json' % (args.dataset_dir), 'r'))

    # overwrite
    # full_unit_candidate_list = []
    # valid_unit_polymer = json.load(open('%s/unit_polymer_dict.json' % (args.dataset_dir), 'r'))
    # for repeat_unit in valid_unit_polymer:
    #     for idx in valid_unit_polymer[repeat_unit]:
    #         full_unit_candidate_list.append(valid_unit_polymer[repeat_unit][idx])

    # r_cnt = 0
    # r1_cnt = 0
    print('total', len(full_unit_candidate_list))
    print('from %d to %d' % (args.start, args.end))
    for idx, full_unit in enumerate(full_unit_candidate_list):
        if idx < args.start or idx >= args.end:
            continue

        # repeat_unit = repeat_units[idx]
        # mol = Chem.RWMol(Chem.MolFromSmiles(full_unit))
        # smt = repeat_unit.replace('([At])', '')
        # patt = Chem.MolFromSmarts(smt)
        # matches = mol.GetSubstructMatch(patt)
        # # print(matches, mol.GetNumAtoms())
        # for i in range(mol.GetNumAtoms()):
        #     if i not in matches:
        #         mol.GetAtomWithIdx(i).SetAtomMapNum(1)
        # full_unit = Chem.MolToSmiles(mol)

        # print(Chem.MolToSmiles(mol))
        out = expand_fn(full_unit)
        print('---------------', idx, '--------------------')
        print(full_unit)
        if out is not None and len(out['reactants']) > 0:
            # gt_monomer = None
            # for repeat_unit in gt_info:
            #     if gt_info[repeat_unit]['full_unit'] == full_unit:
            # gt_monomer = gt_info[repeat_unit]['monomer']
            #
            # assert gt_monomer is not None
            #
            # recovered = False
            monomer_candidate_dict[full_unit] = []
            print('length:', len(out['reactants']))
            for j in range(min(200, len(out['reactants']))):
                monomer_candidate_dict[full_unit].append({
                    'reactants': out['reactants'][j],
                    'prob': out['scores'][j],
                    'template': out['template'][j]
                })
                # print(monomer_candidate_dict[full_unit])
            #
            #     if same_mol(out['reactants'][j], gt_monomer):
            #         print('recovered!')
            #         recovered = True
            #         if j == 0:
            #             r1_cnt += 1
            #
            # if recovered:
            #     r_cnt += 1
            # print('total recovered:', r_cnt, r1_cnt)

    return monomer_candidate_dict

def prepare_valid_unit_polymer():
    cond_score_dict, cond_valid_templates_dict, gt_template_id_dict = load_condensation_dict(args.real_dataset)
    gt_info = json.load(open(args.dataset_dir + '/condensation_gt_info.json', 'r'))


    device_id = 0 if args.gpu >= 0 else -1
    r0_one_step = prepare_mlp(args.mlp_templates_r0, args.mlp_model_dump_r0,
                              device_id, polymer=True)
    r0_expand_fn = partial(r0_one_step.run, n_cores=args.n_cores)

    valid_unit_polymer = {}

    for repeat_unit in gt_info:
        valid_unit_polymer[repeat_unit] = {}

        double_unit, double_idxs = form_double_units(repeat_unit)
        out = r0_expand_fn(double_unit)
        if out is not None and (len(out['reactants']) > 0):
            for j in range(len(out['reactants'])):
                if out['template'][j] in cond_valid_templates_dict[repeat_unit]:
                    m = produce_candidate_monomer(out['reactants'][j], double_idxs)
                    valid_unit_polymer[repeat_unit][out['template'][j]] = m

    return valid_unit_polymer


if __name__ == '__main__':
    # training_set = extract_training_set()
    # test_set = extract_test_gt()

    # json.dump(training_set, open(args.synthetic_dataset, 'w'))
    # json.dump(test_set, open(args.real_dataset, 'w'))

    monomer_candidate_dict = extract_monomer_candidate_dict()
    json.dump(monomer_candidate_dict, open('%s/monomer_dict_%d_%d.json' % (args.dataset_dir, args.start, args.end), 'w'))

    # valid_unit_polymer = prepare_valid_unit_polymer()
    # json.dump(valid_unit_polymer,
    #           open('%s/unit_polymer_dict.json' % (args.dataset_dir), 'w'))