import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.AllChem import AlignMol, GetBestRMS

from .Abstract import *


def compare_conf(smiles: str, source_conf: np.ndarray, target_conf: np.ndarray) -> float:
    source_mol, target_mol = Chem.MolFromSmiles(smiles), Chem.MolFromSmiles(smiles)
    AllChem.EmbedMolecule(source_mol)
    AllChem.EmbedMolecule(target_mol)
    try:
        for i, pos in enumerate(source_conf):
            source_mol.GetConformer().SetAtomPosition(i, [float(pos[0]), float(pos[1]), float(pos[2])])
        for i, pos in enumerate(target_conf):
            target_mol.GetConformer().SetAtomPosition(i, [float(pos[0]), float(pos[1]), float(pos[2])])
        rms = GetBestRMS(source_mol, target_mol)
    except ValueError:
        rms = -1
    return rms


class RMSDComparer(Comparer):
    def __init__(self, *args, **kwargs):
        super(RMSDComparer, self).__init__(*args, **kwargs)

    def encode(self, conf: torch.Tensor, mask_matrices: MaskMatrices, target_conf: torch.Tensor = None,
               *args, **kwargs) -> torch.Tensor:
        return conf

    def compare(self, source_conf: torch.Tensor, target_conf: torch.Tensor, mask_matrices: MaskMatrices, *args,
                **kwargs) -> torch.Tensor:
        smiles_set = kwargs['smiles_set']
        mvw = mask_matrices.mol_vertex_w
        list_rms = []
        for i, smiles in enumerate(smiles_set):
            mask = mvw[i, :]
            p0 = source_conf[mask > 0, :]
            q0 = target_conf[mask > 0, :]
            rms = compare_conf(smiles, p0.cpu().detach(), q0.cpu().detach())
            if rms > -1e-6:
                list_rms.append(rms)
        if len(list_rms):
            loss = torch.mean(torch.FloatTensor(list_rms))
        else:
            loss = torch.tensor(0)
        if self.use_cuda:
            loss = loss.cuda()
        return loss


def compare_list_conf(smiles: str, list_source_conf: List[np.ndarray], list_target_conf: List[np.ndarray],
                      cov_threshold=-1.) -> float:
    list_source_mol = [Chem.MolFromSmiles(smiles) for _ in list_source_conf]
    list_target_mol = [Chem.MolFromSmiles(smiles) for _ in list_target_conf]
    try:
        for j in range(len(list_source_mol)):
            AllChem.EmbedMolecule(list_source_mol[j])
            for i, pos in enumerate(list_source_conf[j]):
                list_source_mol[j].GetConformer().SetAtomPosition(i, [float(pos[0]), float(pos[1]), float(pos[2])])
        for j in range(len(list_target_mol)):
            AllChem.EmbedMolecule(list_target_mol[j])
            for i, pos in enumerate(list_target_conf[j]):
                list_target_mol[j].GetConformer().SetAtomPosition(i, [float(pos[0]), float(pos[1]), float(pos[2])])
        list_rms = []
        for target_mol in list_target_mol:
            lowest_rms = 1e8
            for source_mol in list_source_mol:
                rms = GetBestRMS(source_mol, target_mol)
                lowest_rms = min(rms, lowest_rms)
            list_rms.append(lowest_rms)
    except ValueError:
        return -1.

    if cov_threshold > -1e-6:
        list_cov = []
        for rms in list_rms:
            if rms < cov_threshold + 1e-6:
                list_cov.append(1.)
            else:
                list_cov.append(0.)
        return sum(list_cov) / len(list_cov)
    return sum(list_rms) / len(list_rms)


class COVComparer(MultiComparer):
    def __init__(self, threshold: float, *args, **kwargs):
        super(COVComparer, self).__init__(*args, **kwargs)
        self.threshold = threshold

    def compare(self, list_source_conf: List[torch.Tensor], list_target_conf: List[torch.Tensor],
                mask_matrices: MaskMatrices, *args, **kwargs) -> torch.Tensor:
        smiles = kwargs['smiles']
        p0 = [source_conf.cpu().detach() for source_conf in list_source_conf]
        q0 = [target_conf.cpu().detach() for target_conf in list_target_conf]
        cov = compare_list_conf(smiles, p0, q0, cov_threshold=self.threshold)
        if cov < -1e-6:
            cov = 1.
        loss = torch.tensor(cov)
        if self.use_cuda:
            loss = loss.cuda()
        return loss


class MATComparer(MultiComparer):
    def __init__(self, *args, **kwargs):
        super(MATComparer, self).__init__(*args, **kwargs)

    def compare(self, list_source_conf: List[torch.Tensor], list_target_conf: List[torch.Tensor],
                mask_matrices: MaskMatrices, *args, **kwargs) -> torch.Tensor:
        smiles = kwargs['smiles']
        p0 = [source_conf.cpu().detach() for source_conf in list_source_conf]
        q0 = [target_conf.cpu().detach() for target_conf in list_target_conf]
        mat = compare_list_conf(smiles, p0, q0)
        if mat < -1e-6:
            mat = 0.
        loss = torch.tensor(mat)
        if self.use_cuda:
            loss = loss.cuda()
        return loss
