import time
import torch
import torch.nn as nn
from typing import Tuple, List, Dict, Any
# from rdkit.Chem.rdchem.Mol import Mol as Molecule

from data.structures import MaskMatrices
from data.encode import get_massive_from_atom_features
from net.conformation.layers import NaiveInitializer, GINInitializer, ConformationBuilder, \
    ConformationComparer, MultiConformationComparer


class ConformationModel(nn.Module):
    def __init__(self, atom_dim: int, bond_dim: int, config: dict, dataset_name='qm9', use_cuda=False):
        super(ConformationModel, self).__init__()
        self.use_cuda = use_cuda
        hv_dim = config['HV_DIM']
        he_dim = config['HE_DIM']
        pos_dim = config['POS_DIM']
        p_dropout = config['DROPOUT']
        self.generate_type = generate_type = config['GENERATE_TYPE']
        derive_type = config['DERIVE_TYPE']
        compare_type = config['COMPARE_TYPE']
        step = config['STEP']
        tau = config['TAU']
        self.ff = config['FF']

        self.builder_return_list = []

        self.epoch_switch_comparer_to_equiv_trunc = -1
        # if compare_type == 'equiv-trunc':
        #     compare_type = 'lddt5'
        #     self.epoch_switch_comparer_to_equiv_trunc = int(config['EPOCH'] / 5) + 1
        self.compare_middle = config['COMPARE_MIDDLE'] and derive_type != ''

        if self.compare_middle:
            self.builder_return_list.append('middle')

        self.initializer = GINInitializer(
            atom_dim=atom_dim,
            bond_dim=bond_dim,
            hv_dim=hv_dim,
            he_dim=he_dim,
            use_cuda=use_cuda
        )
        self.conformation_builder = ConformationBuilder(
            hv_dim=hv_dim,
            he_dim=he_dim,
            pos_dim=pos_dim,
            use_cuda=use_cuda,
            p_dropout=p_dropout,
            generate_type=generate_type,
            derive_type=derive_type,
            step=step,
            tau=tau
        )
        self.conformation_comparer = ConformationComparer(
            compare_type=compare_type,
            dataset_name=dataset_name,
            use_cuda=use_cuda
        )
        self.evaluate_comparers = {
            'd': ConformationComparer(
                compare_type='adj1',
                use_cuda=use_cuda
            ),
            'phi': ConformationComparer(
                compare_type='phi',
                use_cuda=use_cuda
            ),
            'psi': ConformationComparer(
                compare_type='psi',
                use_cuda=use_cuda
            ),
            'lddt-score': ConformationComparer(
                compare_type='lddt-score',
                use_cuda=use_cuda
            ),
            'adj-1': ConformationComparer(
                compare_type='adj-1',
                use_cuda=use_cuda
            ),
            'rmsd': ConformationComparer(
                compare_type='rmsd',
                use_cuda=use_cuda
            )
        }
        # if compare_type == 'equiv':
        #     self.evaluate_comparers.update({
        #         'equiv': ConformationComparer(
        #             compare_type='equiv',
        #             use_cuda=use_cuda
        #         ),
        #     })
        # if compare_type == 'equiv-trunc':
        #     self.evaluate_comparers.update({
        #         'equiv-trunc': ConformationComparer(
        #             compare_type='equiv-trunc',
        #             use_cuda=use_cuda
        #         ),
        #     })

    def forward(self, atom_ftr: torch.Tensor, bond_ftr: torch.Tensor, mask_matrices: MaskMatrices,
                target_pos_ftr: torch.Tensor, rdkit_pos_ftr: torch.Tensor, extra_dict: Dict[str, torch.Tensor] = None
                ) -> torch.Tensor:
        hv_ftr, he_ftr, massive = self.initializer.forward(atom_ftr, bond_ftr, rdkit_pos_ftr, mask_matrices)
        pos_ftr, rd = self.conformation_builder.forward(
            hv_ftr, he_ftr, massive, mask_matrices, self.builder_return_list,
            atom_ftr=atom_ftr, bond_ftr=bond_ftr, rdkit_pos_ftr=rdkit_pos_ftr)
        loss = self.conformation_comparer.compare(
            pos_ftr, target_pos_ftr, mask_matrices,
            list_source_conf=rd['list_pos_ftr'][1:] if 'middle' in self.builder_return_list else None,
            massive=massive, extra_dict=extra_dict)
        return loss

    def evaluate(self, atom_ftr: torch.Tensor, bond_ftr: torch.Tensor, mask_matrices: MaskMatrices,
                 target_pos_ftr: torch.Tensor, rdkit_pos_ftr: torch.Tensor, smiles_set: list,
                 extra_dict: Dict[str, torch.Tensor] = None
                 ) -> Dict[str, float]:
        hv_ftr, he_ftr, massive = self.initializer.forward(atom_ftr, bond_ftr, rdkit_pos_ftr, mask_matrices)
        pos_ftr, _ = self.conformation_builder.forward(
            hv_ftr, he_ftr, massive, mask_matrices, [],
            atom_ftr=atom_ftr, bond_ftr=bond_ftr,
            rdkit_pos_ftr=rdkit_pos_ftr, use_ff=self.ff, smiles_set=smiles_set
        )
        return_dict = {}
        for k, v in self.evaluate_comparers.items():
            # t0 = time.time()
            loss = v.compare(pos_ftr, target_pos_ftr, mask_matrices,
                             massive=massive, smiles_set=smiles_set, extra_dict=extra_dict)
            # print(f'{k}: {time.time() - t0}')
            return_dict[k] = float(loss)
        return return_dict

    def get_intermediate_dict(self, atom_ftr: torch.Tensor, bond_ftr: torch.Tensor, mask_matrices: MaskMatrices,
                              target_pos_ftr: torch.Tensor, rdkit_pos_ftr: torch.Tensor, smiles_set: list,
                              extra_dict: Dict[str, torch.Tensor] = None, return_list: List[str] = None
                              ) -> Dict[str, Any]:
        return_dict = {}
        hv_ftr, he_ftr, massive = self.initializer.forward(atom_ftr, bond_ftr, rdkit_pos_ftr, mask_matrices)
        pos_ftr, return_dict_2 = self.conformation_builder.forward(hv_ftr, he_ftr, massive, mask_matrices,
                                                                   return_list=return_list,
                                                                   atom_ftr=atom_ftr, bond_ftr=bond_ftr,
                                                                   rdkit_pos_ftr=rdkit_pos_ftr)
        # loss = self.conformation_comparer.compare(pos_ftr, target_pos_ftr, mask_matrices,
        #                                           massive=massive, smiles_set=smiles_set, extra_dict=extra_dict)
        return_dict.update(return_dict_2)
        return return_dict

    def get_derive_states(self, atom_ftr: torch.Tensor, bond_ftr: torch.Tensor, mask_matrices: MaskMatrices,
                          target_pos_ftr: torch.Tensor, rdkit_pos_ftr: torch.Tensor) -> Dict[str, Any]:
        return self.get_intermediate_dict(atom_ftr, bond_ftr, mask_matrices, target_pos_ftr, rdkit_pos_ftr, [],
                                          return_list=['middle'])

    def directly_compare(self, atom_ftr: torch.Tensor, bond_ftr: torch.Tensor, mask_matrices: MaskMatrices,
                         target_pos_ftr: torch.Tensor, rdkit_pos_ftr: torch.Tensor, smiles_set: list,
                         extra_dict: Dict[str, torch.Tensor] = None
                         ) -> Dict[str, float]:
        massive = torch.FloatTensor(get_massive_from_atom_features(atom_ftr.cpu().numpy()))
        if self.use_cuda:
            massive = massive.cuda()
        return_dict = {}
        for k, v in self.evaluate_comparers.items():
            # t0 = time.time()
            loss = v.compare(rdkit_pos_ftr, target_pos_ftr, mask_matrices,
                             massive=massive, smiles_set=smiles_set, extra_dict=extra_dict)
            # print(f'{k}: {time.time() - t0}')
            return_dict[k] = float(loss)
        return return_dict

    def adapt(self, epoch: int):
        if epoch == self.epoch_switch_comparer_to_equiv_trunc:
            print('\tSwitch to equiv_trunc comparer.')
            self.conformation_comparer = ConformationComparer(
                compare_type='equiv-trunc',
                use_cuda=self.use_cuda
            )
        # if epoch == 2:
        #     if self.generate_type == 'rdkit':
        #         self.conformation_builder.generator.requires_grad_(False)


