"""
Rehearsal based methods:
- ER-FIFO
- ER-FIFO-RW
- MIR
- MaxLoss
"""
import torch
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, RandomSampler
from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters

from algorithms.ocl import OCLAlgorithm
from optimizer import initialize_optimizer_with_model_params
from scheduler import initialize_scheduler
from utils import key_is_none, attr_is_none
from algorithms.alg_utils import get_balanced_loader, ConcatDataset, ConcatBatchSampler


class ER_FIFO(OCLAlgorithm):
    """
    ER-FIFO: Experience replay with a FIFO buffer.
    Fine-tune the model on the union of the new batch, recent batches and the training regression set.

    Config:
        config.epochs              - Number of epochs for other batches
        All configs for ERM
    """

    def train(self, t: int, feedback: dict) -> None:
        if key_is_none(feedback, 'batch_labeled'):
            return
        batch_labeled = feedback['batch_labeled']

        d = self.get_kr_datasets(t, True)
        datasets = [batch_labeled] + d
        concat_dataset = ConcatDataset(datasets)

        epochs = self.config.epochs
        self.erm(epochs, concat_dataset)


class ER_FIFO_RW(OCLAlgorithm):
    """
    ER-FIFO-RW: ER-FIFO with data source reweighting
    Balance the three sources of data: New batch, recent batches, training regression set

    Config:
        config.epochs              - Number of epochs for other batches
        configs for ERM also required
    """

    def train(self, t: int, feedback: dict) -> None:
        if key_is_none(feedback, 'batch_labeled'):
            return
        batch_labeled = feedback['batch_labeled']

        datasets = [batch_labeled]
        d = self.get_kr_datasets(t, True, True)
        datasets = datasets + d
        n = len(batch_labeled) * len(datasets)
        loader = get_balanced_loader(datasets, n,
                                     batch_size=self.config.batch_size, drop_last=True,
                                     **self.config.loader_kwargs)

        epochs = self.config.epochs
        self.erm(epochs, loader=loader, n=n)


