import json
from copy import deepcopy
from collections import defaultdict
from rdkit import Chem
from rdkit.Chem import rdqueries
from common import args, prepare_test_polymers, prepare_mlp
from polymer_lib import get_neighbor_atom, get_atom_by_atom_map_num, remove_atom_map
from rexgen_direct.api import ReactionPrediction
from functools import partial


def same_mol(smi1, smi2):
    return Chem.CanonSmiles(smi1, useChiral=0) == Chem.CanonSmiles(smi2, useChiral=0)

def find_AT_and_neighbors(mol):
    AT = 85
    at_query = rdqueries.AtomNumEqualsQueryAtom(AT)
    at_atoms = mol.GetAtomsMatchingQuery(at_query)
    neighbors = [get_neighbor_atom(at_atom) for at_atom in at_atoms]

    return at_atoms, neighbors

def remove_AT(mol):
    AT = 85
    at_query = rdqueries.AtomNumEqualsQueryAtom(AT)
    while True:
        at_atoms = mol.GetAtomsMatchingQuery(at_query)
        if len(at_atoms) == 0:
            break
        mol.RemoveAtom(at_atoms[0].GetIdx())

def match_substructure(repeat_unit, full_unit):
    part_mol = Chem.MolFromSmiles(repeat_unit)
    at_atoms, neighbors = find_AT_and_neighbors(part_mol)
    assert len(at_atoms) == 2
    bond = part_mol.GetBondBetweenAtoms(neighbors[0].GetIdx(), neighbors[1].GetIdx())
    if bond is not None:
        if bond.GetBondType() == Chem.BondType.SINGLE:
            mw = Chem.RWMol(part_mol)
            mw.RemoveBond(neighbors[0].GetIdx(), neighbors[1].GetIdx())
            mw.AddBond(neighbors[0].GetIdx(), neighbors[1].GetIdx(), Chem.BondType.DOUBLE)
            mw.RemoveAtom(at_atoms[1].GetIdx())
            mw.RemoveAtom(at_atoms[0].GetIdx())
            if same_mol(Chem.MolToSmiles(mw), full_unit):
                # print('matched, single bond -> double bond')
                return True, 'single_to_double'
    else:
        mw = Chem.RWMol(part_mol)
        mw.AddBond(neighbors[0].GetIdx(), neighbors[1].GetIdx(), Chem.BondType.SINGLE)
        mw.RemoveAtom(at_atoms[1].GetIdx())
        mw.RemoveAtom(at_atoms[0].GetIdx())
        if same_mol(Chem.MolToSmiles(mw), full_unit):
            # print('matched, no bond -> single bond')
            return True, 'none_to_single'


    smt = repeat_unit.replace('([At])', '')
    mol = Chem.MolFromSmiles(full_unit)
    patt = Chem.MolFromSmarts(smt)
    if mol.HasSubstructMatch(patt):
        # print('matched, no bond change')
        return True, 'normal'

    return False, ''

def recover_full_unit():
    repeat_units, monomers, _, _ = prepare_test_polymers(args.test_real_polymers)

    rp = ReactionPrediction()

    gt_info = dict()

    for i in range(len(repeat_units)):
        monomer = monomers[i]
        repeat_unit = repeat_units[i]

        matched = False
        if repeat_unit == 'C([At])(=O)c1ccc(cc1)C(=O)Oc1ccc(cc1)C(C)(C)c2ccc(cc2)O([At])':
            full_unit = 'C(O)(=O)c1ccc(cc1)C(=O)Oc1ccc(cc1)C(C)(C)c2ccc(cc2)O'
            matched, match_type = match_substructure(repeat_unit, full_unit)
            assert matched
            gt_info[repeat_unit] = {
                'full_unit': full_unit,
                'monomer': monomer,
                'from': 'reaction',
                'type': match_type
            }
        elif repeat_unit == 'C([At])(=O)c1cc(ccc1)C(=O)Oc1ccc(cc1)C(C)(C)c2ccc(cc2)O([At])':
            full_unit = 'C(O)(=O)c1cc(ccc1)C(=O)Oc1ccc(cc1)C(C)(C)c2ccc(cc2)O'
            matched, match_type = match_substructure(repeat_unit, full_unit)
            assert matched
            gt_info[repeat_unit] = {
                'full_unit': full_unit,
                'monomer': monomer,
                'from': 'reaction',
                'type': match_type
            }
        elif '.' not in monomer:
            matched, match_type = match_substructure(repeat_unit, monomer)
            # print(repeat_unit, monomer, matched)
            if matched:
                gt_info[repeat_unit] = {
                    'full_unit': monomer,
                    'monomer': monomer,
                    'from': 'database',
                    'type': match_type
                }
            else:
                matched = True
                gt_info[repeat_unit] = {
                    'full_unit': monomer,
                    'monomer': monomer,
                    'from': 'database',
                    'type': 'multiple_bond_change'
                }
        else:
            outcome = rp.predict(monomer)
            first = None
            for cand, prob in outcome:
                if first is None:
                    first = cand
                matched, match_type = match_substructure(repeat_unit, cand)
                if matched:
                    # print(repeat_unit, cand, monomer, 'matched with prob', prob)
                    gt_info[repeat_unit] = {
                        'full_unit': cand,
                        'monomer': monomer,
                        'from': 'reaction',
                        'type': match_type
                    }
                    break

        if not matched:
            print('Cannot match', repeat_unit, monomer, first)

    return gt_info