class MultiConformationModel(nn.Module):
    def __init__(self, atom_dim: int, bond_dim: int, config: dict, dataset_name='geom_qm9', use_cuda=False):
        super(MultiConformationModel, self).__init__()
        self.use_cuda = use_cuda
        hv_dim = config['HV_DIM']
        he_dim = config['HE_DIM']
        pos_dim = config['POS_DIM']
        p_dropout = config['DROPOUT']
        self.generate_type = generate_type = config['GENERATE_TYPE']
        derive_type = config['DERIVE_TYPE']
        self.compare_type = compare_type = config['COMPARE_TYPE']
        step = config['STEP']
        tau = config['TAU']
        self.ff = config['FF']
        self.include_rdkit = config['INCLUDE_RDKIT']

        self.builder_return_list = []

        self.epoch_switch_comparer_to_equiv_trunc = -1
        # if compare_type == 'equiv-trunc':
        #     compare_type = 'lddt5'
        #     self.epoch_switch_comparer_to_equiv_trunc = int(config['EPOCH'] / 5) + 1
        self.compare_middle = config['COMPARE_MIDDLE'] and derive_type != ''

        if self.compare_middle:
            self.builder_return_list.append('middle')

        self.initializer = GINInitializer(
            atom_dim=atom_dim,
            bond_dim=bond_dim,
            hv_dim=hv_dim,
            he_dim=he_dim,
            use_cuda=use_cuda
        )
        self.conformation_builder = ConformationBuilder(
            hv_dim=hv_dim,
            he_dim=he_dim,
            pos_dim=pos_dim,
            use_cuda=use_cuda,
            p_dropout=p_dropout,
            generate_type=generate_type,
            derive_type=derive_type,
            step=step,
            tau=tau
        )
        self.conformation_comparer = ConformationComparer(
            compare_type=compare_type,
            dataset_name=dataset_name,
            use_cuda=use_cuda
        )
        self.evaluate_comparers = {
            'adj3': MultiConformationComparer(
                compare_type='adj3',
                use_cuda=use_cuda
            ),
            'cov0.5': MultiConformationComparer(
                compare_type='cov0.5',
                use_cuda=use_cuda
            ),
            'cov1.25': MultiConformationComparer(
                compare_type='cov1.25',
                use_cuda=use_cuda
            ),
            'mat': MultiConformationComparer(
                compare_type='mat',
                use_cuda=use_cuda
            ),
            'lddt-score': MultiConformationComparer(
                compare_type='lddt-score',
                use_cuda=use_cuda
            )
        }

    def forward(self, atom_ftr: torch.Tensor, bond_ftr: torch.Tensor, mask_matrices: MaskMatrices,
                target_conf: torch.Tensor, rdkit_conf: torch.Tensor,
                extra_dict: Dict[str, torch.Tensor] = None) -> torch.Tensor:
        hv_ftr, he_ftr, massive = self.initializer.forward(atom_ftr, bond_ftr, rdkit_conf, mask_matrices)
        pos_ftr, rd = self.conformation_builder.forward(
            hv_ftr, he_ftr, massive, mask_matrices, self.builder_return_list,
            atom_ftr=atom_ftr, bond_ftr=bond_ftr, rdkit_pos_ftr=rdkit_conf,
            is_training=True, target_conf=target_conf
        )
        loss = self.conformation_comparer.compare(
            pos_ftr, target_conf, mask_matrices,
            list_source_conf=rd['list_pos_ftr'] if 'middle' in self.builder_return_list else None,
            massive=massive, extra_dict=extra_dict
        )
        if 'kld_z_loss' in rd.keys() and 'kld_0_loss' in rd.keys():  # CVGAE
            if self.compare_type in ['equiv', 'equiv-trunc']:
                return loss + rd['kld_z_loss'] * 50.0 + rd['kld_0_loss'] * 5e-2
            return loss + rd['kld_z_loss'] * 1.0 + rd['kld_0_loss'] * 1e-3
        return loss

    def evaluate(self, atom_ftr: torch.Tensor, bond_ftr: torch.Tensor, mask_matrices: MaskMatrices,
                 list_target_conf: List[torch.Tensor], list_rdkit_conf: List[torch.Tensor], smiles: str,
                 extra_dict: Dict[str, List[torch.Tensor]] = None) -> Dict[str, float]:
        list_pos_ftr = []
        if self.include_rdkit:
            list_pos_ftr.extend(list_rdkit_conf)
        massive = None
        for rdkit_conf in list_rdkit_conf:
            hv_ftr, he_ftr, massive = self.initializer.forward(atom_ftr, bond_ftr, rdkit_conf, mask_matrices)
            pos_ftr, rd = self.conformation_builder.forward(
                hv_ftr, he_ftr, massive, mask_matrices, [],
                atom_ftr=atom_ftr, bond_ftr=bond_ftr,
                rdkit_pos_ftr=rdkit_conf, use_ff=self.ff, smiles_set=[smiles]
            )
            if 'pos_ftr_ff' in rd.keys():
                list_pos_ftr.append(rd['pos_ftr_ff'])
            else:
                list_pos_ftr.append(pos_ftr.detach())
        assert massive is not None

        return_dict = {}
        for k, v in self.evaluate_comparers.items():
            # t0 = time.time()
            loss = v.compare(list_pos_ftr, list_target_conf, mask_matrices,
                             massive=massive, smiles=smiles, extra_dict=extra_dict)
            # print(f'{k}: {time.time() - t0}')
            return_dict[k] = float(loss)
        return return_dict

    def get_intermediate_dict(self, atom_ftr: torch.Tensor, bond_ftr: torch.Tensor, mask_matrices: MaskMatrices,
                              target_pos_ftr: torch.Tensor, rdkit_pos_ftr: torch.Tensor, smiles_set: list,
                              extra_dict: Dict[str, torch.Tensor] = None, return_list: List[str] = None
                              ) -> Dict[str, Any]:
        return_dict = {}
        hv_ftr, he_ftr, massive = self.initializer.forward(atom_ftr, bond_ftr, rdkit_pos_ftr, mask_matrices)
        pos_ftr, return_dict_2 = self.conformation_builder.forward(hv_ftr, he_ftr, massive, mask_matrices,
                                                                   return_list=return_list,
                                                                   atom_ftr=atom_ftr, bond_ftr=bond_ftr,
                                                                   rdkit_pos_ftr=rdkit_pos_ftr)
        # loss = self.conformation_comparer.compare(pos_ftr, target_pos_ftr, mask_matrices,
        #                                           massive=massive, smiles_set=smiles_set, extra_dict=extra_dict)
        return_dict.update(return_dict_2)
        return return_dict

    def get_derive_states(self, atom_ftr: torch.Tensor, bond_ftr: torch.Tensor, mask_matrices: MaskMatrices,
                          target_pos_ftr: torch.Tensor, rdkit_pos_ftr: torch.Tensor) -> Dict[str, Any]:
        return self.get_intermediate_dict(atom_ftr, bond_ftr, mask_matrices, target_pos_ftr, rdkit_pos_ftr, [],
                                          return_list=['middle'])

    def directly_compare(self, atom_ftr: torch.Tensor, bond_ftr: torch.Tensor, mask_matrices: MaskMatrices,
                         list_target_conf: List[torch.Tensor], list_rdkit_conf: List[torch.Tensor], smiles: str,
                         extra_dict: Dict[str, List[torch.Tensor]] = None) -> Dict[str, float]:
        massive = torch.FloatTensor(get_massive_from_atom_features(atom_ftr.cpu().numpy()))
        if self.use_cuda:
            massive = massive.cuda()
        return_dict = {}
        for k, v in self.evaluate_comparers.items():
            # t0 = time.time()
            loss = v.compare(list_rdkit_conf, list_target_conf, mask_matrices,
                             massive=massive, smiles=smiles, extra_dict=extra_dict)
            # print(f'{k}: {time.time() - t0}')
            return_dict[k] = float(loss)
        return return_dict

    def adapt(self, epoch: int):
        if epoch == self.epoch_switch_comparer_to_equiv_trunc:
            print('\tSwitch to equiv_trunc comparer.')
            self.conformation_comparer = ConformationComparer(
                compare_type='equiv-trunc',
                use_cuda=self.use_cuda
            )
        # if epoch == 2:
        #     if self.generate_type == 'rdkit':
        #         print(f'\t\tFreeze rdkit linear.')
        #         self.conformation_builder.generator.requires_grad_(False)