class MIR(OCLAlgorithm):
    """
    MIR: Maximally Interfered Retrieval
    1. Virtual update: Train the model on new samples only
    2. Select the previous samples with the largest loss increase before and after virtual update
    3. Real update: Recover the model, and train it on new + selected previous samples

    Config:
        config.epochs              - Number of epochs for other batches
        config.kr_size             - Int. Number of KR batches for each batch of new data
        config.lbd                 - Float. lambda: Weight of KR data. Optional.
        All configs for ERM
    """

    def train(self, t: int, feedback: dict) -> None:
        if key_is_none(feedback, 'batch_labeled'):
            return
        batch_labeled = feedback['batch_labeled']

        assert self.config.kr_size > 1
        lbd = 1.0 if attr_is_none(self.config, 'lbd') else self.config.lbd
        assert lbd >= 0

        # Knowledge retention datasets
        d = self.get_kr_datasets(t, True)
        if not d:
            self.erm(self.config.epochs, batch_labeled)
            return
        dataset_kr = d[0]
        for i in range(1, len(d)):
            dataset_kr = dataset_kr + d[i]


        batch_size = self.config.batch_size
        datasets = [batch_labeled, dataset_kr]
        samplers = [RandomSampler(batch_labeled),
                    RandomSampler(dataset_kr, replacement=True, num_samples=(len(batch_labeled) * self.config.kr_size))]
        index_change = [1, self.config.kr_size]
        concat_dataset = ConcatDataset(datasets)
        concat_batch_sampler = ConcatBatchSampler(samplers, batch_size,
                                                  concat_dataset.offsets, index_change)
        concat_loader = DataLoader(concat_dataset, batch_sampler=concat_batch_sampler,
                                   **self.config.loader_kwargs)

        optimizer = initialize_optimizer_with_model_params(self.config, self.model.parameters())
        n_batches = int(len(batch_labeled) / batch_size)
        n_train_steps = n_batches * self.config.epochs
        scheduler = initialize_scheduler(self.config, optimizer, n_train_steps)

        self.model.train()
        for e in range(self.config.epochs):
            self.loader_enum = enumerate(concat_loader)
            for b in range(n_batches):
                _, (x, y) = next(self.loader_enum)
                x, y = x.to(self.config.device), y.to(self.config.device)

                results_pre_virtual = self.pre_virtual_update()

                # Virtual update step
                outputs = self.model(x)
                optimizer.zero_grad()
                p0 = parameters_to_vector(self.model.parameters())
                opt0 = optimizer.state_dict()
                loss = self.criterion(outputs, y)
                loss.backward()
                if self.config.max_grad_norm is not None:
                    clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                optimizer.step()

                # Find KR samples with greatest loss increases
                x1, y1 = self.post_virtual_update(results_pre_virtual)

                # Recover model and optimizer, and recompute the loss
                optimizer.zero_grad()
                optimizer.load_state_dict(opt0)
                vector_to_parameters(p0, self.model.parameters())
                x1, y1 = x1.to(self.config.device), y1.to(self.config.device)
                outputs = self.model(x)
                outputs1 = self.model(x1)

                # Real update step
                loss = (self.criterion(outputs, y) + lbd * self.criterion(outputs1, y1)) / (lbd + 1)
                loss.backward()
                if self.config.max_grad_norm is not None:
                    clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                optimizer.step()
                if scheduler is not None and scheduler.step_every_batch:
                    scheduler.step()

            if scheduler is not None and not scheduler.step_every_batch:
                scheduler.step()
            del self.loader_enum

    def pre_virtual_update(self):
        with torch.no_grad():
            x_kr = None
            y_kr = None
            loss0_kr = None
            for k in range(self.config.kr_size):
                _, (x1, y1) = next(self.loader_enum)
                x_kr = x1 if x_kr is None else torch.cat((x_kr, x1))
                y_kr = y1 if y_kr is None else torch.cat((y_kr, y1))
                x1, y1 = x1.to(self.config.device), y1.to(self.config.device)
                outputs1 = self.model(x1)
                loss = self.criterion(outputs1, y1)
                loss = loss.detach().cpu().flatten()
                loss0_kr = loss if loss0_kr is None else torch.cat((loss0_kr, loss))
            return x_kr, y_kr, loss0_kr

    def post_virtual_update(self, results_pre_virtual):
        with torch.no_grad():
            x_kr, y_kr, loss0_kr = results_pre_virtual
            batch_size = self.config.batch_size
            loss1_kr = None
            k = 0
            for i in range(self.config.kr_size):
                x1 = x_kr[k:k + batch_size]
                y1 = y_kr[k:k + batch_size]
                x1, y1 = x1.to(self.config.device), y1.to(self.config.device)
                outputs1 = self.model(x1)
                loss = self.criterion(outputs1, y1)
                loss = loss.detach().cpu().flatten()
                loss1_kr = loss if loss1_kr is None else torch.cat((loss1_kr, loss))
                k += batch_size
            idx = torch.argsort(loss1_kr - loss0_kr, descending=True)
            idx = idx[:batch_size]
            x1, y1 = x_kr[idx], y_kr[idx]
            return x1, y1


class MaxLoss(MIR):
    """
    MaxLoss: A variant of MIR
    1. Virtual update: Train the model on new samples only
    2. Select the previous samples with the largest loss after virtual update
    3. Real update: Recover the model, and train it on new + selected previous samples

    Config:
        config.epochs              - Number of epochs for other batches
        config.kr_size             - Int. Number of KR batches for each batch of new data
        config.lbd                 - Float. lambda: Weight of KR data. Optional.
        All configs for ERM
    """

    def pre_virtual_update(self):
        pass

    def post_virtual_update(self, results_pre_virtual):
        with torch.no_grad():
            x_kr, y_kr, loss0_kr = super().pre_virtual_update()
            idx = torch.argsort(loss0_kr, descending=True)
            batch_size = self.config.batch_size
            idx = idx[:batch_size]
            x1, y1 = x_kr[idx], y_kr[idx]
            return x1, y1