import time
from typing import *

import numpy as np
import torch
from torch import Tensor, nn

from approaches.hat.hat import HAT
from utils import assert_type, myprint as print

use_bias = True


class ModelHATFeatureAlexNet(nn.Module):
    def __init__(self, list__ncls: List[int], inputsize: Tuple[int, ...],
                 smax: float, hat_enabled: bool,
                 nhid: int, drop1: float, drop2: float,
                 eq_ncls: bool = False):
        super().__init__()
        self.model = Alexnet(num_tasks=len(list__ncls), inputsize=inputsize, hat_enabled=hat_enabled,
                             nhid=nhid, drop1=drop1, drop2=drop2)

        self.smax = smax
        self.last_dim = nhid
    # enddef

    def freeze_masks(self, idx: int):
        for hat in self.model.hats:
            hat.freeze_masks(idx, self.smax)
        # endfor
    # enddef

    def forward(self, idx_task: int, x: Tensor, s: float, args_on_forward: Dict[str, Any]) -> Tensor:
        assert_type(idx_task, int)
        assert_type(x, Tensor)

        batch_size = x.shape[0]
        # assert_shape(x, batch_size, -1, -1, -1)

        ret = self.model(idx_task, x, s=s, args_on_forward=args_on_forward)

        return ret  # out, reg
    # enddef

    def on_after_backward_emb(self, s: float) -> None:
        self.model.on_after_backward_emb(s=s, smax=self.smax)
    # enddef

    def on_after_backward_params(self, idx_task: int, s: float, args: Dict[str, Any]) -> None:
        blocking = self.model.on_after_backward_params(idx_task=idx_task, s=s, smax=self.smax, args=args)

        return blocking
    # enddef


# endclass

def compute_conv_output_size(Lin, kernel_size, stride=1, padding=0, dilation=1):
    return int(np.floor((Lin + 2 * padding - dilation * (kernel_size - 1) - 1) / float(stride) + 1))