def find_end_neighbor(atom, mapping):
    for neighbor in atom.GetNeighbors():
        if neighbor.GetIdx() not in mapping:
            return neighbor

    return None

def copy_neighbor(mol, atom, atom_model, mapping):
    model_neighbors = [mapping[n.GetIdx()] for n in atom_model.GetNeighbors()]
    # atom_neighbors = [n.GetIdx() for n in atom.GetNeighbors()]

    while True:
        changed = False
        for neighbor in atom.GetNeighbors():
            if neighbor.GetIdx() not in model_neighbors:
                mol.RemoveBond(atom.GetIdx(), neighbor.GetIdx())
                changed = True
                break
        if not changed:
            break

def create_double_full_unit(gt_info):
    for repeat_unit in gt_info:
        print(gt_info[repeat_unit]['type'])
        if gt_info[repeat_unit]['type'] == 'normal':

            full_unit = gt_info[repeat_unit]['full_unit']
            # print(repeat_unit, full_unit)

            patt = Chem.RWMol(Chem.MolFromSmiles(repeat_unit))
            mol = Chem.RWMol(Chem.MolFromSmiles(full_unit))

            at_atoms, neighbors = find_AT_and_neighbors(patt)
            remove_AT(patt)

            mapping = mol.GetSubstructMatch(patt)
            print(mol.GetNumAtoms(), patt.GetNumAtoms(), len(mapping))
            # print(repeat_unit, full_unit, mapping)

            neighbor1_mol = mol.GetAtomWithIdx(mapping[neighbors[1].GetIdx()])
            end_start = find_end_neighbor(neighbor1_mol, mapping)
            if end_start is not None:
                neighbors[0].SetAtomMapNum(1)
                neighbors[1].SetAtomMapNum(2)
                neighbor1_mol.SetAtomMapNum(3)
                end_start.SetAtomMapNum(4)

                mol.RemoveBond(neighbor1_mol.GetIdx(), end_start.GetIdx())
                mol = Chem.RWMol(Chem.CombineMols(mol, patt))
                neighbors = [get_atom_by_atom_map_num(mol, 1), get_atom_by_atom_map_num(mol, 2)]
                neighbor1_mol = get_atom_by_atom_map_num(mol, 3)
                end_start = get_atom_by_atom_map_num(mol, 4)

                mol.AddBond(neighbor1_mol.GetIdx(), neighbors[0].GetIdx(), Chem.BondType.SINGLE)
                mol.AddBond(end_start.GetIdx(), neighbors[1].GetIdx(), Chem.BondType.SINGLE)
                print('success!', remove_atom_map(Chem.MolToSmiles(mol)))
                gt_info[repeat_unit]['double_full_unit'] = remove_atom_map(Chem.MolToSmiles(mol))
                continue

            neighbor0_mol = mol.GetAtomWithIdx(mapping[neighbors[0].GetIdx()])
            end_start = find_end_neighbor(neighbor0_mol, mapping)
            if end_start is not None:
                neighbors[0].SetAtomMapNum(1)
                neighbors[1].SetAtomMapNum(2)
                neighbor0_mol.SetAtomMapNum(3)
                end_start.SetAtomMapNum(4)

                mol.RemoveBond(neighbor0_mol.GetIdx(), end_start.GetIdx())
                mol = Chem.RWMol(Chem.CombineMols(mol, patt))
                neighbors = [get_atom_by_atom_map_num(mol, 1), get_atom_by_atom_map_num(mol, 2)]
                neighbor0_mol = get_atom_by_atom_map_num(mol, 3)
                end_start = get_atom_by_atom_map_num(mol, 4)

                mol.AddBond(neighbor0_mol.GetIdx(), neighbors[1].GetIdx(), Chem.BondType.SINGLE)
                mol.AddBond(end_start.GetIdx(), neighbors[0].GetIdx(), Chem.BondType.SINGLE)
                print('success!', remove_atom_map(Chem.MolToSmiles(mol)))
                gt_info[repeat_unit]['double_full_unit'] = remove_atom_map(Chem.MolToSmiles(mol))
                continue

            assert patt.GetNumAtoms() == mol.GetNumAtoms()
            mol2 = deepcopy(mol)
            neighbor1_mol = mol.GetAtomWithIdx(mapping[neighbors[1].GetIdx()])
            copy_neighbor(mol, neighbor1_mol, neighbors[1], mapping)
            neighbor1_mol.SetAtomMapNum(1)

            neighbor0_mol2 = mol2.GetAtomWithIdx(mapping[neighbors[0].GetIdx()])
            copy_neighbor(mol2, neighbor0_mol2, neighbors[0], mapping)
            neighbor0_mol2.SetAtomMapNum(2)

            mol = Chem.RWMol(Chem.CombineMols(mol, mol2))
            neighbor1_mol = get_atom_by_atom_map_num(mol, 1)
            neighbor0_mol2 = get_atom_by_atom_map_num(mol, 2)
            mol.AddBond(neighbor0_mol2.GetIdx(), neighbor1_mol.GetIdx(), Chem.BondType.SINGLE)
            print('success!', remove_atom_map(Chem.MolToSmiles(mol)))
            gt_info[repeat_unit]['double_full_unit'] = remove_atom_map(Chem.MolToSmiles(mol))

        elif gt_info[repeat_unit]['type'] in ['single_to_double', 'none_to_single', 'multiple_bond_change']:
            full_unit = gt_info[repeat_unit]['full_unit']

            patt = Chem.RWMol(Chem.MolFromSmiles(repeat_unit))

            at_atoms, neighbors = find_AT_and_neighbors(patt)
            remove_AT(patt)

            neighbors[0].SetAtomMapNum(1)
            neighbors[1].SetAtomMapNum(2)

            patt2 = deepcopy(patt)
            get_atom_by_atom_map_num(patt2, 1).SetAtomMapNum(3)
            get_atom_by_atom_map_num(patt2, 2).SetAtomMapNum(4)

            mol = Chem.RWMol(Chem.CombineMols(patt, patt2))
            print(Chem.MolToSmiles(mol))
            mol.AddBond(get_atom_by_atom_map_num(mol, 1).GetIdx(), get_atom_by_atom_map_num(mol, 4).GetIdx(), Chem.BondType.SINGLE)
            mol.AddBond(get_atom_by_atom_map_num(mol, 2).GetIdx(), get_atom_by_atom_map_num(mol, 3).GetIdx(), Chem.BondType.SINGLE)
            print('success!', remove_atom_map(Chem.MolToSmiles(mol)))
            gt_info[repeat_unit]['double_full_unit'] = remove_atom_map(Chem.MolToSmiles(mol))

        else:
            assert False

