from typing import *

import torch
from torch import Tensor, nn


class ModelSTL(nn.Module):
    def __init__(self, list__ncls: List[int], inputsize: Tuple[int, ...],
                 drop1: float, drop2: float):
        super().__init__()
        self.ch, self.w, self.h = inputsize
        self.list__ncls = list__ncls

        self.feature = nn.Sequential(
            nn.Dropout(drop1),
            nn.Linear(inputsize[0] * inputsize[1] * inputsize[2], 2048),
            nn.ReLU(),
            nn.Dropout(drop2),
            nn.Linear(2048, 2048),
            nn.ReLU(),
            nn.Dropout(drop2),
            )
        self.list__fc = nn.ModuleList()
        for ncls in list__ncls:
            self.list__fc.append(nn.Linear(2048, ncls))
        # endfor
    # endfor

    def forward(self, idx_task: int, tx: Tensor, s: float, args_on_forward: Dict[str, Any]) -> Tensor:
        bs, _, _, _ = tx.shape
        t = tx[:, -1, 0, 0]  # type: Tensor
        x = tx[:, :-1, :, :]  # type: Tensor
        assert t.shape == (bs,)
        assert x.shape == (bs, self.ch, self.w, self.h)

        x = self.feature(x.view(bs, -1))

        # 連続抽出
        set_t = set(t.detach().cpu().numpy())
        dict__element_t__index = {element_t: torch.where(t == element_t)[0]
                                  for element_t in set_t}

        out = torch.zeros(bs, max(self.list__ncls), device=tx.device)
        for element_t, index in dict__element_t__index.items():
            assert int(element_t) == element_t
            element_t = int(element_t)

            fc = self.list__fc[element_t]
            x_t = x[index]
            out_t = fc(x_t)

            out[index, :self.list__ncls[element_t]] += out_t
        # endfor

        return out, {'reg': 0}

        '''
        for idx_row in range(bs):
            idx_task = t[idx_row].long().item()
            x_row = x[idx_row]

            fc = self.list__fc[idx_task]
            out_row = fc(x_row)
            ncls_row = self.list__ncls[idx_task]
            assert out_row.shape == (ncls_row,), f'{out_row.shape} vs {ncls_row}'

            # padding
            if ncls_row < max(self.list__ncls):
                pad = torch.zeros(max(self.list__ncls) - ncls_row)
                out_row = torch.cat([out_row, pad], dim=0)
            # endif

            list__out.append(out_row)
        # endfor
        out = torch.stack(list__out, dim=0)

        return out, {'reg': 0}
        '''
    # enddef

    def on_after_backward(self, idx_task: int, s: float, args: Dict[str, Any]):
        pass
    # enddef
