import torch
from torch import optim

from methods.base import AdaptiveCL


class Finetune(AdaptiveCL):
    """
    A class that implements the Finetune method for continual learning.
    """

    def __init__(
        self,
        cl_type,
        model_class,
        n_classes,
        n_tasks: int,
        lr: float,
        device: torch.device,
        **kwargs
    ):
        super().__init__(cl_type, n_classes, n_tasks, device)
        self.finetune = model_class(self.n_class_each_task).to(device)
        self.finetune_opt = optim.SGD(self.finetune.parameters(), lr=lr)
        self.method_name = "Finetune"

    def predict(self, inputs: torch.Tensor, task_index) -> torch.Tensor:
        """
        Make predictions on the input data using the fine-tuned model.
        """
        return self.finetune(inputs)

    def update(
        self, inputs: torch.Tensor, labels: torch.Tensor, task_index: int, test=False
    ):
        """
        Update the model based on the input and target data.
        """
        self.finetune_opt.zero_grad()
        loss = self.criterion(self.finetune(inputs), labels)
        loss.backward()
        self.finetune_opt.step()

    def before_fewshot_test(self):
        super().before_fewshot_test()

    def get_models(self) -> list:
        """
        Return the current model.
        """
        return [self.finetune]

    def mode(self, is_train: bool = True):
        """
        Set the model to training or evaluation mode.
        """
        self.finetune.train() if is_train else self.finetune.eval()
