import torch
from torch import optim

from methods.base import AdaptiveCL
from utils import Reservoir


class ExperienceReplay(AdaptiveCL):
    """
    Class implementing the Experience Replay method for Continual Learning.
    """

    def __init__(
        self,
        cl_type,
        model_class,
        batch_size: int,
        n_classes,
        n_tasks: int,
        lr: float,
        buffer_size: int,
        device: torch.device,
        **kwargs
    ):
        super().__init__(cl_type, n_classes, n_tasks, device)
        self.model = model_class(self.n_class_each_task).to(device)
        self.optimizer = optim.SGD(self.model.parameters(), lr=lr)
        self.buffer = Reservoir(buffer_size, device)
        self.batch_size = batch_size
        self.method_name = "ER"

    def predict(self, inputs, task_index):
        """Predict the output for given inputs."""
        return self.model(inputs)

    def update(self, inputs, labels, task_index, test=False):
        """Update the model parameters based on the given inputs and labels."""
        self.optimizer.zero_grad()
        loss = self.criterion(self.model(inputs), labels)
        if len(self.buffer) >= self.batch_size:
            with torch.no_grad():
                inputs_replay, labels_replay = self.buffer.sample(self.batch_size)
            loss += self.criterion(self.model(inputs_replay), labels_replay)
        loss.backward()
        self.optimizer.step()
        self.buffer.add(zip(inputs, labels.view(-1, 1)))

    def before_fewshot_test(self):
        super().before_fewshot_test()

    def get_models(self):
        """Get the models used by this instance."""
        return [self.model]

    def mode(self, is_train: bool = True):
        """Switch the mode of operation (training or evaluation)."""
        if is_train:
            self.model.train()
        else:
            self.model.eval()
