from abc import ABCMeta, abstractmethod

import torch
from torch import nn

from utils import RecordVals, verbose_iterator


class AdaptiveCL(metaclass=ABCMeta):
    """
    An abstract class representing an AdaptiveCL.
    Subclasses must implement the abstract methods.
    """

    def __init__(self, cl_type, n_classes, n_tasks, device):
        """Initialize the class with necessary parameters."""

        self.loss_record = RecordVals()
        self.acc_record = RecordVals()

        self.cumulative_loss = 0
        self.cumulative_acc = 0
        self.average_loss_record = RecordVals()
        self.average_acc_record = RecordVals()

        self.loss_record_fewshot = RecordVals()
        self.acc_record_fewshot = RecordVals()

        self.cumulative_acc_fewshot = 0
        self.average_acc_record_fewshot = RecordVals()
        self.n_tasks = n_tasks
        self.count = 0
        self.count_fewshot = 0
        self.criterion = nn.CrossEntropyLoss(reduction="mean")
        self.n_class_each_task = n_classes
        self.n_classes = n_classes  # // n_tasks
        self.device = device
        self.cl_type = cl_type
        self.avg_acc = 0
        self.moving_average_acc = 0

    @abstractmethod
    def predict(self, inputs, task_index):
        """Predict the output for given inputs."""
        pass

    def train(self, dataloader, verbose=True):
        """Train the model on the given data."""

        for inputs, labels in verbose_iterator(dataloader, verbose):
            if isinstance(inputs, torch.Tensor):
                inputs = inputs.to(self.device)
            if isinstance(inputs, dict):
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
            labels = labels.to(self.device)
            with torch.no_grad():
                # print("-" * 30, isinstance(inputs, dict))
                inputs = self.base(inputs)
            task_index = labels[0] // (self.n_classes // self.n_tasks)
            # labels = labels % self.n_class_each_task

            # Calculate loss and accuracy without updating the model parameters
            with torch.no_grad():
                self.mode(False)
                predictions = self.predict(inputs, task_index)
                loss = self.criterion(predictions, labels).item()
                acc = (predictions.argmax(1) == labels).sum().item() / len(labels)

                # Record cumulative and average loss and accuracy
                self.cumulative_acc += acc
                self.cumulative_loss += loss
                self.count += 1
                self.average_loss_record.add(self.cumulative_loss / self.count)
                self.avg_acc = self.cumulative_acc / self.count
                self.moving_average_acc = 0.99 * self.moving_average_acc + 0.01 * acc
                self.average_acc_record.add(self.avg_acc)
                self.loss_record.add(loss)
                self.acc_record.add(acc)

            # Update model parameters
            self.mode(True)
            self.update(inputs, labels, task_index)

    @abstractmethod
    def update(self, inputs, labels, task_index, test=False):
        """Update the model parameters based on the given inputs and labels."""
        pass

    @abstractmethod
    def get_models(self):
        """Get the models used by this instance."""
        pass

    @abstractmethod
    def mode(self, is_train=True):
        """Switch the mode of operation (training or evaluation)."""
        pass

    @abstractmethod
    def before_fewshot_test(self):
        pass

    def fewshot_test(self, dataloader, verbose=True):
        """Train the model on the given data."""
        self.before_fewshot_test()
        for inputs, labels in verbose_iterator(dataloader, verbose):
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            task_index = labels[0] // (self.n_classes // self.n_tasks)
            # labels = labels % self.n_class_each_task

            # Calculate loss and accuracy without updating the model parameters
            with torch.no_grad():
                self.mode(False)
                predictions = self.predict(inputs, task_index)
                loss_fewshot = self.criterion(predictions, labels).item()
                acc_fewshot = (predictions.argmax(1) == labels).sum().item() / len(
                    labels
                )

                # Record cumulative and average loss and accuracy
                self.cumulative_acc_fewshot += acc_fewshot
                self.count_fewshot += 1
                self.avg_acc_fewshot = self.cumulative_acc_fewshot / self.count_fewshot
                self.average_acc_record_fewshot.add(self.avg_acc_fewshot)
                self.loss_record_fewshot.add(loss_fewshot)
                self.acc_record_fewshot.add(acc_fewshot)

            # Update model parameters
            self.mode(True)
            self.update(inputs, labels, task_index, test=True)

    def get_results(self):
        """Return the average accuracy, average loss, accuracy list, and loss list."""
        avg_acc = (
            torch.tensor(self.average_acc_record.record).unsqueeze(-1).to(self.device)
        )
        avg_loss = (
            torch.tensor(self.average_loss_record.record).unsqueeze(-1).to(self.device)
        )
        acc_list = torch.tensor(self.acc_record.record).to(self.device)
        loss_list = torch.tensor(self.loss_record.record).to(self.device)

        return avg_acc, avg_loss, acc_list, loss_list

    def get_fewshot_results(self):
        avg_acc_fewshot = (
            torch.tensor(self.average_acc_record_fewshot.record)
            .unsqueeze(-1)
            .to(self.device)
        )
        acc_list_fewshot = torch.tensor(self.acc_record_fewshot.record).to(self.device)

        return avg_acc_fewshot, acc_list_fewshot
