from typing import *

import numpy as np
from torch import Tensor, nn

from approaches.spgfi.spgfi import SPGFI
from utils import assert_type

use_bias = True


class ModelSPGFIFeatureAlexNet(nn.Module):
    def __init__(self, list__ncls: List[int], inputsize: Tuple[int, ...], batch_size: int,
                 nhid: int, drop1: float, drop2: float):
        super().__init__()
        self.model = Alexnet(inputsize=inputsize, batch_size=batch_size,
                             nhid=nhid, drop1=drop1, drop2=drop2)

        self.last_dim = nhid
    # enddef

    def set_fc(self, fc: nn.Module):
        self.model.fc = fc
    # enddef

    def freeze_ewc_masks(self, idx_task: int, fisher: Dict[str, Tensor], **kwargs):
        self.model.freeze_ewc_masks(idx_task=idx_task, fisher=fisher, **kwargs)
    # enddef

    def forward(self, x: Tensor, idx_task: int, s: float, args_on_forward: Dict[str, Any]) -> Tuple[Tensor, Dict]:
        assert_type(idx_task, int)
        assert_type(x, Tensor)

        batch_size = x.shape[0]
        # assert_shape(x, batch_size, -1, -1, -1)

        ret, misc = self.model(x, idx_task, s=s, args_on_forward=args_on_forward)

        return ret, misc
    # enddef

    def on_after_backward_params(self, idx_task: int, s: float, args: Dict[str, Any]) -> None:
        blocking = self.model.on_after_backward_params(idx_task=idx_task, s=s, smax=NotImplemented, args=args)

        return blocking
    # enddef


# endclass

def compute_conv_output_size(Lin, kernel_size, stride=1, padding=0, dilation=1):
    return int(np.floor((Lin + 2 * padding - dilation * (kernel_size - 1) - 1) / float(stride) + 1))


class Flattener(nn.Module):
    def __init__(self):
        super().__init__()
    # enddef

    def forward(self, x: Tensor) -> Tensor:
        return x.view(x.shape[0], -1)
    # enddef


class Alexnet(nn.Module):
    def __init__(self, inputsize: Tuple[int, ...], batch_size: int,
                 nhid: int, drop1: float, drop2: float):
        super().__init__()

        nch, size = inputsize[0], inputsize[1]
        self.batch_size = batch_size

        self.c1 = SPGFI(nn.Conv2d(nch, 64, kernel_size=size // 8), target_name='c1')
        s = compute_conv_output_size(size, size // 8)
        s = s // 2

        self.c2 = SPGFI(nn.Conv2d(64, 128, kernel_size=size // 10), target_name='c2')
        s = compute_conv_output_size(s, size // 10)
        s = s // 2

        self.c3 = SPGFI(nn.Conv2d(128, 256, kernel_size=2), target_name='c3')
        s = compute_conv_output_size(s, 2)
        s = s // 2

        self.smid = s
        self.maxpool = nn.MaxPool2d(2)
        self.relu = nn.ReLU()
        self.drop1 = nn.Dropout(drop1)
        self.drop2 = nn.Dropout(drop2)

        self.fc1 = SPGFI(nn.Linear(256 * self.smid ** 2, nhid, bias=use_bias), target_name='fc1')
        self.fc2 = SPGFI(nn.Linear(nhid, nhid, bias=use_bias), target_name='fc2')
    # endddef

    def forward(self, x: Tensor, idx_task: int, s: float, args_on_forward: Dict[str, Any]) -> Tuple[Tensor, Dict]:
        assert_type(idx_task, int)
        assert_type(x, Tensor)

        self.device = x.device

        # 1st
        h = self.maxpool(self.drop1(self.relu(self.c1(x))))

        # 2nd
        h = self.maxpool(self.drop1(self.relu(self.c2(h))))

        # 3rd
        h = self.maxpool(self.drop2(self.relu(self.c3(h))))

        # 4th
        h = h.view(h.shape[0], -1)
        h = self.drop2(self.relu(self.fc1(h)))
        h = self.drop2(self.relu(self.fc2(h)))

        misc = {
            'reg': 0
            }

        return h, misc
    # enddef

    def freeze_ewc_masks(self, idx_task: int, fisher: Dict[str, Tensor], **kwargs):
        for nmodule, module in self.named_modules():
            if isinstance(module, SPGFI):
                subfisher = {}
                for nparam, _ in module.target_module.named_parameters():
                    n = f'feature.model.{module.target_name}.target_module.{nparam}'

                    _f = [v for k, v in fisher.items() if k == n]
                    assert len(_f) == 1, len(_f)
                    subfisher[nparam] = _f[0]
                # endfor

                module.register_fisher(idx_task=idx_task, t=idx_task, h=subfisher)
            # endif
        # endfor

        # freeze masks
        for name, module in self.named_modules():
            if isinstance(module, SPGFI):
                module.freeze_masks(idx_task)
            # endif
        # endfor
    # enddef

    def on_after_backward_params(self, idx_task: int, s: float, smax: float, args: Dict[str, Any]):
        if idx_task == 0:
            return
        # endif

        epoch = args['epoch']
        idx_batch = args['idx_batch']
        show = epoch == 0 and idx_batch == 0

        for n, module in self.named_modules():
            if isinstance(module, SPGFI):
                module.block(idx_task, name=n, show=show)
            # endif
        # endfor
    # enddef

# endclass
