
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

from .Buffer import Buffer
from .base import *
from .utils_model import backbone

class SGD(ContinualLearning):
    """
    Implementation of the baseline model for continual learning.
    """

    def __init__(
        self,
        encoder: nn.Module,
        lr=0.001,
        cls_output_dim: int = 2,
        num_tasks: int = 10,
        z_dim: int = 512,
        **kwargs
    ) -> None:
        encoder = backbone(encoder, cls_output_dim=cls_output_dim * num_tasks, z_dim=z_dim)
        super(SGD, self).__init__(encoder, lr, num_tasks, cls_output_dim)
        self.cls_output_dim = cls_output_dim


    def compute_loss(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor,
        not_aug_inputs: torch.Tensor,
        loss_func: nn.Module,
        transform,
        task_id,
    ) -> torch.Tensor:
        """
        SGD trains on the current task using the data provided, with no countermeasures to avoid forgetting.
        """
            
        self.optimizer.zero_grad()
        outputs = self.forward(inputs)
        outputs_sliced = outputs[
            :,
            task_id * self.cls_output_dim : task_id * self.cls_output_dim
            + self.cls_output_dim,
        ]
        loss = loss_func(outputs_sliced, labels)
        loss.backward()
        self.optimizer.step()

        return loss.item()