def compute_forward_prob(gt_info):
    cnt = defaultdict(int)
    tot = defaultdict(int)
    rp = ReactionPrediction()
    for repeat_unit in gt_info:
        # if gt_info[repeat_unit]['from'] == 'reaction':
        full_unit = gt_info[repeat_unit]['full_unit']
        reactants = '%s.%s' % (full_unit, full_unit)
        product = gt_info[repeat_unit]['double_full_unit']

        prob = rp.compute_prob(reactants, product)

        gt_info[repeat_unit]['forward_prob'] = str(prob)

        print(prob, reactants, product)
        if prob > 0:
            cnt[gt_info[repeat_unit]['type']] += 1
            cnt[gt_info[repeat_unit]['from']] += 1
        tot[gt_info[repeat_unit]['type']] += 1
        tot[gt_info[repeat_unit]['from']] += 1

    print(cnt, len(gt_info))
    print(tot)

def compute_backward_prob(gt_info):
    r0_one_step = prepare_mlp(args.mlp_templates_r0, args.mlp_model_dump_r0,
                              -1, polymer=False, use_load_model=True)
    r0_expand_fn = partial(r0_one_step.run,
                           topk=args.expansion_topk, n_cores=args.n_cores)
    cnt = defaultdict(int)
    tot = defaultdict(int)
    for repeat_unit in gt_info:
        # if gt_info[repeat_unit]['from'] == 'reaction':
        full_unit = gt_info[repeat_unit]['full_unit']
        reactants = '%s.%s' % (full_unit, full_unit)
        product = gt_info[repeat_unit]['double_full_unit']

        prob = 0
        out = r0_expand_fn(product)
        for j in range(len(out['reactants'])):
            if out['reactants'][j] == reactants:
                prob = out['scores'][j]

        gt_info[repeat_unit]['backward_prob'] = str(prob)

        print(prob, reactants, product)
        if prob > 0:
            cnt[gt_info[repeat_unit]['type']] += 1
            cnt[gt_info[repeat_unit]['from']] += 1
        tot[gt_info[repeat_unit]['type']] += 1
        tot[gt_info[repeat_unit]['from']] += 1

    print(cnt, len(gt_info))
    print(tot)