class Alexnet(nn.Module):
    def __init__(self, num_tasks: int, inputsize: Tuple[int, ...], hat_enabled: bool,
                 nhid: int, drop1: float, drop2: float):
        super().__init__()

        nch, size = inputsize[0], inputsize[1]

        self.c1 = nn.Conv2d(nch, 64, kernel_size=size // 8)
        s = compute_conv_output_size(size, size // 8)
        s = s // 2

        self.c2 = nn.Conv2d(64, 128, kernel_size=size // 10)
        s = compute_conv_output_size(s, size // 10)
        s = s // 2

        self.c3 = nn.Conv2d(128, 256, kernel_size=2)
        s = compute_conv_output_size(s, 2)
        s = s // 2

        self.smid = s
        self.maxpool = nn.MaxPool2d(2)
        self.relu = nn.ReLU()
        self.drop1 = nn.Dropout(drop1)
        self.drop2 = nn.Dropout(drop2)

        self.fc1 = nn.Linear(256 * self.smid ** 2, nhid,
                             bias=use_bias)
        self.fc2 = nn.Linear(nhid, nhid, bias=use_bias)

        self.hat_enabled = hat_enabled
        if self.hat_enabled:
            self.hat_c1 = HAT(num_tasks=num_tasks, dim_emb=64, module_name='c1')
            self.hat_c2 = HAT(num_tasks=num_tasks, dim_emb=128, module_name='c2')
            self.hat_c3 = HAT(num_tasks=num_tasks, dim_emb=256, module_name='c3')
            self.hat_f1 = HAT(num_tasks=num_tasks, dim_emb=nhid, module_name='fc1')
            self.hat_f2 = HAT(num_tasks=num_tasks, dim_emb=nhid, module_name='fc2')
        else:
            pass
        # endif

        print(self)
    # endddef

    def forward(self, idx_task: int, x: Tensor, s: float, args_on_forward: Dict[str, Any]) -> Tensor:
        assert_type(idx_task, int)
        assert_type(x, Tensor)
        # assert_type(s, [float, int])

        self.device = x.device

        # 1st
        h = self.maxpool(self.drop1(self.relu(self.c1(x))))
        if self.hat_enabled:
            h = self.hat_c1(h, idx_task, s=s, args_on_forward=args_on_forward)
        # endif

        # 2nd
        h = self.maxpool(self.drop1(self.relu(self.c2(h))))
        if self.hat_enabled:
            h = self.hat_c2(h, idx_task, s=s, args_on_forward=args_on_forward)
        # endif

        # 3rd
        h = self.maxpool(self.drop2(self.relu(self.c3(h))))
        if self.hat_enabled:
            h = self.hat_c3(h, idx_task, s=s, args_on_forward=args_on_forward)
        # endif

        # 4th
        h = h.view(h.shape[0], -1)
        h = self.drop2(self.relu(self.fc1(h)))
        if self.hat_enabled:
            h = self.hat_f1(h, idx_task, s=s, args_on_forward=args_on_forward)
        # endif

        h = self.drop2(self.relu(self.fc2(h)))
        if self.hat_enabled:
            h = self.hat_f2(h, idx_task, s=s, args_on_forward=args_on_forward)
        # endif

        if self.hat_enabled:
            hats = self.hats
            masks = [hat.mask(idx_task, s) for hat in hats]
            masks_pre = [hat.a_max() for hat in hats]

            reg = 0
            count = 1e-5
            if masks_pre[0] is not None:
                for m, mp in zip(masks, masks_pre):
                    aux = 1 - mp
                    reg += (m * aux).sum()
                    count += aux.sum()
                # endfor
            else:
                for m in masks:
                    reg += m.sum()
                    count += np.prod(m.size()).item()
                # endfor
            # endif
            reg /= count
        else:
            reg = torch.tensor(0, device=self.device)
        # endif

        misc = {
            'reg': reg,
            }

        return h, misc
    # enddef

    @property
    def hats(self) -> List[HAT]:
        return [self.hat_c1, self.hat_c2, self.hat_c3, self.hat_f1, self.hat_f2]
    # enddef

    def on_after_backward_emb(self, s: float, smax: float):
        if not self.hat_enabled:
            return
        # endif

        # Compensate embedding gradients
        thres_cosh = 50
        for hat in self.hats:
            for n, p in hat.named_parameters():
                if 'weight' in n:
                    num = torch.cosh(torch.clamp(s * p.data, -thres_cosh, thres_cosh)) + 1
                    den = torch.cosh(p.data) + 1
                    if p.grad is not None:
                        p.grad.data *= smax / s * num / den
                    # endif
                else:
                    raise NotImplementedError(n)
                # endif
            # endfor
        # endfor

        torch.nn.utils.clip_grad_norm_(self.parameters(), 10000)
    # enddef

    def on_after_backward_params(self, idx_task: int, s: float, smax: float, args: Dict[str, Any]):
        if not self.hat_enabled:
            return
        # endif

        log = []
        log.append(f'[on_after_backward] In learning {idx_task}:')

        pre_c1, pre_c2, pre_c3 = self.hat_c1.a_max(), self.hat_c2.a_max(), self.hat_c3.a_max()
        pre_f1, pre_f2 = self.hat_f1.a_max(), self.hat_f2.a_max()

        blocking = {}
        if pre_f1 is not None:
            for n, p in self.fc1.named_parameters():
                if 'weight' in n:
                    post = pre_f1.data.view(-1, 1).expand_as(self.fc1.weight)
                    pre = pre_c3.data.view(-1, 1, 1) \
                        .expand((self.hat_c3.emb.weight.shape[1], self.smid, self.smid)) \
                        .contiguous().view(1, -1).expand_as(self.fc1.weight)
                    red = (1 - torch.min(post, pre))
                elif 'bias' in n:
                    post = pre_f1.data.view(-1)
                    red = (1 - post)
                else:
                    raise NotImplementedError(n)
                # endif

                p.grad.data *= red
                blocking[f'fc1.{n}'] = red.data.clone()

                num_0 = red[red == 0].numel()
                num_all = red.numel()
                log.append(f'[fc1/{n}] dead parameters: {num_0}/{num_all} ({num_0 / num_all * 100:.1f}%)')
            # endfor

            for n, p in self.fc2.named_parameters():
                if 'weight' in n:
                    post = pre_f2.data.view(-1, 1).expand_as(self.fc2.weight)
                    pre = pre_f1.data.view(1, -1).expand_as(self.fc2.weight)
                    red = (1 - torch.min(post, pre))
                elif 'bias' in n:
                    post = pre_f2.data.view(-1)
                    red = (1 - post)
                else:
                    raise NotImplementedError(n)
                # endif

                p.grad.data *= red
                blocking[f'fc2.{n}'] = red.data.clone()

                num_0 = red[red == 0].numel()
                num_all = red.numel()
                log.append(f'[fc2/{n}] dead parameters: {num_0}/{num_all} ({num_0 / num_all * 100:.1f}%)')
            # endfor

            for n, p in self.c1.named_parameters():
                if 'weight' in n:
                    post = pre_c1.data.view(-1, 1, 1, 1).expand_as(self.c1.weight)
                    red = (1 - post)
                elif 'bias' in n:
                    post = pre_c1.data.view(-1)
                    red = (1 - post)
                else:
                    raise NotImplementedError(n)
                # endif

                p.grad.data *= red
                blocking[f'c1.{n}'] = red.data.clone()

                num_0 = red[red == 0].numel()
                num_all = red.numel()
                log.append(f'[c1/{n}] dead parameters: {num_0}/{num_all} ({num_0 / num_all * 100:.1f}%)')
            # endfor

            for n, p in self.c2.named_parameters():
                if 'weight' in n:
                    post = pre_c2.data.view(-1, 1, 1, 1).expand_as(self.c2.weight)
                    pre = pre_c1.data.view(1, -1, 1, 1).expand_as(self.c2.weight)
                    red = (1 - torch.min(post, pre))
                elif 'bias' in n:
                    post = pre_c2.data.view(-1)
                    red = (1 - post)
                else:
                    raise NotImplementedError(n)
                # endif

                p.grad.data *= red
                blocking[f'c2.{n}'] = red.data.clone()

                num_0 = red[red == 0].numel()
                num_all = red.numel()
                log.append(f'[c2/{n}] dead parameters: {num_0}/{num_all} ({num_0 / num_all * 100:.1f}%)')
            # endfor

            for n, p in self.c3.named_parameters():
                if 'weight' in n:
                    post = pre_c3.data.view(-1, 1, 1, 1).expand_as(self.c3.weight)
                    pre = pre_c2.data.view(1, -1, 1, 1).expand_as(self.c3.weight)
                    red = (1 - torch.min(post, pre))
                elif 'bias' in n:
                    post = pre_c3.data.view(-1)
                    red = (1 - post)
                else:
                    raise NotImplementedError(n)
                # endif

                p.grad.data *= red
                blocking[f'c3.{n}'] = red.data.clone()

                num_0 = red[red == 0].numel()
                num_all = red.numel()
                log.append(f'[c3/{n}] dead parameters: {num_0}/{num_all} ({num_0 / num_all * 100:.1f}%)')
            # endfor
        # endif

        # log writing
        if s == 1 / smax and args['epoch'] == 0:
            log = '\n'.join(log)
            print(log)
        # endif

        return blocking
    # enddef

    def compute_param_consumed(self, idx_task: int) -> float:
        # if idx_task == 0:
        #    return 0
        # endif

        pre_c1, pre_c2, pre_c3 = self.hat_c1.a_max(), self.hat_c2.a_max(), self.hat_c3.a_max()
        pre_f1, pre_f2 = self.hat_f1.a_max(), self.hat_f2.a_max()

        num_all = 0
        num_blocked = 0

        if pre_f1 is not None:
            for n, p in self.fc1.named_parameters():
                if 'weight' in n:
                    post = pre_f1.data.view(-1, 1).expand_as(self.fc1.weight)
                    pre = pre_c3.data.view(-1, 1, 1) \
                        .expand((self.hat_c3.emb.weight.shape[1], self.smid, self.smid)) \
                        .contiguous().view(1, -1).expand_as(self.fc1.weight)
                    red = (1 - torch.min(post, pre))
                elif 'bias' in n:
                    post = pre_f1.data.view(-1)
                    red = (1 - post)
                else:
                    raise NotImplementedError(n)
                # endif

                num_all += red.numel()
                num_blocked += red[red == 0].numel()
            # endfor

            for n, p in self.fc2.named_parameters():
                if 'weight' in n:
                    post = pre_f2.data.view(-1, 1).expand_as(self.fc2.weight)
                    pre = pre_f1.data.view(1, -1).expand_as(self.fc2.weight)
                    red = (1 - torch.min(post, pre))
                elif 'bias' in n:
                    post = pre_f2.data.view(-1)
                    red = (1 - post)
                else:
                    raise NotImplementedError(n)
                # endif

                num_all += red.numel()
                num_blocked += red[red == 0].numel()
            # endfor

            for n, p in self.c1.named_parameters():
                if 'weight' in n:
                    post = pre_c1.data.view(-1, 1, 1, 1).expand_as(self.c1.weight)
                    red = (1 - post)
                elif 'bias' in n:
                    post = pre_c1.data.view(-1)
                    red = (1 - post)
                else:
                    raise NotImplementedError(n)
                # endif

                num_all += red.numel()
                num_blocked += red[red == 0].numel()
            # endfor

            for n, p in self.c2.named_parameters():
                if 'weight' in n:
                    post = pre_c2.data.view(-1, 1, 1, 1).expand_as(self.c2.weight)
                    pre = pre_c1.data.view(1, -1, 1, 1).expand_as(self.c2.weight)
                    red = (1 - torch.min(post, pre))
                elif 'bias' in n:
                    post = pre_c2.data.view(-1)
                    red = (1 - post)
                else:
                    raise NotImplementedError(n)
                # endif

                num_all += red.numel()
                num_blocked += red[red == 0].numel()
            # endfor

            for n, p in self.c3.named_parameters():
                if 'weight' in n:
                    post = pre_c3.data.view(-1, 1, 1, 1).expand_as(self.c3.weight)
                    pre = pre_c2.data.view(1, -1, 1, 1).expand_as(self.c3.weight)
                    red = (1 - torch.min(post, pre))
                elif 'bias' in n:
                    post = pre_c3.data.view(-1)
                    red = (1 - post)
                else:
                    raise NotImplementedError(n)
                # endif

                num_all += red.numel()
                num_blocked += red[red == 0].numel()
            # endfor
        # endif

        return num_blocked / num_all
    # enddef

# endclass
