from __future__ import print_function
import numpy as np
import torch
import torch.nn.functional as F
from rdkit import Chem
import rdchiral
from rdchiral.main import rdchiralRunText, rdchiralRun
from rdchiral.initialization import rdchiralReaction, rdchiralReactants
from .mlp_policies import load_parallel_model, load_model, preprocess
from collections import defaultdict, OrderedDict
from tqdm import tqdm
import multiprocessing
from rdkit.Chem import AllChem

def merge(reactant_d):
    ret = []
    for reactant, l in reactant_d.items():
        ss, ts = zip(*l)
        ret.append((reactant, sum(ss), list(ts)[0]))
    reactants, scores, templates = zip(*sorted(ret,key=lambda item : item[1], reverse=True))
    return list(reactants), list(scores), list(templates)

def get_atom_map(smi):
    atom_map = set()
    m = Chem.MolFromSmiles(smi)
    for a in m.GetAtoms():
        atom_map.add(a.GetAtomMapNum())
    atom_map = atom_map - {0}

    return atom_map

def has_labeled_neighbor(atom):
    neighbors = atom.GetNeighbors()
    for atom in neighbors:
        if atom.GetAtomMapNum() > 0:
            return True
    return False

def fill_atom_map_num(smi, lost_idx):
    m = Chem.MolFromSmiles(smi)
    candidate_atoms = []
    for a in m.GetAtoms():
        if (a.GetAtomMapNum() == 0) and has_labeled_neighbor(a):
            candidate_atoms.append(a)

    if len(candidate_atoms) == 1:
        candidate_atoms[0].SetAtomMapNum(int(lost_idx))
    elif len(candidate_atoms) == 2:
        print(smi + ' has more than one connection point!')
    elif len(candidate_atoms) == 0:
        print(smi + ' has no connection point!')

    return Chem.MolToSmiles(m)

def no_duplicated_atom_map_num(smi):
    atm_set = set()
    m = Chem.MolFromSmiles(smi)
    for a in m.GetAtoms():
        if a.GetAtomMapNum() != 0:
            if a.GetAtomMapNum() in atm_set:
                return False
            else:
                atm_set.add(a.GetAtomMapNum())
    return True

def make_valid_polymerization(reactants, num_atoms):
    # if len(reactants) != 2:
    #     return False
    # for r in reactants:
    #     if Chem.MolFromSmiles(r).GetNumAtoms() < 8:
    #         return False
    # return True

    # atom_nums = set()
    # for r in reactants:
    #     atom_nums.add(Chem.MolFromSmiles(r).GetNumAtoms())
    # # print(atom_nums)
    # gt = {9, 8}
    # if atom_nums == gt:
    #     return True
    # else:
    #     return False

    for smi in reactants:
        if not no_duplicated_atom_map_num(smi):
            return False

    set_a = set(np.arange(num_atoms)+1)
    set_b = set(np.arange(num_atoms)+1+num_atoms)

    if len(reactants) == 1:
        atom_map = get_atom_map(reactants[0])

        if set_a.issubset(atom_map) and len(set_b.intersection(atom_map)) == 0:
            return True
        if set_b.issubset(atom_map) and len(set_a.intersection(atom_map)) == 0:
            return True

    elif len(reactants) == 2:
        for i in range(2):
            atom_map = get_atom_map(reactants[i])
            if len(atom_map) == num_atoms - 1 and atom_map.issubset(set_a):
                lost_idx = (set_a - atom_map).pop()
                reactants[i] = fill_atom_map_num(reactants[i], lost_idx)
            if len(atom_map) == num_atoms - 1 and atom_map.issubset(set_b):
                lost_idx = (set_b - atom_map).pop()
                reactants[i] = fill_atom_map_num(reactants[i], lost_idx)


        atom_map_a = get_atom_map(reactants[0])
        atom_map_b = get_atom_map(reactants[1])

        if set_a.issubset(atom_map_a) and set_b.issubset(atom_map_b):
            return True
        if set_a.issubset(atom_map_b) and set_b.issubset(atom_map_a):
            return True

    return False

