import numpy as np
import torch
import torch.nn as nn

from data.structures import MaskMatrices
from .messages import *

from typing import List, Tuple, Dict, Any


class GRUUnion(nn.Module):
    def __init__(self, h_dim: int, m_dim: int,
                 use_cuda=False, bias=True):
        super(GRUUnion, self).__init__()
        self.gru_cell = nn.GRUCell(m_dim, h_dim, bias=bias)
        self.relu = nn.LeakyReLU()

    def forward(self, h_ftr: torch.Tensor, m_ftr: torch.Tensor) -> torch.Tensor:
        h_ftr = self.gru_cell(m_ftr, h_ftr)
        h_ftr = self.relu(h_ftr)
        return h_ftr


class GlobalReadout(nn.Module):
    def __init__(self, hm_dim: int, hv_dim: int, mm_dim: int,
                 use_cuda=False, dropout=0.0):
        super(GlobalReadout, self).__init__()
        self.use_cuda = use_cuda
        self.dropout = nn.Dropout(dropout)

        self.attend = nn.Linear(hv_dim, mm_dim)
        self.at_act = nn.LeakyReLU()
        self.align = nn.Linear(hm_dim + hv_dim, 1)
        self.al_act = nn.Softmax(dim=-1)
        self.ag_act = nn.ELU()

    def forward(self, hm_ftr: torch.Tensor, hv_ftr: torch.Tensor,
                mask_matrices: MaskMatrices,
                return_alignment=False) -> Tuple[torch.Tensor, np.ndarray]:
        """
        molecule message readout with global attention and dynamic properties
        :param hm_ftr: molecule features with shape [n_mol, hm_dim]
        :param hv_ftr: vertex features with shape [n_vertex, hv_dim]
        :param mask_matrices: mask matrices
        :param return_alignment: if returns node alignment
        :return: molecule message
        """
        mvw = mask_matrices.mol_vertex_w  # shape [n_mol, n_vertex]
        mvb = mask_matrices.mol_vertex_b  # shape [n_mol, n_vertex]
        hm_v_ftr = mvw.t() @ hm_ftr  # shape [n_vertex, hm_dim]

        attend_ftr = self.attend(self.dropout(hv_ftr))  # shape [n_vertex, mm_dim]
        attend_ftr = self.at_act(attend_ftr)
        align_ftr = self.align(self.dropout(torch.cat([hm_v_ftr, hv_ftr], dim=1)))  # shape [n_vertex, 1]
        align_ftr = mvw @ torch.diag(torch.reshape(align_ftr, [-1])) + mvb  # shape [n_mol, n_vertex]
        align_ftr = self.al_act(align_ftr)
        mm_ftr = self.ag_act(align_ftr @ attend_ftr)  # shape [n_mol, mm_dim]

        return mm_ftr, align_ftr.cpu().detach().numpy() if return_alignment else None


class AtomBondEncoder(nn.Module):
    def __init__(self, atom_dim: int, bond_dim: int, hv_dim: int, he_dim: int):
        super(AtomBondEncoder, self).__init__()
        self.v_linear = nn.Linear(atom_dim, hv_dim, bias=True)
        self.v_act = nn.Tanh()
        self.e_linear = nn.Linear(bond_dim, he_dim, bias=True)
        self.e_act = nn.Tanh()

    def forward(self, atom_ftr: torch.Tensor, bond_ftr: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        hv_ftr = self.v_act(self.v_linear(atom_ftr))
        he_ftr = self.e_act(self.e_linear(bond_ftr))
        return hv_ftr, he_ftr


class GeometryMessagePassing(nn.Module):
    def __init__(self, hv_dim: int, he_dim: int, mv_dim: int, me_dim: int, pos_dim: int, hops: int,
                 use_cuda=False, p_dropout=0.0, message_type='naive'):
        super(GeometryMessagePassing, self).__init__()
        self.use_cuda = use_cuda
        self.message_type = message_type
        self.hops = hops

        if message_type == 'naive':
            self.messages = nn.ModuleList([
                NaiveMessage(hv_dim, he_dim, mv_dim, me_dim, pos_dim, use_cuda, p_dropout)
                for _ in range(hops)
            ])
        elif message_type == 'naive_position':
            self.messages = nn.ModuleList([
                NaivePositionMessage(hv_dim, he_dim, mv_dim, me_dim, pos_dim, use_cuda, p_dropout)
                for _ in range(hops)
            ])
        elif message_type == 'distance':
            self.messages = nn.ModuleList([
                DistanceMessage(hv_dim, he_dim, mv_dim, me_dim, pos_dim, use_cuda, p_dropout)
                for _ in range(hops)
            ])
        elif message_type == 'distance_angle':
            self.messages = nn.ModuleList([
                DistanceAngleMessage(hv_dim, he_dim, mv_dim, me_dim, pos_dim, use_cuda, p_dropout)
                for _ in range(hops)
            ])
        else:
            assert False, f'Undefined message type {message_type} in net.property.layers.GeometryMessagePassing'

        self.unions_v = nn.ModuleList([GRUUnion(hv_dim, mv_dim, use_cuda) for _ in range(hops)])
        self.unions_e = nn.ModuleList([GRUUnion(he_dim, me_dim, use_cuda) for _ in range(hops)])

    def forward(self, hv_ftr: torch.Tensor, he_ftr: torch.Tensor, pos_ftr: torch.Tensor, mask_matrices: MaskMatrices,
                return_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor, List[Dict[str, Any]]]:
        list_return_dict = []
        for i in range(self.hops):
            assert isinstance(self.messages[i], Message)
            mv_ftr, me_ftr, return_dict = self.messages[i].forward(hv_ftr, he_ftr, pos_ftr,
                                                                   mask_matrices, return_list)
            hv_ftr = self.unions_v[i](hv_ftr, mv_ftr)
            he_ftr = self.unions_e[i](he_ftr, me_ftr)
            list_return_dict.append(return_dict)
        return hv_ftr, he_ftr, list_return_dict


class Readout(nn.Module):
    def __init__(self, hm_dim: int, hv_dim: int, mm_dim: int, iteration: int,
                 use_cuda=False, p_dropout=0.0):
        super(Readout, self).__init__()
        self.use_cuda = use_cuda

        self.vertex2mol = nn.Linear(hv_dim, hm_dim, bias=True)
        self.vm_act = nn.LeakyReLU()
        self.readout = GlobalReadout(hm_dim, hv_dim, mm_dim, use_cuda, p_dropout)
        self.union = GRUUnion(hm_dim, mm_dim, use_cuda)
        self.iteration = iteration

    def forward(self, hv_ftr: torch.Tensor,
                mask_matrices: MaskMatrices,
                return_alignment=False) -> Tuple[torch.Tensor, List[np.ndarray]]:
        # initialize molecule features with mean of vertex features
        mvw = mask_matrices.mol_vertex_w
        norm_mvw = mvw / torch.sum(mvw, dim=-1, keepdim=True)
        hm_ftr = norm_mvw @ self.vm_act(self.vertex2mol(hv_ftr))

        # iterate
        alignments = []
        for i in range(self.iteration):
            mm_ftr, alignment = self.readout.forward(hm_ftr, hv_ftr, mask_matrices, return_alignment)
            hm_ftr = self.union.forward(hm_ftr, mm_ftr)
            alignments.append(alignment)

        return hm_ftr, alignments
