"""
OCLAlgorithm base class and baselines:
- FBO
- NBO

Configs for ERM:
    config.batch_size          - Batch size for Training (not the size of the data batch)
    config.lr                  - Learning rate
    config.wd                  - Weight decay
    config.device              - Device to be used
    config.loader_kwargs       - Dataloader kwargs

Configs for first batch (t = 0):
    config.epochs_first_batch  - Epochs of training the initial model. If not specified, config.epochs will be used.
    config.initial_model_load  - If specified, load this pretrained model instead of training.
"""

import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader

from optimizer import initialize_optimizer_with_model_params
from scheduler import initialize_scheduler
from utils import key_is_none, attr_is_none, load_ckpt, save_ckpt


class OCLAlgorithm:
    """
    Prototype of an OCL algorithm.
    Usage:
        alg = OCLAlgorithm(model, config)
        alg(t, feedback)
    self.mem is a static memory buffer, which has type dict.
    By default, when t == 0, the model is trained on training and regression sets with ERM.

    Feedback:
        Specify what are required in the feedback.
        By default, feedback requires:
            batch_labeled      - PDSSubDataset. A labeled data batch.
            train_regression   - PDSSubDataset. The training regression set. Optional.

    Config:
        Specify what are required in config.
        By default, requires config.epochs or config.epochs_first_batch
        Specify a pretrained initial model with config.initial_model_load
    """

    def __init__(self, model, config):
        self.model = model
        self.config = config
        self.criterion = self.config.loss_function_dummy
        self.mem = {}

    def train(self, t: int, feedback: dict) -> None:
        """
        Put the algorithm implementation inside this function.
        By default, when t == 0, this function won't be called.

        Inputs:
            t          - Int. If t == 0 then it is the first batch.
            feedback   - dict. The feedback given to the trainer.
        """
        raise NotImplementedError

    def __call__(self, t: int, feedback: dict) -> None:
        if not key_is_none(feedback, 'batch_labeled'):
            self.mem['batch_{}'.format(t)] = feedback['batch_labeled']
        if t == 0:
            dataset = feedback['batch_labeled']
            if not key_is_none(feedback, 'train_regression'):
                self.mem['train_regression'] = feedback['train_regression']
                dataset += self.mem['train_regression']
            if attr_is_none(self.config, 'initial_model_load'):
                epochs = self.config.epochs if attr_is_none(self.config,
                                                            'epochs_first_batch') else self.config.epochs_first_batch
                self.erm(epochs, dataset)
                if not attr_is_none(self.config, 'initial_model_save'):
                    save_ckpt({'model': self.model.state_dict()}, self.config.initial_model_save)
            else:
                d = load_ckpt(self.config.initial_model_load)
                self.model.load_state_dict(d['model'])
        else:
            self.train(t, feedback)

    """Helper functions"""

    def erm(self, epochs: int, dataset=None, loader=None, n=None, loss_func=None) -> None:
        """
        ERM Training
        Input:
            epochs                     - Number of epochs
            dataset                    - The dataset to train on
            loader                     - Alternatively, the DataLoader to train on
            n                          - Size of the dataset (use it with loader because len(loader) is highly not recommended)
            loss_func                  - Alternative loss function. Inputs: x, y, yhat

        Configs for ERM:
            config.batch_size          - Batch size for Training (not the size of the data batch)
            config.lr                  - Learning rate
            config.wd                  - Weight decay
            config.device              - Device to be used
            config.loader_kwargs       - Dataloader kwargs
        """
        if loader is None:
            loader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True,
                                drop_last=True, **self.config.loader_kwargs)
        self.optimizer = initialize_optimizer_with_model_params(self.config, self.model.parameters())

        if n is None:
            n = len(loader) if dataset is None else len(dataset)
        n_train_steps = int(n / self.config.batch_size) * epochs
        scheduler = initialize_scheduler(self.config, self.optimizer, n_train_steps)

        self.model.train()
        for e in range(epochs):
            self.loader_enum = enumerate(loader)
            while True:
                try:
                    _, (x, y) = next(self.loader_enum)
                except StopIteration:
                    break
                x, y = x.to(self.config.device), y.to(self.config.device)
                outputs = self.model(x)
                if loss_func is None:
                    loss = self.criterion(outputs, y)
                else:
                    loss = loss_func(x, y, outputs)
                self.optimizer.zero_grad()
                loss.backward()
                if self.config.max_grad_norm is not None:
                    clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                self.optimizer.step()
                if scheduler is not None and scheduler.step_every_batch:
                    scheduler.step()

            if scheduler is not None and not scheduler.step_every_batch:
                scheduler.step()
            del self.loader_enum

        del self.optimizer

    def get_kr_datasets(self, t, pop_old=False, concat_recent=False):
        """
        Return datasets that require knowledge retention (recent batches + regression set)
        t              - Time stamp
        pop_old        - If True, delete recent datasets that are too old from mem
        concat_recent  - If True, concatenate recent datasets into one dataset
        The regression set is always the last set (if there is one).
        Return a list of datasets
        """
        w = self.config.recent_batches
        if not w:
            w = 0
        datasets = []
        for i in range(1, min(t, w) + 1):
            if not key_is_none(self.mem, 'batch_{}'.format(t - i)):
                datasets.append(self.mem['batch_{}'.format(t - i)])
        if concat_recent and len(datasets) > 0:
            d = datasets[0]
            for i in range(1, len(datasets)):
                d = d + datasets[i]
            datasets = [d]
        if not key_is_none(self.mem, 'train_regression'):
            datasets.append(self.mem['train_regression'])
        if t > w and pop_old:
            if not key_is_none(self.mem, 'batch_{}'.format(t - w - 1)):
                self.mem.pop('batch_{}'.format(t - w - 1))
        return datasets

########################################################
# Baselines
class FBO(OCLAlgorithm):
    """
    First Batch Only: Lower bound

    Config:
        None
    """

    def train(self, t: int, feedback: dict) -> None:
        pass


class NBO(OCLAlgorithm):
    """
    New Batch Only: Lower bound of knowledge retention

    Config:
        config.epochs      - Number of epochs
        All configs for ERM
    """

    def train(self, t: int, feedback: dict) -> None:
        if key_is_none(feedback, 'batch_labeled'):
            return
        dataset = feedback['batch_labeled']
        self.erm(self.config.epochs, dataset)