def make_valid_polymerization_wrapper(input):
    (reactant, id), num_atoms = input
    reactant_list = list(reactant.split('.'))

    if make_valid_polymerization(reactant_list, num_atoms):
        return '.'.join(reactant_list), id

    return None, id

def is_stable_wrapper(input):
    reactant, id, x = input
    reactant_list = list(reactant.split('.'))

    if len(reactant_list) != 2:
        return None, id

    if not satisfy_size_constraint(reactant_list, x):
        return None, id

    for mol in reactant_list:
        if not is_symmetric_mol(mol):
            return None, id

    return reactant, id

def satisfy_size_constraint(reactant_list, product):
    r_size = 0
    for r in reactant_list:
        m_r = Chem.MolFromSmiles(r)
        r_size += m_r.GetNumAtoms()
    m_p = Chem.MolFromSmiles(product)
    return r_size <= m_p.GetNumAtoms() + 3

def is_symmetric_mol(smi):
    m = Chem.MolFromSmiles(smi)
    matches = m.GetSubstructMatches(m, uniquify=False)
    if len(matches) <= 1:
        return False
    for i in range(len(matches)):
        for j in range(len(matches)):
            if i < j and check_sym(m, matches[i], matches[j]):
                return True
    return False

def check_sym(mol, match_a, match_b):
    idx_mapping = {match_a[i]: match_b[i] for i in range(len(match_a))}
    mapped_atoms = []
    for a in idx_mapping:
        if idx_mapping[idx_mapping[a]] != a:
            return False
        if idx_mapping[a] != a:
            mapped_atoms.append(a)

    if len(mapped_atoms) * 2 < len(match_a):
        return False

    if len(mapped_atoms) == 4:
        benzene_ring = True
        for a in mapped_atoms:
            atom = mol.GetAtomWithIdx(a)
            if not atom.GetIsAromatic():
                benzene_ring = False
        if benzene_ring:
            return False

    for idx in idx_mapping:
        if mol.GetAtomWithIdx(idx).GetAtomMapNum() == 1:
            if idx_mapping[idx] == idx:
                return False

    return True

def runtextwrapper(input):
    id, rule, x = input
    out1 = None
    # mapped_outcomes = None
    # try:
    #     out1, mapped_outcomes = rdchiralRunText(rule, x,
    #                                             return_mapped=True)
    # except Exception:
    #     pass

    try:
        rxn = AllChem.ReactionFromSmarts(rule)
        ps = rxn.RunReactants([Chem.MolFromSmiles(x)])
        out1 = []
        for p in ps:
            s = []
            for m in p:
                s.append(Chem.MolToSmiles(m))
            s = '.'.join(s)
            mol = Chem.MolFromSmiles(s)
            if mol is None:
                continue
            out1.append(s)
    except ValueError:
        pass
    except RuntimeError:
        pass

    return out1, id

