import torch
import tqdm
import json
from typing import List, Any
from rdkit import Chem

from data.structures import PackedMolGraph
from data.load_data import dft_mol_positions as get_positions, rdkit_mol_positions as rdkit_positions
from net.conformation.comparers.RMSD import compare_conf
from train.recover_conf import recover_conf

LOG_PATH = 'visualize/derive/log.json'
# [RDKit, equiv-trunc, 'adj3', 'kabsch']


def find_good_mol(dataset_name: str, list_mol: List[Any], list_special_config: List[dict], use_cuda=False):
    n_mol = len(list_mol)
    models = []
    for special_config in list_special_config:
        generate = special_config['GENERATE_TYPE']
        derive = special_config['DERIVE_TYPE']
        compare = special_config['COMPARE_TYPE']
        token = f'{generate}-{derive}-{compare}'
        model = recover_conf(special_config, dataset_name, token, use_cuda=use_cuda)
        models.append(model)

    log = []
    best_profit = 0
    for i, mol in enumerate(list_mol):
        if i % 1000 == 999:
            print(f'\tProcessing: {i + 1}/{n_mol}')
            with open(LOG_PATH, 'w+') as fp:
                json.dump(log, fp)
        pmg = PackedMolGraph([mol])
        smiles = Chem.MolToSmiles(mol)
        target_pos = get_positions(mol)
        rdkit_pos = rdkit_positions(mol)

        list_rmsd = [compare_conf(smiles, rdkit_pos, target_pos)]
        for model in models:
            return_dict = model.get_derive_states(
                atom_ftr=pmg.atom_ftr,
                bond_ftr=pmg.bond_ftr,
                mask_matrices=pmg.mask_matrices,
                target_pos_ftr=torch.FloatTensor(target_pos),
                rdkit_pos_ftr=torch.FloatTensor(rdkit_pos)
            )
            pos = return_dict['list_pos_ftr'][-1].detach().numpy()
            rmsd = compare_conf(smiles, pos, target_pos)
            list_rmsd.append(rmsd)
        if min(list_rmsd) < 0:
            continue
        log.append({
            'no': i,
            'smiles': smiles,
            'list_rmsd': list_rmsd
        })
        profit = min([list_rmsd[0] - list_rmsd[1], list_rmsd[2] - list_rmsd[1], list_rmsd[3] - list_rmsd[1]])
        if profit > 0.4 and list_rmsd[1] < 0.5 or profit > best_profit:
            best_profit = max(profit, best_profit)
            print(f'\t\t{profit:.3f} #{i}-{smiles}: {list_rmsd[0]:.3f}, {list_rmsd[1]:.3f}, {list_rmsd[2]:.3f}, '
                  f'{list_rmsd[3]:.3f}')
