import random
import time
from argparse import Namespace
from copy import deepcopy
from typing import *

import hydra.utils
import torch
import torch.cuda
from torch.utils.data import DataLoader, TensorDataset

from approaches.abst_appr import AbstractAppr
from approaches.acl import model_acl
from approaches.acl.appr_acl_orig import ACL as OrigAppr
from approaches.acl.model_acl import Net


class Appr(AbstractAppr):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float,
                 epochs_max: int, patience_max: int,
                 drop1: float, drop2: float,
                 batch_size: int, checkpoint: str):
        super().__init__(device=device, list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=None, lamb=None)

        taskcla = [(t, ncls) for t, ncls in enumerate(list__ncls)]

        self.num_samples = 1
        num_tasks = len(list__ncls)
        self.batch_size = batch_size

        self.list__train_memory = list()

        use_memory = 'no'

        args = Namespace(
            # common
            inputsize=inputsize,
            taskcla=taskcla,
            ntasks=num_tasks,
            device=device,
            experiment='others',
            samples=1,
            nepochs=epochs_max,
            lr_min=lr_min,
            lr_factor=lr_factor,
            lr_patience=patience_max,
            # checkpoint=os.path.join(self.dir_params, self.desc),
            checkpoint=checkpoint,
            batch_size=batch_size,
            adv=0.05,
            orth=0.1,
            e_lr=0.01,
            e_wd=0.01,
            s_step=5,
            d_lr=0.001,
            d_wd=0.01,
            d_step=1,
            diff='yes',
            lam=1,
            mom=0.9,
            # shared
            # units=175,
            # latent_dim=128,
            units=2048,
            latent_dim=2048,
            nlayers=2,
            drop1=drop1,
            drop2=drop2,
            )
        model = Net(args=args).to(device)
        self.appr = OrigAppr(model=model, args=args, network=model_acl)
    # enddef

    def update_dataloader(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader):
        # add memory
        d__train = {'x': [], 'y': [], 'tt': [], 'td': []}
        assert len(self.list__train_memory) == idx_task
        for t, train_memory in enumerate(self.list__train_memory):
            c = 0
            d_memory = {'x': [], 'y': [], 'tt': [], 'td': []}
            for x, y, tt, td in train_memory:
                d_memory['x'].append(x)
                d_memory['y'].append(y)
                d_memory['tt'].append(tt)
                d_memory['td'].append(td)
                c += 1
            # endfor
            d__train['x'].append(torch.stack(d_memory['x'], dim=0))
            d__train['y'].append(torch.stack(d_memory['y'], dim=0))
            d__train['tt'].append(torch.stack(d_memory['tt'], dim=0))
            d__train['td'].append(torch.stack(d_memory['td'], dim=0))

            print(f'[add memory at {idx_task}] {c} samples from task {t}')
        # endfor

        for x, y in dl_train:
            bs = x.shape[0]
            d__train['x'].append(x)
            d__train['y'].append(y)
            d__train['tt'].append(torch.tensor([idx_task] * bs))
            d__train['td'].append(torch.tensor([idx_task + 1] * bs))
        # endfor
        dl_train = DataLoader(TensorDataset(
            torch.cat(d__train['x'], dim=0),
            torch.cat(d__train['y'], dim=0),
            torch.cat(d__train['tt'], dim=0),
            torch.cat(d__train['td'], dim=0),
            ), batch_size=self.batch_size)

        d__val = {'x': [], 'y': [], 'tt': [], 'td': []}
        for x, y in dl_val:
            bs = x.shape[0]
            d__val['x'].append(x)
            d__val['y'].append(y)
            d__val['tt'].append(torch.tensor([idx_task] * bs))
            d__val['td'].append(torch.tensor([idx_task + 1] * bs))
        # endfor
        dl_val = DataLoader(TensorDataset(
            torch.cat(d__val['x'], dim=0),
            torch.cat(d__val['y'], dim=0),
            torch.cat(d__val['tt'], dim=0),
            torch.cat(d__val['td'], dim=0),
            ), batch_size=self.batch_size)

        return dl_train, dl_val
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              ) -> float:
        self.dl_train_pre = deepcopy(dl_train)
        dl_train, dl_val = self.update_dataloader(idx_task, dl_train, dl_val)

        dataset = {'train': dl_train, 'valid': dl_val}
        time_start = time.time()
        self.appr.train(task_id=idx_task, dataset=dataset)
        time_end = time.time()
        time_consumed = time_end - time_start

        return time_consumed
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        d = {'x': [], 'y': [], 'tt': [], 'td': []}
        for x, y in dl_test:
            bs = x.shape[0]
            d['x'].append(x)
            d['y'].append(y)
            d['tt'].append(torch.tensor([idx_task] * bs))
            d['td'].append(torch.tensor([idx_task + 1] * bs))
        # endfor
        dl_test = DataLoader(TensorDataset(torch.cat(d['x'], dim=0),
                                           torch.cat(d['y'], dim=0),
                                           torch.cat(d['tt'], dim=0),
                                           torch.cat(d['td'], dim=0),
                                           ), batch_size=dl_test.batch_size)

        test_model = self.appr.load_model(idx_task)
        res = self.appr.test(dl_test, idx_task, model=test_model)

        return {'loss_test': res['loss_t'],
                'acc_test': res['acc_t'] / 100,
                }

    def complete_learning(self, idx_task: int) -> None:
        dl_train = self.dl_train_pre

        # add current
        num_samples_per_class = self.num_samples  # // len(self.list__ds_ncls[task_index])
        dth = {c: 0 for c in range(self.list__ncls[idx_task])}

        l = list()
        for x, y in dl_train:
            bs = x.shape[0]
            for i in range(bs):
                l.append((x[i], y[i], torch.tensor(idx_task), torch.tensor(idx_task + 1)))
            # endfor
        # endfor
        l = random.sample(l, len(l))
        memory = list()
        for x, y, tt, td in l:
            if dth[y.item()] < num_samples_per_class:
                memory.append((x, y, tt, td))
                dth[y.item()] += 1
            # endif
            if all([n == num_samples_per_class for c, n in dth.items()]):
                break
            # endif
        # endfor
        print(f'length of current task: {len(l)} -> memory: {len(memory)}')
        self.list__train_memory.append(memory)
    # enddef

# enclasss