class MLPModel(object):
    def __init__(self,state_path, template_path, device=-1, fp_dim=2048, polymer=False, use_load_model=False):
        super(MLPModel, self).__init__()
        self.fp_dim = fp_dim
        if polymer or use_load_model:
            self.net, self.idx2rules = load_model(state_path, template_path, fp_dim)
        else:
            self.net, self.idx2rules = load_parallel_model(state_path,template_path, fp_dim)
        self.net.eval()
        self.device = device
        self.polymer = polymer
        if device >= 0:
            self.net.to(device)

    def run(self, x, topk=10, n_cores=1, enum=True, stability_check=False):
        if not self.polymer:
            if len(x) > 300:
                return None
        arr = preprocess(x, self.fp_dim)
        arr = np.reshape(arr,[-1, arr.shape[0]])
        arr = torch.tensor(arr, dtype=torch.float32)
        if self.device >= 0:
            arr = arr.to(self.device)
        preds = self.net(arr)
        preds = F.softmax(preds,dim=1)
        if self.device >= 0:
            preds = preds.cpu()

        # ===================== conditional model =============================
        if self.polymer:
            # x = '[H][c:1]1[cH:2][cH:3][c:4]([CH:7]=[CH:8][c:9]2[cH:10][cH:11][c:12]([CH:15]=[CH:16][H])[cH:13][cH:14]2)[cH:5][cH:6]1'
            reactants = []
            scores = []
            templates = []
            if Chem.MolFromSmiles(x) is None:
                return None
            num_atoms = len(Chem.MolFromSmiles(x).GetAtoms()) / 2

            pool = multiprocessing.Pool(n_cores)

            if not enum:
                probs, idx = torch.topk(preds,k=topk)
                rule_k = [(id, self.idx2rules[id]) for id in idx[0].numpy().tolist()]
            else:
                rule_k = list(self.idx2rules.items())

            def geninputs():
                for id, rule in rule_k:
                    yield (id, rule, x)

            # print(x)
            print("trying all %d templates" % len(rule_k))
            pbar = tqdm(pool.imap_unordered(runtextwrapper, geninputs()),
                        total=len(rule_k))
            # pbar = pool.imap_unordered(runtextwrapper, geninputs())

            suc_cnt = 0
            valid_cnt = 0
            runtext_results = []
            for result in pbar:
                out1, id = result
                if out1 is None:
                    continue

                if len(out1) == 0: continue
                suc_cnt += len(out1)

                out1 = sorted(out1)
                out1 = [(x, id) for x in out1]
                runtext_results.extend(out1)

                # for reactant in out1:
                #     reactant_list = list(reactant.split('.'))
                #
                #     if make_valid_polymerization(reactant_list, num_atoms):
                #         valid_cnt += 1
                #         # print(id, reactant_list)
                #         reactants.append('.'.join(reactant_list))
                #         # scores.append(preds[0][id].item()/len(out1))
                #         scores.append(1.)
                #         templates.append(self.idx2rules[id])

                pbar.set_description('%d' % (suc_cnt))

            def geninputs():
                for i in range(len(runtext_results)):
                    yield (runtext_results[i], num_atoms)

            print("filtering all %d fitted templates" % len(runtext_results))
            pbar = tqdm(pool.imap_unordered(make_valid_polymerization_wrapper, geninputs()),
                        total=len(runtext_results))
            for out in pbar:
                reactant, id = out
                if reactant is not None:
                    valid_cnt += 1
                    reactants.append(reactant)
                    scores.append(1.)
                    templates.append(id)

                pbar.set_description('%d' % (valid_cnt))

            pool.close()
            # for id in tqdm(range(len(self.idx2rules))):
            #     rule = self.idx2rules[id]
                # try:
                #     out1, mapped_outcomes = rdchiralRunText(rule, x, return_mapped=True)
                #     if len(out1) == 0: continue
                #     out1 = sorted(out1)
                #     for reactant in out1:
                #         reactant_list = list(set(mapped_outcomes[reactant][0].split('.')))
                #         if polymerization_valid(reactant_list, num_atoms):
                #             print(id, reactant_list)
                #             reactants.append(mapped_outcomes[reactant][0])
                #             scores.append(preds[0][id].item()/len(out1))
                #             templates.append(rule)
                # except ValueError:
                #     pass
            if len(reactants) == 0: return None
        # ===================== conditional model =============================

        # ===================== conditional model-2 =============================
        elif stability_check:
            reactants = []
            scores = []
            templates = []
            if Chem.MolFromSmiles(x) is None:
                return None
            num_atoms = len(Chem.MolFromSmiles(x).GetAtoms()) / 2

            pool = multiprocessing.Pool(n_cores)
            rule_k = list(self.idx2rules.items())

            def geninputs():
                for id, rule in rule_k:
                    yield (id, rule, x)

            print("trying all %d templates" % len(rule_k))
            pbar = tqdm(pool.imap_unordered(runtextwrapper, geninputs()),
                        total=len(rule_k))

            suc_cnt = 0
            valid_cnt = 0
            runtext_results = []
            for result in pbar:
                out1, id = result
                if out1 is None:
                    continue

                if len(out1) == 0: continue
                suc_cnt += len(out1)

                out1 = sorted(out1)
                out1 = [(o, id, x) for o in out1]
                runtext_results.extend(out1)
                pbar.set_description('%d' % (suc_cnt))

            print("filtering all %d fitted templates" % len(runtext_results))
            pbar = tqdm(pool.imap_unordered(is_stable_wrapper, runtext_results),
                        total=len(runtext_results))
            for out in pbar:
                reactant, id = out
                if reactant is not None:
                    valid_cnt += 1
                    reactants.append(reactant)
                    scores.append(preds[0][id].item())
                    templates.append(id)
                pbar.set_description('%d' % (valid_cnt))

            pool.close()
            if len(reactants) == 0: return None
        # ===================== conditional model-2 =============================

        # ========================== old model ================================
        else:
            probs, idx = torch.topk(preds,k=topk)
            # probs = F.softmax(probs,dim=1)
            rule_k = [self.idx2rules[id] for id in idx[0].numpy().tolist()]
            reactants = []
            scores = []
            templates = []
            for i , rule in enumerate(rule_k):
                out1 = []
                try:
                    out1, mapped_outcomes = rdchiralRunText(rule, x, return_mapped=True)
                    # out1 = rdchiralRunText(rule, Chem.MolToSmiles(Chem.MolFromSmarts(x)))
                    if len(out1) == 0: continue
                    # if len(out1) > 1: print("more than two reactants."),print(out1)
                    out1 = sorted(out1)
                    for reactant in out1:
                        # reactants.append(mapped_outcomes[reactant][0])
                        reactants.append(reactant)
                        scores.append(probs[0][i].item()/len(out1))
                        templates.append(rule)
                # out1 = rdchiralRunText(x, rule)
                except Exception:
                    pass
            if len(reactants) == 0: return None
        # ========================== old model ================================


        reactants_d = defaultdict(list)
        for r, s, t in zip(reactants, scores, templates):
            if '.' in r:
                str_list = sorted(r.strip().split('.'))
                reactants_d['.'.join(str_list)].append((s, t))
            else:
                reactants_d[r].append((s, t))

        reactants, scores, templates = merge(reactants_d)
        total = sum(scores)
        if not stability_check:
            scores = [s / total for s in scores]
        else:
            probs, idx = torch.topk(preds,k=100)
            total = probs.sum().item()
            scores = [s / total for s in scores]
        return {'reactants':reactants,
                'scores' : scores,
                'template' : templates}



