from copy import deepcopy

import torch
from torch import optim

from methods.base import AdaptiveCL
from utils import Reservoir


class Oracle(AdaptiveCL):
    def __init__(
        self,
        cl_type,
        model_class,
        n_tasks: int,
        n_classes,
        lr: float,
        device: torch.device,
        buffer_size: int,
        batch_size: int,
        **kwargs
    ):
        super().__init__(cl_type, n_classes, n_tasks, device)
        self.feature = model_class(self.n_class_each_task).to(device)
        self.fc = deepcopy(self.feature.classifier)
        self.feature.classifier = torch.nn.Identity()
        self.feature_opt = optim.SGD(self.feature.parameters(), lr=lr)
        self.oracle = [deepcopy(self.fc) for _ in range(n_tasks)]

        self.oracle_opt = [optim.SGD(md.parameters(), lr=lr) for md in self.oracle]
        self.method_name = "Oracle"
        self.buffer = Reservoir(buffer_size, device)
        self.batch_size = batch_size
        self.seen_tasks = []

    def predict(self, inputs: torch.Tensor, task_index) -> torch.Tensor:
        inputs = self.feature(inputs)
        return self.oracle[task_index](inputs)

    def update(self, inputs, labels, task_index, test=False):
        # if task_index.item() not in self.seen_tasks:
        #     self.seen_tasks.append(task_index.item())
        #     if len(self.seen_tasks) > 1:
        #         self.oracle[task_index].load_state_dict(
        #             self.oracle[self.current_task].state_dict()
        #         )
        # self.current_task = task_index

        self.oracle_opt[task_index].zero_grad()
        self.feature_opt.zero_grad()
        feature = self.feature(inputs)
        loss = self.criterion(self.oracle[task_index](feature), labels)
        if len(self.buffer) >= self.batch_size:
            with torch.no_grad():
                inputs_replay, labels_replay = self.buffer.sample(self.batch_size)
            inputs_replay = self.feature(inputs_replay)
            # with torch.no_grad():
            preds_r = self.oracle[task_index](inputs_replay)
            loss += self.criterion(preds_r, labels_replay)
        loss.backward()
        self.oracle_opt[task_index].step()
        self.feature_opt.step()

        self.buffer.add(zip(inputs, labels.view(-1, 1)))

    def before_fewshot_test(self):
        super().before_fewshot_test()

    def get_models(self):
        return {"classifier": self.oracle, "feature": self.feature}

    def mode(self, is_train: bool = True):
        """Set the models to training or evaluation mode."""

        if is_train:
            self.feature.train()
            for oracle in self.oracle:
                oracle.train()
        else:
            self.feature.eval()
            for oracle in self.oracle:
                oracle.eval()
