import torch
from torch import optim

from methods.base import AdaptiveCL
from utils import RecordVals


class Separate(AdaptiveCL):
    """
    Decomposing Adaptive CL into three separate stages.
    """

    def __init__(
        self,
        cl_type,
        model_class,
        n_tasks: int,
        n_classes,
        lr: float,
        device: torch.device,
        batch_size,
        **kwargs
    ):
        super().__init__(cl_type, n_classes, n_tasks, device)
        self.device = device
        self.model_list = [model_class(self.n_class_each_task).to(device)]
        self.opt_list = [optim.SGD(md.parameters(), lr=lr) for md in self.model_list]
        self.method_name = "Separate"
        self.current_index = 0
        self.moving_average_acc_list = torch.zeros(1)
        self.model_class = model_class
        self.cumulative_acc_list = torch.zeros(1)
        self.count_list = torch.zeros(1)
        self.lr = lr
        self.inferred_index_record = RecordVals()

    def predict(self, inputs: torch.Tensor, task_index) -> torch.Tensor:
        return self.model_list[self.current_index](inputs)

    def update(self, inputs, labels, task_index, test=False):
        # change point detection
        self.CPdetection(inputs, labels)
        self.opt_list[self.current_index].zero_grad()
        loss = self.criterion(self.model_list[self.current_index](inputs), labels)
        loss.backward()
        self.opt_list[self.current_index].step()
        if len(self.model_list) > 20:
            self.prune()

    def CPdetection(self, inputs, labels):
        """Detect change point via exponential moving average."""
        with torch.no_grad():
            current_acc = torch.tensor(
                [
                    (model(inputs).argmax(1) == labels).sum() / len(labels)
                    for model in self.model_list
                ]
            )
            self.moving_average_acc_list += 0.1 * (
                current_acc - self.moving_average_acc_list
            )
            self.cumulative_acc_list += current_acc
            self.count_list += 1
            if max(self.moving_average_acc_list) < 0.95 * self.avg_acc:
                self.current_index = len(self.model_list)
                self.model_list.append(
                    self.model_class(self.n_class_each_task).to(self.device)
                )
                self.moving_average_acc_list = torch.cat(
                    (self.moving_average_acc_list, torch.zeros(1))
                )
                self.cumulative_acc_list = torch.cat(
                    (self.cumulative_acc_list, torch.zeros(1))
                )
                self.count_list = torch.cat((self.count_list, torch.ones(1)))
                self.opt_list.append(
                    optim.SGD(self.model_list[-1].parameters(), lr=self.lr)
                )
                self.current_index = len(self.model_list) - 1
            else:
                self.current_index = self.moving_average_acc_list.argmax()

    def prune(self):
        index = (self.cumulative_acc_list / self.cumulative_acc_list).argmin()
        self.model_list.pop(index)
        self.opt_list.pop(index)
        self.moving_average_acc_list = torch.cat(
            (
                self.moving_average_acc_list[:index],
                self.moving_average_acc_list[index + 1 :],
            )
        )
        self.cumulative_acc_list = torch.cat(
            (self.cumulative_acc_list[:index], self.cumulative_acc_list[index + 1 :])
        )
        self.count_list = torch.cat(
            (self.count_list[:index], self.count_list[index + 1 :])
        )
        self.current_index = self.moving_average_acc_list.argmax()

    def before_fewshot_test(self):
        super().before_fewshot_test()

    def get_models(self):
        return self.model_list

    def mode(self, is_train: bool = True):
        """Set the models to training or evaluation mode."""
        for oracle in self.model_list:
            if is_train:
                oracle.train()
            else:
                oracle.eval()
