import os
import numpy as np
import torch
from typing import List, Any
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.AllChem import AlignMol

from data.structures import PackedMolGraph
from data.load_data import dft_mol_positions as get_positions, rdkit_mol_positions as rdkit_positions
from train.recover_conf import recover_conf, recover_multi_conf
from .derive.plt_derive import log_pos_json, plt_derive


def align_to(smiles, source, target):
    source_mol, target_mol = Chem.MolFromSmiles(smiles), Chem.MolFromSmiles(smiles)
    AllChem.EmbedMolecule(source_mol)
    AllChem.EmbedMolecule(target_mol)
    try:
        for i, pos in enumerate(source):
            source_mol.GetConformer().SetAtomPosition(i, [float(pos[0]), float(pos[1]), float(pos[2])])
        for i, pos in enumerate(target):
            target_mol.GetConformer().SetAtomPosition(i, [float(pos[0]), float(pos[1]), float(pos[2])])
        AlignMol(source_mol, target_mol)
    except ValueError:
        return source
    return get_positions(source_mol)


def vis_2d_graph(list_mol: List[Any]):
    d = 'visualize/derive/smiles'
    if not os.path.exists(d):
        os.mkdir(d)
    for i, mol in enumerate(list_mol):
        Draw.MolToFile(mol, f'{d}/m{i}.png')


def vis_real_rdkit_with_mols(list_mol: List[Any]):
    print('For real:')
    for i, mol in enumerate(list_mol):
        smiles = Chem.MolToSmiles(mol)
        print(f'\tProcessing {smiles}...')
        real_pos = get_positions(mol)
        rdkit_pos = rdkit_positions(mol)
        real_pos = align_to(smiles, real_pos, rdkit_pos)
        log_pos_json(real_pos, None, mol, smiles, f'm{i}_real')
        plt_derive(real_pos, None, mol, f'm{i}_real')
        log_pos_json(rdkit_pos, None, mol, smiles, f'm{i}_rdkit')
        plt_derive(rdkit_pos, None, mol, f'm{i}_rdkit')


def vis_derive_with_mols(dataset_name: str, list_mol: List[Any],
                         special_config: dict, use_cuda=False, final_only=False):
    generate = special_config['GENERATE_TYPE']
    derive = special_config['DERIVE_TYPE']
    compare = special_config['COMPARE_TYPE']
    token = f'{generate}-{derive}-{compare}'
    model = recover_conf(special_config, dataset_name, token, use_cuda=use_cuda)
    for i, mol in enumerate(list_mol):
        smiles = Chem.MolToSmiles(mol)
        print(f'\tProcessing {smiles}...')
        pmg = PackedMolGraph([mol])
        return_dict = model.get_derive_states(
            atom_ftr=pmg.atom_ftr,
            bond_ftr=pmg.bond_ftr,
            mask_matrices=pmg.mask_matrices,
            target_pos_ftr=torch.FloatTensor(get_positions(mol)),
            rdkit_pos_ftr=torch.FloatTensor(rdkit_positions(mol))
        )
        list_pos, list_mom = return_dict['list_pos_ftr'], return_dict['list_mom_ftr']
        if final_only:
            title = f'm{i}_{compare}'
            pos = list_pos[-1].detach().numpy()
            log_pos_json(pos, None, mol, smiles, title)
            plt_derive(pos, None, mol, title)
        else:
            for j, (pos, mom) in enumerate(zip(list_pos, list_mom)):
                title = f'm{i}_{compare}_derive_{j}'
                pos = pos.detach().numpy()
                mom = mom.detach().numpy()
                log_pos_json(pos, mom, mol, smiles, title)
                plt_derive(pos, mom, mol, title)


def vis_derive_multi_with_mols(dataset_name: str, list_mol: List[Any],
                               special_config: dict, use_cuda=False):
    generate = special_config['GENERATE_TYPE']
    derive = special_config['DERIVE_TYPE']
    compare = special_config['COMPARE_TYPE']
    token = f'{generate}-{derive}-{compare}'
    model = recover_multi_conf(special_config, dataset_name, token, use_cuda=use_cuda)
    for i, mol in enumerate(list_mol):
        smiles = Chem.MolToSmiles(mol)
        print(f'\tProcessing {smiles}...')
        pmg = PackedMolGraph([mol])
        return_dict = model.get_derive_states(
            atom_ftr=pmg.atom_ftr,
            bond_ftr=pmg.bond_ftr,
            mask_matrices=pmg.mask_matrices,
            target_pos_ftr=torch.FloatTensor(get_positions(mol)),
            rdkit_pos_ftr=torch.FloatTensor(rdkit_positions(mol)),
        )
        list_pos, list_mom = return_dict['list_pos_ftr'], return_dict['list_mom_ftr']
        for j, (pos, mom) in enumerate(zip(list_pos, list_mom)):
            title = f'm{i}_{dataset_name}_{token}_derive_{j}'
            pos = pos.detach().numpy()
            mom = mom.detach().numpy()
            log_pos_json(pos, mom, mol, smiles, title)
            plt_derive(pos, mom, mol, title)
