from typing import *

import torch
from torch import Tensor, nn

from utils import assert_shape, assert_type


class HAT(nn.Module):
    def __init__(self, num_tasks: int, dim_emb: int):
        super(HAT, self).__init__()
        self.dim_emb = dim_emb
        self.emb = nn.Embedding(num_tasks, dim_emb)
        self.gate = nn.Sigmoid()

        self.mask_pre = None
        self.history_mask = dict()  # type: Dict[int, Tensor]
    # enddef

    def mask(self, idx_task: int, s: float) -> Tensor:
        m = self._mask1(idx_task, s)
        assert_shape(m, self.dim_emb)

        return m
    # enddef

    def _mask1(self, idx_task: int, s: float) -> Tensor:
        e_t = self.emb(torch.tensor(idx_task, device=self.emb.weight.device))
        a_t = self.gate(s * e_t)

        return a_t
    # enddef

    def freeze_masks(self, idx_task: int, smax: float):
        thres_emb = 6
        self.emb.weight.data = torch.clamp(self.emb.weight.data,
                                           -thres_emb, thres_emb)

        mask = self.mask(idx_task, s=smax).data.clone()
        if self.mask_pre is None:
            self.mask_pre = mask
        else:
            self.mask_pre = torch.max(self.mask_pre, mask)
        # endif

        self.history_mask[idx_task] = mask.data.clone()
    # enddef

    def forward(self, idx_task: int, x: Tensor, s: float, args_on_forward: Dict[str, Any]) -> Tensor:
        assert_type(idx_task, int)
        assert_type(x, Tensor)

        batch_size = x.shape[0]
        dims_emb = x.shape[1:]
        assert_shape(x,
                     batch_size, *dims_emb)

        self.emb.weight.data = torch.clamp(self.emb.weight.data, -6, 6)

        a_t = self.mask(idx_task, s)
        assert_shape(a_t, -1)

        # '''
        if 'freeze_mask' in args_on_forward.keys() and args_on_forward['freeze_mask']:
            a_t = a_t.data.clone()
        # endif
        # '''

        if x.dim() == 4:
            a_t = a_t.view(1, -1, 1, 1).expand_as(x)
        elif x.dim() == 2:
            a_t = a_t.expand_as(x)
        else:
            raise NotImplementedError
        # endif

        out = x * a_t

        return out
    # enddef

    def a_max(self) -> Optional[Tensor]:
        if self.mask_pre is None:
            return None
        else:
            ret = self.mask_pre.data.clone()
            ret2 = None
            for m in self.history_mask.values():
                if ret2 is None:
                    ret2 = m.data.clone()
                else:
                    ret2 = torch.max(ret2, m.data.clone())
            # endfor
            assert (ret == ret2).all()

            return ret
        # endif
    # enddef

    def selective_a_max(self, idx: int, list__dissimilar_task: List[int]):
        m = None
        for i in [_i for _i in range(idx) if _i in list__dissimilar_task]:
            if m is None:
                m = self.history_mask[i].data.clone()
            else:
                m = torch.max(m, self.history_mask[i].data.clone())
            # endif
        # endfor

        return m
# endclass
