import numpy as np
import torch
import torch.nn as nn

from typing import Tuple, List, Dict, Any

from data.structures import MaskMatrices
from net.property.layers import AtomBondEncoder, GeometryMessagePassing, Readout


class PropertyModel(nn.Module):
    def __init__(self, atom_dim: int, bond_dim: int, config: dict, use_cuda=False):
        super(PropertyModel, self).__init__()
        hv_dim = config['HV_DIM']
        he_dim = config['HE_DIM']
        hm_dim = config['HM_DIM']
        mv_dim = config['MV_DIM']
        me_dim = config['ME_DIM']
        mm_dim = config['MM_DIM']
        pos_dim = config['POS_DIM']
        self.n_layer = config['N_LAYER']
        n_hop = config['N_HOP']
        n_global = config['N_GLOBAL']
        message_type = config['MESSAGE_TYPE']
        p_dropout = config['DROPOUT']
        self.use_cuda = use_cuda

        self.atom_bond_encoder = AtomBondEncoder(
            atom_dim=atom_dim,
            bond_dim=bond_dim,
            hv_dim=hv_dim,
            he_dim=he_dim,
        )
        self.mp_kernel = GeometryMessagePassing(
            hv_dim=hv_dim,
            he_dim=he_dim,
            mv_dim=mv_dim,
            me_dim=me_dim,
            pos_dim=pos_dim,
            hops=n_hop,
            use_cuda=use_cuda,
            p_dropout=p_dropout,
            message_type=message_type,
        )
        self.readout = Readout(
            hm_dim=hm_dim,
            hv_dim=hv_dim,
            mm_dim=mm_dim,
            iteration=n_global,
            use_cuda=use_cuda,
            p_dropout=p_dropout
        )

    def forward(self, atom_ftr: torch.Tensor, bond_ftr: torch.Tensor,
                mask_matrices: MaskMatrices, pos_ftr: torch.Tensor,
                return_local_alignment=False, return_global_alignment=False
                ) -> Tuple[torch.Tensor, List[List[Dict[str, Any]]], List[np.ndarray], List[np.ndarray]]:
        hv_ftr, he_ftr = self.atom_bond_encoder.forward(atom_ftr, bond_ftr)

        return_list = []
        if return_local_alignment:
            return_list.append('alignment')

        list_alignments = []
        list_he_ftr = []
        for i in range(self.n_layer):
            hv_ftr, he_ftr, alignments = self.mp_kernel.forward(hv_ftr, he_ftr, pos_ftr, mask_matrices, return_list)
            list_alignments.append(alignments)
            list_he_ftr.append(he_ftr.cpu().detach().numpy())

        fingerprint, global_alignments = self.readout.forward(hv_ftr, mask_matrices, return_global_alignment)
        return fingerprint, list_alignments, global_alignments, list_he_ftr