def gen_full_unit_addition(repeat_unit):
    patt = Chem.RWMol(Chem.MolFromSmiles(repeat_unit))

    at_atoms, neighbors = find_AT_and_neighbors(patt)
    remove_AT(patt)

    bond = patt.GetBondBetweenAtoms(neighbors[0].GetIdx(), neighbors[1].GetIdx())

    if bond is None:
        patt.AddBond(neighbors[0].GetIdx(), neighbors[1].GetIdx(), Chem.BondType.SINGLE)
    elif bond.GetBondType() == Chem.BondType.SINGLE:
        bond.SetBondType(Chem.BondType.DOUBLE)
    else:
        assert False

    chain_full_unit = Chem.MolToSmiles(patt)

    return chain_full_unit

def compute_chain_growth_prob(gt_info):
    rp = ReactionPrediction()
    r0_one_step = prepare_mlp(args.mlp_templates_r0, args.mlp_model_dump_r0,
                              -1, polymer=False, use_load_model=True)
    r0_expand_fn = partial(r0_one_step.run,
                           topk=args.expansion_topk, n_cores=args.n_cores)

    for repeat_unit in gt_info:
        # gen double unit
        patt = Chem.RWMol(Chem.MolFromSmiles(repeat_unit))

        at_atoms, neighbors = find_AT_and_neighbors(patt)
        remove_AT(patt)

        neighbors[0].SetAtomMapNum(1)
        neighbors[1].SetAtomMapNum(2)

        patt2 = deepcopy(patt)
        get_atom_by_atom_map_num(patt2, 1).SetAtomMapNum(3)
        get_atom_by_atom_map_num(patt2, 2).SetAtomMapNum(4)

        mol = Chem.RWMol(Chem.CombineMols(patt, patt2))
        mol.AddBond(get_atom_by_atom_map_num(mol, 1).GetIdx(), get_atom_by_atom_map_num(mol, 4).GetIdx(),
                    Chem.BondType.SINGLE)
        mol.AddBond(get_atom_by_atom_map_num(mol, 2).GetIdx(), get_atom_by_atom_map_num(mol, 3).GetIdx(),
                    Chem.BondType.SINGLE)

        chain_double_unit = remove_atom_map(Chem.MolToSmiles(mol))

        # gen full unit
        patt = Chem.RWMol(Chem.MolFromSmiles(repeat_unit))

        at_atoms, neighbors = find_AT_and_neighbors(patt)
        remove_AT(patt)

        bond = patt.GetBondBetweenAtoms(neighbors[0].GetIdx(), neighbors[1].GetIdx())

        if bond is None:
            patt.AddBond(neighbors[0].GetIdx(), neighbors[1].GetIdx(), Chem.BondType.SINGLE)
        elif bond.GetBondType() == Chem.BondType.SINGLE:
            bond.SetBondType(Chem.BondType.DOUBLE)
        else:
            assert False

        chain_full_unit = Chem.MolToSmiles(patt)

        reactants = '%s.%s' % (chain_full_unit, chain_full_unit)
        product = chain_double_unit

        # forward prob
        prob = rp.compute_prob(reactants, product)
        gt_info[repeat_unit]['chain_forward_prob'] = str(prob)

        # backward prob
        prob = 0
        out = r0_expand_fn(product)
        for j in range(len(out['reactants'])):
            if out['reactants'][j] == reactants:
                prob = out['scores'][j]
        gt_info[repeat_unit]['chain_backward_prob'] = str(prob)

        print(gt_info[repeat_unit])


if __name__ == '__main__':
    # gt_info = recover_full_unit()
    # create_double_full_unit(gt_info)
    #
    # compute_forward_prob(gt_info)
    # compute_backward_prob(gt_info)
    #
    # compute_chain_growth_prob(gt_info)
    # json.dump(gt_info, open(args.dataset_dir + '/gt_info.json', 'w'))


    # gt_info = json.load(open(args.dataset_dir + '/gt_info.json', 'r'))
    #
    # cond_repeat_units = {k: v for k, v in gt_info.items() if gt_info[k]['type'] == 'normal'}
    # json.dump(cond_repeat_units, open(args.dataset_dir + '/condensation_gt_info.json', 'w'))


    gt_info = json.load(open(args.dataset_dir + '/condensation_gt_info.json', 'r'))
    two_monomer_gt_info = {k: v for k, v in gt_info.items() if gt_info[k]['from'] == 'reaction'}

    cnt = 0
    for k in two_monomer_gt_info:
        two_monomer_gt_info[k]['group'] = cnt // 10
        if cnt < 49:
            cnt += 1

    json.dump(two_monomer_gt_info, open(args.dataset_dir + '/cond2monomer_gt_info.json', 'w'))
