"""
GEM based methods
- GEM-PDS
"""
import torch
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, RandomSampler
from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters

from algorithms.ocl import OCLAlgorithm
from optimizer import initialize_optimizer_with_model_params
from scheduler import initialize_scheduler
from utils import key_is_none, attr_is_none
from algorithms.alg_utils import ConcatDataset, ConcatBatchSampler


class GEM_PDS(OCLAlgorithm):
    """
    GEM-PDS: A combination of GEM and A-GEM

    Config:
        config.epochs              - Number of epochs for other batches
        config.kr_size             - Number of KR batches for each new batch
        All configs for ERM
    """

    def __init__(self, model, config):
        super().__init__(model, config)
        self.mem['grad_g0'] = 0
        self.mem['grad_other'] = 0

    def train(self, t: int, feedback: dict) -> None:
        if key_is_none(feedback, 'batch_labeled'):
            return
        batch_labeled = feedback['batch_labeled']
        datasets_kr = self.get_kr_datasets(t, True, True)

        self.run_gem_pds(batch_labeled, datasets_kr, self.config.epochs)
        print('Grad count:')
        print('  g0:        {}'.format(self.mem['grad_g0']))
        print('  other:     {}'.format(self.mem['grad_other']))
        print('  g0 / all = {}'.format(self.mem['grad_g0'] / max(self.mem['grad_g0'] + self.mem['grad_other'], 1)))

    def run_gem_pds(self, dataset, datasets_kr, epochs):
        """Main routine of GEM-PDS"""
        krs = 1 if attr_is_none(self.config, 'kr_size') else self.config.kr_size
        datasets = [dataset] + datasets_kr
        n_batches = int(len(dataset) / self.config.batch_size)
        n_kr = len(datasets_kr)
        assert n_kr <= 2
        samplers = [RandomSampler(dataset)] + \
                   [RandomSampler(d, num_samples=len(dataset) * krs, replacement=True) for d in datasets_kr]
        index_change = [1] + [krs] * n_kr
        concat_dataset = ConcatDataset(datasets)
        concat_batch_sampler = ConcatBatchSampler(samplers, self.config.batch_size,
                                                  concat_dataset.offsets, index_change)
        concat_loader = DataLoader(concat_dataset, batch_sampler=concat_batch_sampler,
                                   **self.config.loader_kwargs)

        optimizer = initialize_optimizer_with_model_params(self.config, self.model.parameters())
        n_train_steps = n_batches * epochs
        scheduler = initialize_scheduler(self.config, optimizer, n_train_steps)

        self.model.train()
        for e in range(epochs):
            self.loader_enum = enumerate(concat_loader)
            for b in range(n_batches):
                _, (x, y) = next(self.loader_enum)
                x, y = x.to(self.config.device), y.to(self.config.device)
                outputs = self.model(x)
                loss = self.criterion(outputs, y)
                optimizer.zero_grad()
                loss.backward()

                # GEM
                self.gem_main(n_kr, krs, self.criterion, optimizer)

                # Update model
                if self.config.max_grad_norm is not None:
                    clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                optimizer.step()
                if scheduler is not None and scheduler.step_every_batch:
                    scheduler.step()

            if scheduler is not None and not scheduler.step_every_batch:
                scheduler.step()
            del self.loader_enum

    def gem_main(self, n_kr, krs, criterion, optimizer):
        if n_kr > 0:
            g0 = parameters_to_vector([p.grad for p in self.model.parameters()]).detach()
            grads = []
            for k in range(n_kr):
                gk = torch.zeros_like(g0)
                for j in range(krs):
                    _, (x1, y1) = next(self.loader_enum)
                    x1, y1 = x1.to(self.config.device), y1.to(self.config.device)
                    outputs = self.model(x1)
                    loss = criterion(outputs, y1)
                    optimizer.zero_grad()
                    loss.backward()
                    g = parameters_to_vector([p.grad for p in self.model.parameters()]).detach()
                    gk += g
                gk /= krs
                grads.append(gk)
            if n_kr == 1:
                g1 = grads[0]
                g = self.solve_gem(g0, g1, None)
                del g1
            else:
                g1 = grads[0]
                g2 = grads[1]
                g = self.solve_gem(g0, g1, g2)
            vector_to_parameters(g, [p.grad for p in self.model.parameters()])
            grads.clear()
            del grads
            del g0

    def solve_gem(self, g0, g1, g2):
        """Solve the GEM-PDS optimization problem
        This procedure is described in Appendix D
        """
        eps = 1e-6
        
        if g1 is None:
            return g0
        a = torch.dot(g0, g1).item()
        b = torch.dot(g1, g1).item()
        if g2 is None:
            if a >= 0:
                self.mem['grad_g0'] += 1
                return g0
            else:
                self.mem['grad_other'] += 1
                g = g0 - g1 * a / max(b, eps)
                return g

        c = torch.dot(g1, g2).item()
        d = torch.dot(g0, g2).item()
        e = torch.dot(g2, g2).item()
        if a >= 0 and d >= 0:
            self.mem['grad_g0'] += 1
            return g0

        self.mem['grad_other'] += 1
        p = c * d - a * e
        q = a * c - b * d
        r = b * e - c * c

        if a <= 0 and q <= 0:
            g = g0 - g1 * a / max(b, eps)
            return g
        if d <= 0 and p <= 0:
            g = g0 - g2 * d / max(e, eps)
            return g

        r = max(r, eps)
        g = g0 + p / r * g1 + q / r * g2
        return g