import json
import torch.nn.functional as F

from data.load_data import SupportedDatasets
from data.load_multi_data import SupportedMultiDatasets
from .Abstract import *
from net.utils.d_phi_psi_encode import d_phi_psi_encode, d_phi_psi_trunc_encode

HYPER_DIR = 'data/hyper'


class Equivalent(Comparer):
    def __init__(self, trunc=False, dataset_name: str = SupportedDatasets.QM9,
                 consider_f=True, consider_s=True, *args, **kwargs):
        super(Equivalent, self).__init__(*args, **kwargs)
        self.func = d_phi_psi_trunc_encode if trunc else d_phi_psi_encode

        def unpack(file_name, factor=1.0):
            with open(f'{HYPER_DIR}/{file_name}.json') as fp:
                d = json.load(fp)
            if consider_f and consider_s:
                return factor * d['d'], factor * d['phi'], factor * d['psi']
            elif consider_f:
                return factor * d['f_d'], factor * d['f_phi'], factor * d['f_psi']
            elif consider_s:
                return factor * d['s_d'], factor * d['s_phi'], factor * d['s_psi']
            else:
                return factor, factor, factor

        if dataset_name in [SupportedDatasets.QM7, SupportedDatasets.QM8, SupportedDatasets.QM9]:
            self.GAMMA_D, self.GAMMA_PHI, self.GAMMA_PSI = unpack(SupportedDatasets.QM9)
        elif dataset_name in [SupportedMultiDatasets.GEOM_QM9_SMALL, SupportedMultiDatasets.GEOM_QM9]:
            self.GAMMA_D, self.GAMMA_PHI, self.GAMMA_PSI = unpack(SupportedMultiDatasets.GEOM_QM9_SMALL)
        elif dataset_name in [SupportedMultiDatasets.GEOM_DRUGS_SMALL, SupportedMultiDatasets.GEOM_DRUGS]:
            self.GAMMA_D, self.GAMMA_PHI, self.GAMMA_PSI = unpack(SupportedMultiDatasets.GEOM_DRUGS_SMALL)
        else:
            try:
                self.GAMMA_D, self.GAMMA_PHI, self.GAMMA_PSI = unpack(dataset_name)
            except FileNotFoundError or EOFError:
                assert False, f'Hyper-parameters for dataset {dataset_name} have not been generated yet, ' \
                              f'please configurate and run `script_conf_feature.py` first.'
        print(f'\tUse lambda parameters from dataset {dataset_name}:')
        print(f'\t\tLambda_D: {self.GAMMA_D}')
        print(f'\t\tLambda_Phi: {self.GAMMA_PHI}')
        print(f'\t\tLambda_Psi: {self.GAMMA_PSI}')

    def encode(self, conf: torch.Tensor, mask_matrices: MaskMatrices, target_conf: torch.Tensor = None,
               *args, **kwargs) -> torch.Tensor:
        extra_dict = kwargs['extra_dict']
        d, trunc_phi, trunc_psi, _ = self.func(
            mask_matrices=mask_matrices,
            pos=conf,
            extra_dict=extra_dict
        )
        return torch.cat([self.GAMMA_D * d, self.GAMMA_PHI * trunc_phi, self.GAMMA_PSI * trunc_psi])

    def compare(self, source_conf: torch.Tensor, target_conf: torch.Tensor, mask_matrices: MaskMatrices, *args,
                **kwargs) -> torch.Tensor:
        extra_dict = kwargs['extra_dict']
        s_d, s_phi, s_psi, _ = self.func(
            mask_matrices=mask_matrices,
            pos=source_conf,
            extra_dict=extra_dict
        )
        t_d, t_phi, t_psi, _ = self.func(
            mask_matrices=mask_matrices,
            pos=target_conf,
            extra_dict=extra_dict
        )
        l_d = F.mse_loss(s_d, t_d)
        l_phi = F.mse_loss(s_phi, t_phi)
        l_psi = F.mse_loss(s_psi, t_psi)
        return l_d * self.GAMMA_D ** 2 + l_phi * self.GAMMA_PHI ** 2 + l_psi * self.GAMMA_PSI ** 2


# class MultiEquivalent(MultiComparer):
#     GAMMA_D = Equivalent.GAMMA_D
#     GAMMA_PHI = Equivalent.GAMMA_PHI
#     GAMMA_PSI = Equivalent.GAMMA_PSI
#
#     def __init__(self, trunc=False, *args, **kwargs):
#         super(MultiEquivalent, self).__init__(*args, **kwargs)
#         self.func = d_phi_psi_trunc_encode if trunc else d_phi_psi_encode
#
#     def encode(self, conf: torch.Tensor, mask_matrices: MaskMatrices, extra_dict: Dict[str, Any]
#                ) -> torch.Tensor:
#         d, trunc_phi, trunc_psi, _ = self.func(
#             mask_matrices=mask_matrices,
#             pos=conf,
#             extra_dict=extra_dict
#         )
#         return torch.cat([self.GAMMA_D * d, self.GAMMA_PHI * trunc_phi, self.GAMMA_PSI * trunc_psi])
#
#     def compare(self, list_source_conf: List[torch.Tensor], target_conf: torch.Tensor,
#                 mask_matrices: MaskMatrices, *args, **kwargs) -> torch.Tensor:
#         source_encodes = torch.vstack([self.encode(
#             conf=conf,
#             mask_matrices=mask_matrices,
#             extra_dict=kwargs['extra_dict'],
#         ) for conf in list_source_conf])
#         target_encode = self.encode(
#             conf=target_conf,
#             mask_matrices=mask_matrices,
#             extra_dict=kwargs['extra_dict'],
#         )
#         mse = torch.min(torch.sum((source_encodes - target_encode.unsqueeze(dim=0)) ** 2, dim=1))
#         return mse
