from typing import *

import numpy as np
import torch
from torch import Tensor, nn

from approaches.hat.hat import HAT
from approaches.prmwo2so.constants import ConstantsPRM
from utils import assert_type, myprint as print


class ModelHATFeature(nn.Module):
    def __init__(self, list__ncls: List[int], inputsize: Tuple[int, ...],
                 smax: float, hat_enabled: bool,
                 drop1: float, drop2: float):
        super(ModelHATFeature, self).__init__()
        self.model = MLP(num_tasks=len(list__ncls), inputsize=inputsize, hat_enabled=hat_enabled,
                         drop1=drop1, drop2=drop2)

        self.smax = smax
    # enddef

    def freeze_masks(self, idx: int):
        for hat in self.model.hats:
            hat.freeze_masks(idx, self.smax)
        # endfor
    # enddef

    def forward(self, idx_task: int, x: Tensor, s: float, args_on_forward: Dict[str, Any]) -> Tensor:
        assert_type(idx_task, int)
        assert_type(x, Tensor)

        batch_size = x.shape[0]
        # assert_shape(x, batch_size, -1, -1, -1)

        ret = self.model(idx_task, x, s=s, args_on_forward=args_on_forward)

        return ret  # out, reg
    # enddef

    def on_after_backward(self, idx_task: int, s: float, args: Dict[str, Any]) -> None:
        self.model.on_after_backward(idx_task=idx_task, s=s, smax=self.smax, args=args)
        torch.nn.utils.clip_grad_norm_(self.parameters(), 10000)
    # enddef


# endclass


