import copy
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

from .Buffer import Buffer
from .base import *
from .utils_model import backbone


class Single(ContinualLearning):
    # baseline model that does not use any continual learning techniques
    def __init__(
        self,
        encoder: nn.Module,
        lr=0.001,
        cls_output_dim: int = 2,
        num_tasks: int = 10,
        z_dim: int = 512,
        dataset_name: str = "celeba",
        device="cuda",
        **kwargs
    ) -> None:

        # backbone_encoder = backbone(
        #     encoder, cls_output_dim=cls_output_dim, z_dim=z_dim
        # )
        super(Single, self).__init__(encoder, lr, num_tasks, cls_output_dim)
        # self.empty_encoder = copy.deepcopy(encoder)
        self.cls_output_dim = cls_output_dim
        self.z_dim = z_dim
        self.encoders_sav = nn.ModuleDict({})
        self.device = device
        self.cur_task_name = None
        self.lr = lr

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        return self.encoders_sav[self.cur_task_name](x)

    def begin_task(self, dataloader, task_name, task_id, **kwargs):
        # reset the encoder to the empty encoder for next task

        self.cur_task_name = task_name
        self.encoders_sav[task_name] = backbone(
            copy.deepcopy(self.encoder),
            cls_output_dim=self.cls_output_dim,
            z_dim=self.z_dim,
        ).to(self.device)
        self.optimizer = torch.optim.Adam(
            self.encoders_sav[task_name].parameters(), lr=self.lr
        )
        return super().begin_task(dataloader, task_name, task_id, **kwargs)

    # def end_task(self, dataloader, task_name, task_id, **kwargs):

    #     self.encoders_sav[task_name] = copy.deepcopy(self.encoder)
    #     return

    def compute_loss_on_task_id(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor,
        loss_func: nn.Module,
        task_id: int,
        task_name: str,
        **kwargs
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # NOTE: this might have problem for other things
        with torch.no_grad():
            self.encoders_sav[task_name].eval()
            outputs = self.encoders_sav[task_name](inputs)
            loss = loss_func(outputs, labels)
            self.encoders_sav[task_name].train()
        return None, None, loss

    def compute_loss(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor,
        not_aug_inputs: torch.Tensor,
        loss_func: nn.Module,
        custom_transform,
        task_id,
    ) -> torch.Tensor:
        self.encoders_sav.train()
        self.optimizer.zero_grad()
        tot_loss = 0
        outputs = self.encoders_sav[self.cur_task_name](inputs)
        loss = loss_func(outputs, labels)
        tot_loss += loss
        tot_loss.backward()
        self.optimizer.step()
        return tot_loss.item()

    def calculate_accuraciess(
        self,
        valid_loader: torch.utils.data.DataLoader,
        tasks_name: Tuple[str],
        device: torch.device,
    ) -> dict:
        # if len(predictors) == 0:
        #     return dict()
        # for task_name in tasks_name:
        #     assert predictors[task_name]
        correct = [0] * len(tasks_name)
        eces = [
            CalibrationError(task="multiclass", n_bins=15, num_classes=2)
            for _ in range(len(tasks_name))
        ]
        f1s = [
            F1Score(task="binary", num_classes=2).to(device)
            for _ in range(len(tasks_name))
        ]
        recalls = [
            Recall(task="binary", num_classes=2).to(device)
            for _ in range(len(tasks_name))
        ]
        precision = [
            Precision(task="binary", num_classes=2).to(device)
            for _ in range(len(tasks_name))
        ]
        total = len(valid_loader.dataset)
        self.encoder.eval()
        result = dict()
        with torch.no_grad():
            for sample in valid_loader:
                images = sample["image"].to(device)
                for idx, task_name in enumerate(tasks_name):
                    cur_task_y = (
                        sample[task_name].type(torch.LongTensor).to(device)
                    )
                    outputs = self.encoders_sav[task_name](images)
                    _, predicted = torch.max(outputs.data, 1)
                    correct[idx] += (predicted == cur_task_y).sum().item()
                    probabilities = F.softmax(outputs, dim=1)
                    eces[idx].update(probabilities, cur_task_y)
                    f1s[idx].update(predicted, cur_task_y)
                    recalls[idx].update(predicted, cur_task_y)
                    precision[idx].update(predicted, cur_task_y)
        for idx, task_name in enumerate(tasks_name):
            result[task_name] = correct[idx] / total
            result[task_name + "_ece"] = eces[idx].compute().item()
            result[task_name + "_f1"] = f1s[idx].compute().item()
            result[task_name + "_recall"] = recalls[idx].compute().item()
            result[task_name + "_precision"] = precision[idx].compute().item()
        return result

    # def load_state_dict(self, state_dict, strict=True):
    #     """
    #     Override the default load_state_dict to handle dynamic ModuleDict.
    #     """

    def state_dict(self, **kwargs):
        # Get the state_dict from the parent class
        self.end_task(
            dataloader=None, task_name=self.cur_task_name, task_id=None
        )
        state_dict = super().state_dict(**kwargs)

        return state_dict
