import torch
from torch import optim

from methods.base import AdaptiveCL


class ExpVAE(AdaptiveCL):
    """
    A class that implements the Experimental Variational Autoencoder method for continual learning.
    """

    def __init__(
        self,
        cl_type,
        model_class,
        gen_class,
        n_classes,
        n_tasks,
        lr: float,
        device: torch.device,
        **kwargs
    ):
        super().__init__(cl_type, n_classes, n_tasks, device)
        self.oracle = [
            model_class(self.n_class_each_task).to(device) for _ in range(n_tasks)
        ]
        self.oracle_optimizers = [
            optim.SGD(oracle.parameters(), lr=lr) for oracle in self.oracle
        ]
        self.vae = [gen_class().to(device) for _ in range(n_tasks)]
        self.vae_optimizers = [optim.SGD(vae.parameters(), lr=lr) for vae in self.vae]
        self.method_name = "ExpVAE"

    def predict(self, inputs: torch.Tensor, task_index) -> torch.Tensor:
        """Predict the output for given inputs."""
        with torch.no_grad():
            inference_index = torch.tensor(
                [vae.loss(inputs) for vae in self.vae]
            ).argmin()
            predictions = self.oracle[inference_index](inputs).softmax(1)
        return predictions.log()

    def update(
        self, inputs: torch.Tensor, labels: torch.Tensor, task_index: int, test=False
    ):
        """Update the model parameters based on the input data."""
        self.oracle_optimizers[task_index].zero_grad()
        loss = self.criterion(self.oracle[task_index](inputs), labels)
        loss.backward()
        self.oracle_optimizers[task_index].step()

        self.vae_optimizers[task_index].zero_grad()
        loss = self.vae[task_index].loss(inputs)
        loss.backward()
        self.vae_optimizers[task_index].step()

    def before_fewshot_test(self):
        super().before_fewshot_test()

    def get_models(self):
        """Return the current models."""
        return self.oracle

    def mode(self, is_train: bool = True):
        """Set the models to training or evaluation mode."""
        for model in self.oracle + self.vae:
            if is_train:
                model.train()
            else:
                model.eval()
