from typing import *

import numpy as np
import torch
from torch import Tensor, nn


class ModelMTL(nn.Module):
    def __init__(self, list__ncls: List[int], inputsize: Tuple[int, ...], backbone: str,
                 nhid: int, drop1: float, drop2: float):
        super().__init__()
        self.ch, self.w, self.h = inputsize
        self.list__ncls = list__ncls

        relu = nn.ReLU()
        drop1 = nn.Dropout(drop1)
        drop2 = nn.Dropout(drop2)

        if backbone == 'mlp':
            fc1 = nn.Linear(inputsize[0] * inputsize[1] * inputsize[2], nhid)
            fc2 = nn.Linear(nhid, nhid)

            self.feature = nn.Sequential(
                Flattener(),
                drop1,
                fc1,
                relu,
                drop2,
                fc2,
                relu,
                drop2,
                )
        elif backbone == 'alexnet':
            maxpool = nn.MaxPool2d(2)
            size = self.w

            c1 = nn.Conv2d(self.ch, 64, kernel_size=size // 8)
            s = compute_conv_output_size(size, size // 8)
            s = s // 2

            c2 = nn.Conv2d(64, 128, kernel_size=size // 10)
            s = compute_conv_output_size(s, size // 10)
            s = s // 2

            c3 = nn.Conv2d(128, 256, kernel_size=2)
            s = compute_conv_output_size(s, 2)
            s = s // 2

            fc1 = nn.Linear(256 * s ** 2, nhid)
            fc2 = nn.Linear(nhid, nhid)

            self.feature = nn.Sequential(
                c1,
                relu,
                drop1,
                maxpool,
                c2,
                relu,
                drop1,
                maxpool,
                c3,
                relu,
                drop2,
                maxpool,
                Flattener(),
                fc1,
                relu,
                drop2,
                fc2,
                relu,
                drop2,
                )
        else:
            raise NotImplementedError(backbone)
        # endif

        self.list__fc = nn.ModuleList()
        for ncls in list__ncls:
            self.list__fc.append(nn.Linear(nhid, ncls))
        # endfor
    # endfor

    def forward(self, idx_task: int, tx: Tensor, s: float, args_on_forward: Dict[str, Any]) -> Tensor:
        bs, _, _, _ = tx.shape
        t = tx[:, -1, 0, 0]  # type: Tensor
        x = tx[:, :-1, :, :]  # type: Tensor
        assert t.shape == (bs,)
        assert x.shape == (bs, self.ch, self.w, self.h)

        x = self.feature(x)

        # selection by task
        set_t = set(t.detach().cpu().numpy())
        dict__element_t__index = {element_t: torch.where(t == element_t)[0]
                                  for element_t in set_t}

        out = torch.zeros(bs, max(self.list__ncls), device=tx.device)
        for element_t, indices in dict__element_t__index.items():
            assert int(element_t) == element_t
            element_t = int(element_t)
            ncls = self.list__ncls[element_t]

            x_t = x[indices]
            fc = self.list__fc[element_t]
            out_t = fc(x_t)

            out[indices, :ncls] += out_t
        # endfor

        misc = {
            'reg': 0,
            'task_indices': t.data.clone(),
            }

        return out, misc
    # enddef

    def on_after_backward_emb(self, **kwargs):
        pass
    # enddef

    def on_after_backward_params(self, idx_task: int, **kwargs):
        pass
    # enddef


def compute_conv_output_size(Lin, kernel_size, stride=1, padding=0, dilation=1):
    return int(np.floor((Lin + 2 * padding - dilation * (kernel_size - 1) - 1) / float(stride) + 1))


class Flattener(nn.Module):
    def forward(self, x: Tensor) -> Tensor:
        return x.view(x.shape[0], -1)