class MLP(nn.Module):
    def __init__(self, num_tasks: int, inputsize: Tuple[int, ...], hat_enabled: bool,
                 drop1: float, drop2: float):
        super(MLP, self).__init__()

        self.fc1 = nn.Linear(inputsize[0] * inputsize[1] * inputsize[2], 2048)
        self.fc2 = nn.Linear(2048, 2048)

        self.relu = nn.ReLU()
        self.drop1 = nn.Dropout(drop1)
        self.drop2 = nn.Dropout(drop2)

        self.hat_enabled = hat_enabled
        if self.hat_enabled:
            self.hat1 = HAT(num_tasks=num_tasks, dim_emb=2048)
            self.hat2 = HAT(num_tasks=num_tasks, dim_emb=2048)
        else:
            pass
        # endif
    # endddef

    def forward(self, idx_task: int, x: Tensor, s: float, args_on_forward: Dict[str, Any]) -> Tensor:
        assert_type(idx_task, int)
        assert_type(x, Tensor)
        assert_type(s, [float, int])

        self.device = x.device

        # Gated
        h = x.view(x.shape[0], -1)
        h = self.drop1(h)

        h = self.drop2(self.relu(self.fc1(h)))
        if self.hat_enabled:
            h = self.hat1(idx_task, h, s=s, args_on_forward=args_on_forward)
        # endif

        h = self.drop2(self.relu(self.fc2(h)))
        if self.hat_enabled:
            h = self.hat2(idx_task, h, s=s, args_on_forward=args_on_forward)
        # endif

        if ConstantsPRM.NO_NEED_REG in args_on_forward.keys():
            return h
        # endif

        if self.hat_enabled:
            hats = self.hats
            masks = [hat.mask(idx_task, s) for hat in hats]

            if ConstantsPRM.KEY_DISSIMILARS in args_on_forward.keys():
                dict__idx_layer__dissimilars = args_on_forward[ConstantsPRM.KEY_DISSIMILARS]
                if 'epoch' in args_on_forward.keys() and args_on_forward['epoch'] == 0 \
                        and args_on_forward['idx_batch'] == 0:
                    print(f'mask_pre on {dict__idx_layer__dissimilars}')
                # endif
                masks_pre = [hat.selective_a_max(idx_task, dict__idx_layer__dissimilars[idx_layer])
                             for idx_layer, hat in enumerate(hats)]
            else:
                masks_pre = [hat.a_max() for hat in hats]
            # endif

            reg = 0
            count = 1e-5
            if masks_pre[0] is not None:
                for m, mp in zip(masks, masks_pre):
                    aux = 1 - mp
                    reg += (m * aux).sum()
                    count += aux.sum()
                # endfor
            else:
                for m in masks:
                    reg += m.sum()
                    count += np.prod(m.size()).item()
                # endfor
            # endif
            reg /= count
        else:
            reg = torch.tensor(0, device=self.device)
        # endif

        return h, reg
    # enddef

    @property
    def hats(self) -> List[HAT]:
        return [self.hat1, self.hat2]
    # enddef

    def on_after_backward(self, idx_task: int, s: float, smax: float, args: Dict[str, Any]):
        if not self.hat_enabled:
            return
        # endif

        log = []
        log.append(f'[on_after_backward] In learning {idx_task}:')

        assert_type(idx_task, int)
        if 'dict__idx_layer__dissimilars' in args.keys():
            dict__idx_layer__dissimilars = args['dict__idx_layer__dissimilars']
            list__dissimilar_task_0 = dict__idx_layer__dissimilars[0]
            list__dissimilar_task_1 = dict__idx_layer__dissimilars[1]

            if s == 1 / smax:
                msg1 = f'- [1]: block: {list__dissimilar_task_0}, relax: {[t for t in range(idx_task) if t not in list__dissimilar_task_0]}'
                msg2 = f'- [2]: block: {list__dissimilar_task_1}, relax: {[t for t in range(idx_task) if t not in list__dissimilar_task_1]}'

                log.append(msg1)
                log.append(msg2)
            # endif

            pre_1 = self.hat1.selective_a_max(idx_task, list__dissimilar_task_0)
            pre_2 = self.hat2.selective_a_max(idx_task, list__dissimilar_task_1)
        else:
            pre_1, pre_2 = self.hat1.a_max(), self.hat2.a_max()
        # endif

        if pre_1 is not None:
            for n, p in self.fc1.named_parameters():
                if 'weight' in n:
                    post = pre_1.data.view(-1, 1).expand_as(self.fc1.weight)
                    red = (1 - post)  # type: Tensor
                    p.grad.data *= red
                elif 'bias' in n:
                    post = pre_1.data.view(-1)
                    red = (1 - post)
                    p.grad.data *= red
                else:
                    raise NotImplementedError(n)
                # endif

                num_0 = red[red == 0].numel()
                num_all = red.numel()
                log.append(f'[mask1/{n}] dead parameters: {num_0}/{num_all} ({num_0 / num_all * 100:.1f}%)')
            # endfor

            for n, p in self.fc2.named_parameters():
                if 'weight' in n:
                    post = pre_2.data.view(-1, 1).expand_as(self.fc2.weight)
                    pre = pre_1.data.view(1, -1).expand_as(self.fc2.weight)
                    red = (1 - torch.min(post, pre))
                    p.grad.data *= red
                elif 'bias' in n:
                    post = pre_2.data.view(-1)
                    red = (1 - post)
                    p.grad.data *= red
                else:
                    raise NotImplementedError(n)
                # endif

                num_0 = red[red == 0].numel()
                num_all = red.numel()
                log.append(f'[mask2/{n}] dead parameters: {num_0}/{num_all} ({num_0 / num_all * 100:.1f}%)')
            # endfor
        # endif

        # log writing
        if s == 1 / smax and args['epoch'] == 0:
            log = '\n'.join(log)
            print(log)

            # with open(os.path.join(kwargs['dir_results'], kwargs['desc'], 'dead_parameter_log.txt'), 'at') as f:
            #   f.write(log)
            #  f.write('\n')
            # f.write('-------\n')
            # f.flush()
            # endwith
        # endif

        # Compensate embedding gradients
        thres_cosh = 50
        for hat in self.hats:
            for n, p in hat.named_parameters():
                if 'weight' in n:
                    num = torch.cosh(torch.clamp(s * p.data, -thres_cosh, thres_cosh)) + 1
                    den = torch.cosh(p.data) + 1
                    if p.grad is not None:
                        p.grad.data *= smax / s * num / den
                    # endif
                else:
                    raise NotImplementedError(n)
                # endif
            # endfor
        # endfor

        torch.nn.utils.clip_grad_norm_(self.parameters(), 10000)
    # enddef

# endclass
