from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base import *
from loss import DynamicWeightingPDLoss, DynamicIRDLoss, DynamicWeightingRKDLoss, DynamicPreserveCosineLoss, DynamicWeightingRBFLoss
import copy


class LwP(ContinualLearning):
    def __init__(
        self,
        encoder: nn.Module,
        lr: float = 0.001,
        num_tasks: int = 10,
        cls_output_dim: int = 2,
        device="cuda",
        lam_pres=0.01,
        lam_pred=1,
        z_dim=512,
        enable_dynamic = True,
        dist_method = "orig",
        **kwargs
    ) -> None:
        super(LwP, self).__init__(
            encoder, lr, num_tasks=num_tasks, cls_output_dim=cls_output_dim
        )

        self.device = device
        self.z_dim = z_dim
        self.predictors = nn.ModuleDict()
        self.past_predictors = nn.ModuleDict()
        
        # print("Using dist_method: ", dist_method)
        # print("Enable dynamic weighting: ", enable_dynamic)
        
        loss_dict = {
            "orig": DynamicWeightingPDLoss,
            "co2l": DynamicIRDLoss,
            "rkd": DynamicWeightingRKDLoss,
            "cos": DynamicPreserveCosineLoss,
            "rbf": DynamicWeightingRBFLoss
        }
        loss = loss_dict[dist_method]
        # put current task on top of it
        self.pres_loss = loss(enable_dynamic=enable_dynamic)
        self.lam_pred = lam_pred
        self.lam_pres = lam_pres
        self.optimizer = torch.optim.Adam(
            list(self.encoder.parameters())
            + list(self.predictors.parameters()),
            lr=lr,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.encoder(x)
        ys = {
            task_name: predictor(z)
            for task_name, predictor in self.predictors.items()
        }
        return z, ys

    def compute_loss(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor,
        not_aug_inputs: torch.Tensor,
        loss_func: nn.Module,
        transform,
        task_index,
    ) -> torch.Tensor:
        self.optimizer.zero_grad()
        self.past_encoder.eval()
        self.past_predictors.eval()
        self.encoder.train()
        self.predictors.train()

        pred_loss = 0
        z_new, outputs = self.forward(inputs)
        z_old = self.past_encoder(inputs)

        for task_name, predictor in self.predictors.items():
            # if output is current task
            if task_name == self.cur_task_name:
                # simple supervised learning loss
                loss = loss_func(outputs[task_name], labels)
            # otherwise, create pseudolabel and calculate loss
            else:
                with torch.no_grad():
                    pseudolabel = self.past_predictors[task_name](z_old)
                    pseudolabel = F.softmax(pseudolabel, dim=1)
                loss = loss_func(outputs[task_name], pseudolabel)

            pred_loss += loss
        pred_loss = pred_loss / len(self.predictors)
        # average by number of tasks to make it fair

        # to translate to MTL setting, we concatenate all labels into one tensor
        # and pass it to the pres_loss function
        # which then "intraclass" is defined as the same class across all tasks

        pres_loss = self.pres_loss(z_old, z_new, labels)
        tot_loss = self.lam_pred * pred_loss + self.lam_pres * pres_loss

        tot_loss.backward()
        self.optimizer.step()
        return pred_loss.item()

    def begin_task(
        self,
        dataloader: torch.utils.data.DataLoader,
        task_name: str,
        task_id: int,
        **kwargs
    ) -> None:
        self.past_encoder = copy.deepcopy(self.encoder).to(self.device)
        self.past_predictors = copy.deepcopy(self.predictors).to(self.device)
        self.predictors[task_name] = nn.Linear(
            self.z_dim, self.cls_output_dim
        ).to(self.device)
        self.cur_task_name = task_name
        return

    def compute_loss_on_task_id(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor,
        loss_func: nn.Module,
        task_id: int,
        task_name: str,
        **kwargs
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        with torch.no_grad():
            _, outputs = self.forward(inputs)
            loss = loss_func(outputs[task_name], labels)
        return None, None, loss

    def calculate_accuraciess(
        self,
        valid_loader: torch.utils.data.DataLoader,
        tasks_name: Tuple[str],
        device: torch.device,
    ) -> dict:
        correct = [0] * len(tasks_name)
        eces = [
            CalibrationError(task="multiclass", n_bins=15, num_classes=2)
            for _ in range(len(tasks_name))
        ]
        f1s = [
            F1Score(task="binary", num_classes=2).to(device)
            for _ in range(len(tasks_name))
        ]
        recalls = [
            Recall(task="binary", num_classes=2).to(device)
            for _ in range(len(tasks_name))
        ]
        precision = [
            Precision(task="binary", num_classes=2).to(device)
            for _ in range(len(tasks_name))
        ]

        total = len(valid_loader.dataset)
        self.encoder.eval()
        self.predictors.eval()
        result = dict()
        zs = []
        ys = []
        with torch.no_grad():
            for sample in valid_loader:
                images = sample["image"].to(device)
                z = self.encoder(images)
                zs.extend(z.cpu().numpy())
                for idx, task_name in enumerate(tasks_name):
                    cur_task_y = (
                        sample[task_name].type(torch.LongTensor).to(device)
                    )
                    outputs = self.predictors[task_name](z)
                    _, predicted = torch.max(outputs, 1)
                    correct[idx] += (predicted == cur_task_y).sum().item()

                    probabilities = F.softmax(outputs, dim=1) 
                    eces[idx].update(probabilities, cur_task_y)
                    f1s[idx].update(predicted, cur_task_y)
                    recalls[idx].update(predicted, cur_task_y)
                    precision[idx].update(predicted, cur_task_y)
                    
                    # if it is the last task
                    if task_name == tasks_name[-1]:
                        ys.extend(cur_task_y.cpu().numpy())

        for idx, task_name in enumerate(tasks_name):
            result[task_name] = correct[idx] / total
            result[task_name + "_ece"] = eces[idx].compute().item()
            result[task_name + "_f1"] = f1s[idx].compute().item()
            result[task_name + "_recall"] = recalls[idx].compute().item()
            result[task_name + "_precision"] = precision[idx].compute().item()
            
        zs = np.array(zs)
        ys = np.array(ys)
        tsne_reducer = TSNE(n_components=2, random_state=325235)
        zs = tsne_reducer.fit_transform(zs)
        # plt.scatter(zs[:, 0], zs[:, 1], c=ys)
        # plt.tight_layout()
        # # color code them based on labels
   
        # plt.savefig(f"./figures/tsne_{self.model_name}_task_{len(tasks_name)}.png")
        print(result)
        return result, zs, ys
    
    def load_state_dict(self, state_dict, **kwargs):
        # Remove past_predictors and past_encoder from the state_dict
        state_dict = {
            k: v for k, v in state_dict.items()
            if not k.startswith('past_predictors') and not k.startswith('past_encoder')
        }
        # need to be false as we are not loading the full state_dict
        # Call the parent class's load_state_dict with the filtered state_dict
        super().load_state_dict(state_dict, strict=False, **kwargs)
    
    def state_dict(self, **kwargs):
        # Get the state_dict from the parent class
        state_dict = super().state_dict(**kwargs)
        
        # Remove past_predictors and past_encoder from the state_dict
        state_dict = {
            k: v for k, v in state_dict.items()
            if not k.startswith('past_predictors') and not k.startswith('past_encoder')
        }
        
        return state_dict