from copy import deepcopy

import torch
from torch import optim

from methods.base import AdaptiveCL
from utils import Reservoir


class MIR(AdaptiveCL):
    """
    Class implementing the MIR method for Continual Learning.
    """

    def __init__(
        self,
        cl_type,
        model_class,
        batch_size: int,
        n_classes,
        n_tasks: int,
        lr: float,
        buffer_size: int,
        device: torch.device,
        **kwargs
    ):
        super().__init__(cl_type, n_classes, n_tasks, device)
        self.model = model_class(self.n_class_each_task).to(device)
        self.model_v = deepcopy(self.model)
        self.optimizer = optim.SGD(self.model.parameters(), lr=lr)
        self.optimizer_v = optim.SGD(self.model_v.parameters(), lr=lr)
        self.buffer = Reservoir(buffer_size, device)
        self.batch_size = batch_size
        self.sample_batch_size = 256
        self.method_name = "MIR"
        self.criterion_each = torch.nn.CrossEntropyLoss(reduction="none")
        self.lr = lr
        self.k = batch_size

    def predict(self, inputs, task_index):
        """Predict the output for given inputs."""
        return self.model(inputs)

    def update(self, inputs, labels, task_index, test=False):
        """Update the model parameters based on the given inputs and labels."""
        self.model.train()
        self.optimizer.zero_grad()
        loss = self.criterion(self.model(inputs), labels)
        loss.backward()
        with torch.no_grad():
            model_dict = dict(self.model.named_parameters())
            for name, p in self.model_v.named_parameters():
                p.data = model_dict[name].data - self.lr * model_dict[name].grad.data

        if len(self.buffer) >= self.sample_batch_size:
            with torch.no_grad():
                inputs_replay, labels_replay = self.mir_sample()

        self.model.train()
        self.optimizer.zero_grad()
        loss = self.criterion(self.model(inputs), labels)
        if len(self.buffer) >= self.sample_batch_size:
            loss += self.criterion(self.model(inputs_replay), labels_replay)
        loss.backward()
        self.optimizer.step()
        self.buffer.add(zip(inputs, labels.view(-1, 1)))

    def mir_sample(self):
        # sample from buffer
        inputs_replay, labels_replay = self.buffer.sample(self.sample_batch_size)
        self.model_v.eval()
        self.model.eval()
        with torch.no_grad():
            score = self.criterion_each(
                self.model_v(inputs_replay), labels_replay
            ) - self.criterion_each(self.model(inputs_replay), labels_replay)
            indices = torch.argsort(score, descending=True)[: self.k]
        self.model.train()
        return inputs_replay[indices], labels_replay[indices]

    def before_fewshot_test(self):
        super().before_fewshot_test()

    def get_models(self):
        """Get the models used by this instance."""
        return [self.model]

    def mode(self, is_train: bool = True):
        """Switch the mode of operation (training or evaluation)."""
        if is_train:
            self.model.train()
        else:
            self.model.eval()
