import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, List, Dict, Any

from data.structures import MaskMatrices
from data.encode import get_massive_from_atom_features
from data.utils import optimize_pos
from .generators import *
from .derivators import *
from .comparers import *
from net.utils import assemble_hierarchical_losses
from net.utils.components import GraphIsomorphismNetwork


class NaiveInitializer(nn.Module):
    def __init__(self, atom_dim: int, bond_dim: int, hv_dim: int, he_dim: int, use_cuda=False):
        super(NaiveInitializer, self).__init__()
        self.use_cuda = use_cuda
        self.v_linear = nn.Linear(atom_dim, hv_dim, bias=True)
        self.v_act = nn.Tanh()
        self.e_linear = nn.Linear(bond_dim, he_dim, bias=True)
        self.e_act = nn.Tanh()

    def forward(self, atom_ftr: torch.Tensor, bond_ftr: torch.Tensor
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        hv_ftr = self.v_act(self.v_linear(atom_ftr))
        he_ftr = self.e_act(self.e_linear(bond_ftr))
        massive = torch.FloatTensor(get_massive_from_atom_features(atom_ftr.cpu().numpy()))
        if self.use_cuda:
            massive = massive.cuda()
        return hv_ftr, he_ftr, massive


class GINInitializer(nn.Module):
    def __init__(self, atom_dim: int, bond_dim: int, hv_dim: int, he_dim: int, h_dim: int = 128, use_cuda=False):
        super(GINInitializer, self).__init__()
        self.use_cuda = use_cuda
        self.gin = GraphIsomorphismNetwork(
            atom_dim=atom_dim,
            bond_dim=bond_dim,
            h_dim=h_dim,
            n_layer=4,
            use_cuda=use_cuda
        )
        self.v_linear = nn.Linear(h_dim, hv_dim, bias=True)
        self.v_act = nn.Tanh()
        self.e_linear = nn.Linear(3 * h_dim, he_dim, bias=True)
        self.e_act = nn.Tanh()

    def forward(self, atom_ftr: torch.Tensor, bond_ftr: torch.Tensor, pos: torch.Tensor, mask_matrices: MaskMatrices
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        bond_ftr, vew1, vew2 = self.gin.extend_graph(bond_ftr, pos, mask_matrices.vertex_edge_w1,
                                                     mask_matrices.vertex_edge_w2, self.use_cuda)
        hv_ftr, he_ftr = self.gin.forward(atom_ftr, bond_ftr, vew1, vew2)
        hv_ftr = self.v_linear(self.v_act(hv_ftr))
        he_ftr = self.e_linear(self.e_act(he_ftr[:mask_matrices.vertex_edge_w1.shape[1], :]))
        massive = torch.FloatTensor(get_massive_from_atom_features(atom_ftr.cpu().numpy()))
        if self.use_cuda:
            massive = massive.cuda()
        return hv_ftr, he_ftr, massive


class ConformationBuilder(nn.Module):
    def __init__(self, hv_dim: int, he_dim: int, pos_dim: int, use_cuda=False, p_dropout=0.0,
                 generate_type='lstm', derive_type='newton', step=4, tau=0.25, direct_mom=True):
        super(ConformationBuilder, self).__init__()
        self.direct_mom = direct_mom
        self.need_derive = derive_type != ''

        if generate_type == 'lstm':
            self.generator = LSTMGenerator(
                hv_dim=hv_dim, he_dim=he_dim, pos_dim=pos_dim,
                need_momentum=self.need_derive,
                use_cuda=use_cuda,
                p_dropout=p_dropout)
        elif generate_type == 'rdkit':
            self.generator = RDKitGenerator(
                hv_dim=hv_dim, he_dim=he_dim, pos_dim=pos_dim,
                need_momentum=self.need_derive,
                use_cuda=use_cuda,
                p_dropout=p_dropout)
        elif generate_type == 'cvgae':
            self.generator = CVGAEGenerator(
                hv_dim=hv_dim, he_dim=he_dim, pos_dim=pos_dim,
                need_momentum=self.need_derive,
                use_cuda=use_cuda,
                p_dropout=p_dropout
            )
        else:
            assert False, f'Undefined generate type {generate_type} in net.conformation.layers.ConformationBuilder'
        assert isinstance(self.generator, Generator)

        if self.need_derive:
            self.tau = tau
            self.step = step
            assert step > 0, f"derive step should be 1 at least, but {step}"
            if derive_type == 'dissipate-hamilton':
                self.derivator = DissipativeHamiltonianDerivator(
                    dissipate=True,
                    v_dim=hv_dim, e_dim=he_dim, pq_dim=pos_dim,
                    use_cuda=use_cuda,
                    p_dropout=p_dropout)
            elif derive_type == 'hamilton':
                self.derivator = DissipativeHamiltonianDerivator(
                    dissipate=False,
                    v_dim=hv_dim, e_dim=he_dim, pq_dim=pos_dim,
                    use_cuda=use_cuda,
                    p_dropout=p_dropout)
            elif derive_type == 'newton':
                self.derivator = NewtonianDerivator(
                    h_dim=2 * hv_dim,
                    v_dim=hv_dim, e_dim=he_dim, pq_dim=pos_dim,
                    use_cuda=use_cuda,
                    p_dropout=p_dropout)
            elif derive_type == 'langevin':
                self.derivator = LangevinDerivator(
                    v_dim=hv_dim, e_dim=he_dim, pq_dim=pos_dim,
                    use_cuda=use_cuda,
                    p_dropout=p_dropout)
            else:
                assert False, f'Undefined derive type {derive_type} in net.conformation.layers.ConformationBuilder'
            assert isinstance(self.derivator, Derivator)

        if pos_dim != 3:
            if isinstance(self.generator, RDKitGenerator):
                def proc(x):
                    try:
                        return x @ torch.FloatTensor(
                            np.linalg.pinv(self.generator.linear.weight.cpu().detach().numpy())).t().type_as(
                            self.generator.linear.weight)
                    except np.linalg.LinAlgError:
                        print(x)
                        assert False
                self.subspace_mapping = proc
            else:
                self.subspace_mapping = nn.Linear(pos_dim, 3, bias=False)
        else:
            self.subspace_mapping = lambda x: x

    def forward(self, hv_ftr: torch.Tensor, he_ftr: torch.Tensor, massive: torch.Tensor, mask_matrices: MaskMatrices,
                return_list: List[str],
                atom_ftr: torch.Tensor = None, bond_ftr: torch.Tensor = None, rdkit_pos_ftr: torch.Tensor = None,
                is_training=False, target_conf: torch.Tensor = None, use_ff=False, smiles_set: List[str] = None
                ) -> Tuple[torch.Tensor, Dict[str, Any]]:
        p_ftr, q_ftr, rd1 = self.generator.forward(hv_ftr, he_ftr, mask_matrices, return_list,
                                                   rdkit_pos_ftr=rdkit_pos_ftr, is_training=is_training,
                                                   target_conf=target_conf)
        list_p_ftr, list_q_ftr = [p_ftr], [q_ftr]

        if self.need_derive:
            for i in range(self.step):
                dp, dq, _ = self.derivator.forward(
                    hv_ftr, he_ftr, massive, list_p_ftr[-1], list_q_ftr[-1], mask_matrices, return_list,
                    atom_ftr=atom_ftr, bond_ftr=bond_ftr)
                if self.direct_mom:
                    list_p_ftr.append(dp)
                    list_q_ftr.append(list_q_ftr[-1] + (dp / massive) * self.tau)
                else:
                    list_p_ftr.append(list_p_ftr[-1] + dp * self.tau)
                    list_q_ftr.append(list_q_ftr[-1] + dq * self.tau)

        pos_ftr = self.subspace_mapping(list_q_ftr[-1])

        return_dict = {}
        if use_ff and smiles_set is not None:
            list_new_pos = []
            for i in range(len(smiles_set)):
                smiles = smiles_set[i]
                mask = mask_matrices.mol_vertex_w[i, :]
                pos = pos_ftr[mask > 0, :]
                pos = optimize_pos(pos, smiles)
                list_new_pos.append(pos)
            return_dict['pos_ftr_ff'] = torch.vstack(list_new_pos)

        return_dict.update(rd1)
        if 'derive' in return_list:
            return_dict['list_p_ftr'] = list_p_ftr
            return_dict['list_q_ftr'] = list_q_ftr
            return_dict['pos_ftr'] = pos_ftr
        if 'middle' in return_list:
            return_dict['list_pos_ftr'] = [self.subspace_mapping(q_ftr) for q_ftr in list_q_ftr]
            return_dict['list_mom_ftr'] = [self.subspace_mapping(p_ftr) for p_ftr in list_p_ftr]
        return pos_ftr, return_dict


class ConformationComparer:
    def __init__(self, compare_type='adj3', dataset_name='qm9', use_cuda=False):
        self.use_cuda = use_cuda

        if compare_type.startswith('adj'):
            self.comparer = DistanceComparer(hop=int(compare_type[3:]), use_cuda=self.use_cuda)
        elif compare_type == 'lddt-score':
            self.comparer = lDDTScore(truncate=15., use_cuda=self.use_cuda)
        elif compare_type.startswith('lddt'):
            self.comparer = lDDTComparer(truncate=float(compare_type[4:]), use_cuda=self.use_cuda)
        elif compare_type == 'naive':
            self.comparer = NaiveComparer(use_cuda=use_cuda)
        elif compare_type == 'kabsch':
            self.comparer = KabschComparer(use_cuda=use_cuda)
        elif compare_type == 'equiv':
            self.comparer = Equivalent(dataset_name=dataset_name)
        elif compare_type == 'equiv-trunc':
            self.comparer = Equivalent(trunc=True, dataset_name=dataset_name)
        elif compare_type == 'equiv-trunc_nof':
            self.comparer = Equivalent(trunc=True, dataset_name=dataset_name, consider_f=False)
        elif compare_type == 'equiv-trunc_nos':
            self.comparer = Equivalent(trunc=True, dataset_name=dataset_name, consider_s=False)
        elif compare_type == 'equiv-trunc_nofs':
            self.comparer = Equivalent(trunc=True, dataset_name=dataset_name, consider_f=False, consider_s=False)
        elif compare_type == 'phi':
            self.comparer = JustPhi()
        elif compare_type == 'psi':
            self.comparer = JustPsi()
        elif compare_type == 'rmsd':
            self.comparer = RMSDComparer()
        else:
            assert False, f'Undefined encode type {compare_type} in net.conformation.layers.ConformationComparer'

    def encode(self, conf: torch.Tensor, mask_matrices: MaskMatrices, target_conf: torch.Tensor = None,
               *args, **kwargs) -> Any:
        return self.comparer.encode(
            conf=conf,
            mask_matrices=mask_matrices,
            target_conf=target_conf,
            *args, **kwargs
        )

    def compare(self, source_conf: torch.Tensor, target_conf: torch.Tensor, mask_matrices: MaskMatrices,
                list_source_conf: List[torch.Tensor] = None, *args, **kwargs) -> torch.Tensor:
        if list_source_conf is not None:
            losses = [self.comparer.compare(
                source_conf=sc,
                target_conf=target_conf,
                mask_matrices=mask_matrices,
                *args, **kwargs
            ) for sc in list_source_conf]
            return assemble_hierarchical_losses(losses, 1.6)
        return self.comparer.compare(
            source_conf=source_conf,
            target_conf=target_conf,
            mask_matrices=mask_matrices,
            *args, **kwargs
        )


class MultiConformationComparer:
    def __init__(self, compare_type, use_cuda=False):
        self.use_cuda = use_cuda
        if compare_type.startswith('adj'):
            self.comparer = MultiDistanceComparer(hop=int(compare_type[3:]), use_cuda=self.use_cuda)
        elif compare_type.startswith('cov'):
            self.comparer = COVComparer(threshold=float(compare_type[3:]), use_cuda=self.use_cuda)
        elif compare_type == 'mat':
            self.comparer = MATComparer(use_cuda=self.use_cuda)
        elif compare_type == 'lddt-score':
            self.comparer = MultilDDTScore(truncate=15., use_cuda=self.use_cuda)
        else:
            assert False, f'Undefined encode type {compare_type} in net.conformation.layers.MultiConformationComparer'

    def compare(self, list_source_conf: List[torch.Tensor], list_target_conf: [torch.Tensor],
                mask_matrices: MaskMatrices, *args, **kwargs) -> torch.Tensor:
        return self.comparer.compare(
            list_source_conf=list_source_conf,
            list_target_conf=list_target_conf,
            mask_matrices=mask_matrices,
            *args, **kwargs
        )
