from typing import Dict, List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import CalibrationError, F1Score, Precision, Recall
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

class ContinualLearning(nn.Module):
    # base class for continual learning models
    def __init__(
        self,
        encoder: nn.Module,
        lr: float = 0.001,
        num_tasks: int = 10,
        cls_output_dim: int = 2,
    ) -> None:
        super(ContinualLearning, self).__init__()
        self.encoder = encoder
        self.num_tasks = num_tasks
        self.cls_output_dim = cls_output_dim
        self.optimizer = torch.optim.Adam(self.encoder.parameters(), lr=lr)
        self.model_name = self.__class__.__name__

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.encoder(x)
        return z

    def compute_loss(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Compute the loss for the model.
        """
        raise NotImplementedError

    def compute_loss_on_task_id(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor,
        loss_func: nn.Module,
        task_id: int,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        outputs = self.forward(inputs)
        outputs_sliced = outputs[
            :,
            task_id * self.cls_output_dim : task_id * self.cls_output_dim
            + self.cls_output_dim,
        ]
        return outputs, outputs_sliced, loss_func(outputs_sliced, labels)

    def end_task(self, dataloader, task_name, task_id, **kwargs):
        """
        End the current task. Not all models need this.
        """
        pass

    def begin_task(self, dataloader, task_name, task_id, **kwargs):
        """
        beginning of the current task. Not all models need this.
        """
        pass

    def calculate_accuracy(
        self,
        predictor: nn.Module,
        valid_loader: torch.utils.data.DataLoader,
        task_name: str,
        device: torch.device,
    ) -> float:
        correct = 0
        total = 0
        with torch.no_grad():
            for sample in valid_loader:
                images = sample["image"].to(device)
                cur_task_y = sample[task_name].type(torch.LongTensor).to(device)
                z = self.encoder(images)
                outputs = predictor(z)
                _, predicted = torch.max(outputs.data, 1)
                total += cur_task_y.size(0)
                correct += (predicted == cur_task_y).sum().item()
        return correct / total

    def calculate_accuraciess(
        self,
        # or emply list?
        valid_loader: torch.utils.data.DataLoader,
        tasks_name: Tuple[str],
        device: torch.device,
    ) -> dict:

        # for task_name in tasks_name:
        #     assert predictors[task_name]
        correct = [0] * len(tasks_name)
        eces = [
            CalibrationError(task="multiclass", n_bins=15, num_classes=2)
            for _ in range(len(tasks_name))
        ]
        f1s = [
            F1Score(task="binary", num_classes=2).to(device)
            for _ in range(len(tasks_name))
        ]
        recalls = [
            Recall(task="binary", num_classes=2).to(device)
            for _ in range(len(tasks_name))
        ]
        precision = [
            Precision(task="binary", num_classes=2).to(device)
            for _ in range(len(tasks_name))
        ]
        total = len(valid_loader.dataset)
        self.encoder.eval()
        result = dict()
        zs = []
        ys = []
        with torch.no_grad():
            for sample in valid_loader:
                images = sample["image"].to(device)
                # this should be outside of the inner loop
                outputs = self.forward(images)
                zs.extend(outputs.cpu().numpy())
                for idx, task_name in enumerate(tasks_name):
                    cur_task_y = (
                        sample[task_name].type(torch.LongTensor).to(device)
                    )
                    outputs_sliced = outputs[
                        :,
                        idx * self.cls_output_dim : idx * self.cls_output_dim
                        + self.cls_output_dim,
                    ]
                    # outputs = predictors[task_name](z)
                    _, predicted = torch.max(outputs_sliced.data, 1)
                    correct[idx] += (predicted == cur_task_y).sum().item()
                    probabilities = F.softmax(outputs_sliced, dim=1)
                    eces[idx].update(probabilities, cur_task_y)
                    f1s[idx].update(predicted, cur_task_y)
                    recalls[idx].update(predicted, cur_task_y)
                    precision[idx].update(predicted, cur_task_y)
                    
                    if task_name == tasks_name[-1]:
                        ys.extend(cur_task_y.cpu().numpy())

                    # eces[idx] += ece.forward(outputs_sliced, cur_task_y).item()

        for idx, task_name in enumerate(tasks_name):
            result[task_name] = correct[idx] / total
            result[task_name + "_ece"] = eces[idx].compute().item()
            result[task_name + "_f1"] = f1s[idx].compute().item()
            result[task_name + "_recall"] = recalls[idx].compute().item()
            result[task_name + "_precision"] = precision[idx].compute().item()
        
        # save the tsne plot figure of the latent space
        zs = np.array(zs)
        ys = np.array(ys)

        return result, zs, ys

    def calculate_accuracies(
        self,
        # or emply list?
        predictors: dict,
        valid_loader: torch.utils.data.DataLoader,
        tasks_name: Tuple[str],
        device: torch.device,
    ) -> dict:

        if len(predictors) == 0:
            return dict()
        for task_name in tasks_name:
            assert predictors[task_name]
        correct = [0] * len(tasks_name)
        eces = [
            CalibrationError(task="multiclass", n_bins=15, num_classes=2)
            for _ in range(len(tasks_name))
        ]
        f1s = [
            F1Score(task="binary", num_classes=2).to(device)
            for _ in range(len(tasks_name))
        ]
        recalls = [
            Recall(task="binary", num_classes=2).to(device)
            for _ in range(len(tasks_name))
        ]
        precision = [
            Precision(task="binary", num_classes=2).to(device)
            for _ in range(len(tasks_name))
        ]
        total = len(valid_loader.dataset)
        self.encoder.eval()
        result = dict()
        with torch.no_grad():
            for sample in valid_loader:
                images = sample["image"].to(device)
                for idx, task_name in enumerate(tasks_name):
                    cur_task_y = (
                        sample[task_name].type(torch.LongTensor).to(device)
                    )
                    z = self.encoder(images)
                    outputs = predictors[task_name](z)
                    _, predicted = torch.max(outputs.data, 1)
                    correct[idx] += (predicted == cur_task_y).sum().item()
                    probabilities = F.softmax(outputs, dim=1)
                    eces[idx].update(probabilities, cur_task_y)
                    f1s[idx].update(predicted, cur_task_y)
                    recalls[idx].update(predicted, cur_task_y)
                    precision[idx].update(predicted, cur_task_y)
        for idx, task_name in enumerate(tasks_name):
            result[task_name] = correct[idx] / total
            result[task_name + "_ece"] = eces[idx].compute().item()
            result[task_name + "_f1"] = f1s[idx].compute().item()
            result[task_name + "_recall"] = recalls[idx].compute().item()
            result[task_name + "_precision"] = precision[idx].compute().item()
        return result

    def train_test_predictor(
        self,
        train_loader: torch.utils.data.DataLoader,
        valid_loader: torch.utils.data.DataLoader,
        task_name: str,
        device: torch.device,
        latent_dim: int = 512,
    ) -> Tuple[torch.nn.Module, float]:
        # add one linear evaluation:
        linear_model = torch.nn.Linear(latent_dim, 2)
        linear_model.to(device)
        optimizer = torch.optim.Adam(linear_model.parameters(), lr=0.001)
        criterion = torch.nn.CrossEntropyLoss()
        # one epoch of training
        self.eval()
        best_model = None
        best_loss = torch.inf
        for sample in train_loader:
            images = sample["image"].to(device)
            cur_task_y = sample[task_name].type(torch.LongTensor).to(device)
            z = self.encoder(images)
            optimizer.zero_grad()
            outputs = linear_model(z)
            loss = criterion(outputs, cur_task_y)
            loss.backward()
            optimizer.step()
            if loss.item() < best_loss:
                best_loss = loss.item()
                best_model = linear_model
        # assuming classification task
        return best_model, self.calculate_accuracy(
            best_model, valid_loader, task_name, device
        )

    def get_parameters(self):
        """
        Returns the parameters of the model.
        """
        return self.encoder.parameters()


class MultitaskLearning(ContinualLearning):
    # base class for multitask learning models
    # as of now completely identical to ContinualLearning, just for future extensibility
    def __init__(
        self,
        encoder: nn.Module,
        tasks_name_to_cls_num: dict,
        lr: float = 0.001,
        cls_output_dim: int = 2,
    ) -> None:
        self.tasks_name = tasks_name_to_cls_num.keys()
        super(MultitaskLearning, self).__init__(
            encoder, lr, len(self.tasks_name), cls_output_dim
        )

    def compute_loss_nograd(
        self, inputs: torch.Tensor, labels: dict, loss_func: nn.Module
    ) -> torch.Tensor:
        """
        Compute the loss for the model.
        """
        raise NotImplementedError


# https://github.com/AvivNavon/nash-mtl/blob/main/methods/weight_methods.py


class WeightMethod:
    def __init__(self, n_tasks: int, device: torch.device):
        super().__init__()
        self.n_tasks = n_tasks
        self.device = device

    def get_weighted_loss(
        self,
        losses: torch.Tensor,
        shared_parameters: Union[
            List[torch.nn.parameter.Parameter], torch.Tensor
        ],
        task_specific_parameters: Union[
            List[torch.nn.parameter.Parameter], torch.Tensor
        ],
        last_shared_parameters: Union[
            List[torch.nn.parameter.Parameter], torch.Tensor
        ],
        representation: Union[torch.nn.parameter.Parameter, torch.Tensor],
        **kwargs,
    ):
        raise NotImplementedError(
            "get_weighted_loss method must be implemented"
        )

    def backward(
        self,
        losses: torch.Tensor,
        shared_parameters: Union[
            List[torch.nn.parameter.Parameter], torch.Tensor
        ] = None,
        task_specific_parameters: Union[
            List[torch.nn.parameter.Parameter], torch.Tensor
        ] = None,
        last_shared_parameters: Union[
            List[torch.nn.parameter.Parameter], torch.Tensor
        ] = None,
        representation: Union[
            List[torch.nn.parameter.Parameter], torch.Tensor
        ] = None,
        **kwargs,
    ) -> Tuple[Union[torch.Tensor, None], Union[dict, None]]:
        """

        Parameters
        ----------
        losses :
        shared_parameters :
        task_specific_parameters :
        last_shared_parameters : parameters of last shared layer/block
        representation : shared representation
        kwargs :

        Returns
        -------
        Loss, extra outputs
        """
        loss, extra_outputs = self.get_weighted_loss(
            losses=losses,
            shared_parameters=shared_parameters,
            task_specific_parameters=task_specific_parameters,
            last_shared_parameters=last_shared_parameters,
            representation=representation,
            **kwargs,
        )
        loss.backward()
        return loss, extra_outputs

    def __call__(
        self,
        losses: torch.Tensor,
        shared_parameters: Union[
            List[torch.nn.parameter.Parameter], torch.Tensor
        ] = None,
        task_specific_parameters: Union[
            List[torch.nn.parameter.Parameter], torch.Tensor
        ] = None,
        **kwargs,
    ):
        return self.backward(
            losses=losses,
            shared_parameters=shared_parameters,
            task_specific_parameters=task_specific_parameters,
            **kwargs,
        )

    def parameters(self) -> List[torch.Tensor]:
        """return learnable parameters"""
        return []
