import time
from typing import *

import torch
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset

from utils import print_num_params
from approaches.abst_appr import AbstractAppr
from approaches.mtl.model_mtl import ModelMTL


class Appr(AbstractAppr):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float,
                 epochs_max: int, patience_max: int, backbone: str, batch_size: int,
                 nhid: int, drop1: float, drop2: float, small_lr: bool):
        if small_lr:
            lr = lr / 100
        else:
            lr = lr / 10
        # endif
        super().__init__(device=device, list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=-1, lamb=0)
        self.model = ModelMTL(list__ncls=list__ncls, inputsize=inputsize,
                              backbone=backbone, nhid=nhid, drop1=drop1, drop2=drop2).to(self.device)

        print_num_params(self.model)

        self.num_tasks = len(list__ncls)
        self.batch_size = batch_size
        self.list__dl_train = [NotImplemented] * self.num_tasks
        self.list__dl_val = [NotImplemented] * self.num_tasks
    # enddef

    def compute_loss(self, output: Tensor, target: Tensor, misc: Dict[str, Any]) -> Tensor:
        reg = misc['reg']
        task_indices = misc['task_indices']  # type: Tensor

        set_t = set(task_indices.detach().cpu().numpy())
        dict__element_t__index = {element_t: torch.where(task_indices == element_t)[0]
                                  for element_t in set_t}
        loss = 0
        for element_t, indices in dict__element_t__index.items():
            assert int(element_t) == element_t
            element_t = int(element_t)
            ncls = self.list__ncls[element_t]

            output_t = output[indices, :ncls]
            target_t = target[indices]
            assert torch.all(target_t < ncls)

            loss += self.criterion(output_t, target_t)
        # endfor

        return loss + self.lamb * reg
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              **kwargs) -> Dict[str, Any]:
        self.list__dl_train[idx_task] = dl_train
        self.list__dl_val[idx_task] = dl_val

        if idx_task < len(self.list__ncls) - 1:
            return {'time_consumed': 0}
        elif idx_task == (len(self.list__ncls) - 1):
            # train
            list__xt_train = []
            list__y_train = []
            list__xt_val = []
            list__y_val = []

            for i in range(idx_task + 1):
                dl_train = self.list__dl_train[i]  # type: DataLoader
                dl_val = self.list__dl_val[i]  # type: DataLoader

                for x, y in dl_train:
                    t = torch.ones_like(x)[:, [0], :, :] * i
                    xt = torch.cat([x, t], dim=1)

                    list__xt_train.append(xt)
                    list__y_train.append(y)
                # endfor

                for x, y in dl_val:
                    t = torch.ones_like(x)[:, [0], :, :] * i
                    xt = torch.cat([x, t], dim=1)

                    list__xt_val.append(xt)
                    list__y_val.append(y)
                # endfor
            # endfor

            dl_train = DataLoader(TensorDataset(torch.cat(list__xt_train, dim=0),
                                                torch.cat(list__y_train, dim=0)),
                                  batch_size=self.batch_size * 10,
                                  shuffle=True)
            dl_val = DataLoader(TensorDataset(torch.cat(list__xt_val, dim=0),
                                              torch.cat(list__y_val, dim=0)),
                                batch_size=self.batch_size)
            time_start = time.time()
            super().train(idx_task=-1, dl_train=dl_train, dl_val=dl_val,
                          args_on_forward=args_on_forward, args_on_after_backward=args_on_after_backward)
            time_end = time.time()
            time_consumed = time_end - time_start

            misc = {
                'time_consumed': time_consumed,
                }

            return misc
        else:
            raise ValueError
        # endif
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        list__xt_test = []
        list__y_test = []

        for x, y in dl_test:
            t = torch.ones_like(x)[:, [0], :, :] * idx_task
            xt = torch.cat([x, t], dim=1)

            list__xt_test.append(xt)
            list__y_test.append(y)
        # endfor

        dl_test = DataLoader(TensorDataset(torch.cat(list__xt_test, dim=0),
                                           torch.cat(list__y_test, dim=0)),
                             batch_size=self.batch_size)

        return super().test(idx_task=-1, dl_test=dl_test, args_on_forward=args_on_forward)
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        pass
    # enddef
# enclasss
