import time
from typing import *

import torch
from torch.utils.data import DataLoader, TensorDataset

from approaches.abst_appr import AbstractAppr
from approaches.stl.model_stl import ModelSTL


class Appr(AbstractAppr):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float,
                 epochs_max: int, patience_max: int,
                 drop1: float, drop2: float):
        super().__init__(device=device, list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=-1, lamb=0)
        self.model = ModelSTL(list__ncls=list__ncls, inputsize=inputsize,
                              drop1=drop1, drop2=drop2).to(self.device)

        self.list__dl_train = []
        self.list__dl_val = []
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              ) -> float:
        self.list__dl_train.append(dl_train)
        self.list__dl_val.append(dl_val)

        if idx_task < len(self.list__ncls) - 1:
            return 0
        elif idx_task == (len(self.list__ncls) - 1):
            # train
            list__xt_train = []
            list__y_train = []
            list__xt_val = []
            list__y_val = []
            bs_train = None
            bs_val = None

            for i in range(idx_task + 1):
                dl_train = self.list__dl_train[i]
                dl_val = self.list__dl_val[i]

                for x, y in dl_train:
                    bs, ch, w, h = x.shape
                    if bs_train is None:
                        bs_train = bs
                    else:
                        # assert bs <= bs_train
                        pass
                    # endif

                    t = (torch.ones(bs, 1, w, h) * i).float()
                    xt = torch.cat([x, t], dim=1)
                    assert xt.shape == (bs, ch + 1, w, h)

                    list__xt_train.append(xt)
                    list__y_train.append(y)
                # endfor
                for x, y in dl_val:
                    bs, ch, w, h = x.shape
                    if bs_val is None:
                        bs_val = bs
                    else:
                        # assert bs <= bs_val
                        pass
                    # endif

                    t = (torch.ones(bs, 1, w, h) * i).float()
                    xt = torch.cat([x, t], dim=1)
                    assert xt.shape == (bs, ch + 1, w, h)

                    list__xt_val.append(xt)
                    list__y_val.append(y)
                # endfor
            # endfor

            dl_train = DataLoader(TensorDataset(torch.cat(list__xt_train, dim=0), torch.cat(list__y_train, dim=0)),
                                  batch_size=bs_train, shuffle=False)
            dl_val = DataLoader(TensorDataset(torch.cat(list__xt_val, dim=0), torch.cat(list__y_val, dim=0)),
                                batch_size=bs_val, shuffle=False)
            time_start = time.time()
            super().train(idx_task=-1, dl_train=dl_train, dl_val=dl_val,
                          args_on_forward=args_on_forward, args_on_after_backward=args_on_after_backward)
            time_end = time.time()
            time_consumed = time_end - time_start
            return time_consumed
        else:
            raise ValueError
        # endif
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        if idx_task < len(self.list__ncls) - 1:
            # return {'loss_test': 0, 'acc_test': 0}
            pass
        # endif

        list__xt_test = []
        list__y_test = []
        bs_test = None
        for x, y in dl_test:
            bs, ch, w, h = x.shape
            if bs_test is None:
                bs_test = bs
            else:
                # assert bs <= bs_test, f'{bs_test} vs {bs}'
                pass
            # endif

            t = (torch.ones(bs, 1, w, h) * idx_task).float()
            xt = torch.cat([x, t], dim=1)
            assert xt.shape == (bs, ch + 1, w, h)

            list__xt_test.append(xt)
            list__y_test.append(y)
        # endfor

        dl_test = DataLoader(TensorDataset(torch.cat(list__xt_test, dim=0), torch.cat(list__y_test, dim=0)),
                             batch_size=bs_test, shuffle=False)

        return super().test(idx_task=-1, dl_test=dl_test, args_on_forward=args_on_forward)
    # enddef

    def complete_learning(self, idx_task: int) -> None:
        pass
    # enddef
# enclasss