if __name__ == '__main__':
    import argparse
    from pprint import pprint
    parser = argparse.ArgumentParser(description="Policies for retrosynthesis Planner")
    parser.add_argument('--template_rule_path', default='../data/uspto_all/template_rules_1.dat',
                        type=str, help='Specify the path of all template rules.')
    parser.add_argument('--model_path', default='../model/saved_rollout_state_1_2048.ckpt',
                        type=str, help='specify where the trained model is')
    args = parser.parse_args()
    state_path = args.model_path
    template_path = args.template_rule_path
    model =  MLPModel(state_path,template_path,device=-1)
    x = '[F-:1]'
    # x = '[CH2:10]([S:14]([O:3][CH2:2][CH2:1][Cl:4])(=[O:16])=[O:15])[CH:11]([CH3:13])[CH3:12]'
    # x = '[S:3](=[O:4])(=[O:5])([O:6][CH2:7][CH:8]([CH2:9][CH2:10][CH2:11][CH3:12])[CH2:13][CH3:14])[OH:15]'
    # x = 'OCC(=O)OCCCO'
    # x = 'CC(=O)NC1=CC=C(O)C=C1'
    x = 'S=C(Cl)(Cl)'
    # x = "NCCNC(=O)c1ccc(/C=N/Nc2ncnc3c2cnn3-c2ccccc2)cc1"
    # x = 'CCOC(=O)c1cnc2c(F)cc(Br)cc2c1O'
    # x = 'COc1cc2ncnc(Oc3cc(NC(=O)Nc4cc(C(C)(C)C(F)(F)F)on4)ccc3F)c2cc1OC'
    # x = 'COC(=O)c1ccc(CN2C(=O)C3(COc4cc5c(cc43)OCCO5)c3ccccc32)o1'
    x = 'O=C1Nc2ccccc2C12COc1cc3c(cc12)OCCO3'
    # x = 'CO[C@H](CC(=O)O)C(=O)O'
    # x = 'O=C(O)c1cc(OCC(F)(F)F)c(C2CC2)cn1'
    y = model.run(x,10)
    pprint(y)
