from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base import *
from loss import DynamicWeightingPDLoss
from .MTL import MTL
import copy

class DynamicDistMTL(MTL):
    def __init__(self,
                 encoder: nn.Module,
                 tasks_name_to_cls_num: dict,
                 lr: float = 0.001,
                 device='cuda',
                 lam_pres = 0.01,
                 lam_pred = 1,
                 **kwargs) -> None:
        super(DynamicDistMTL, self).__init__(encoder,
                                             tasks_name_to_cls_num,
                                             lr,
                                             device,
                                             **kwargs)
        self.past_encoder = copy.deepcopy(encoder)
        
        self.pres_loss = DynamicWeightingPDLoss()
        self.lam_pred = lam_pred
        self.lam_pres = lam_pres

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        z = self.encoder(x)
        ys = {task_name: predictor(z) for task_name,
        predictor in self.predictors.items()}
        return z, ys

    def compute_loss(self,
                     inputs: torch.Tensor,
                     labels: dict,
                     loss_func: nn.Module) -> torch.Tensor:
        self.optimizer.zero_grad()
        self.past_encoder.eval()
        
        tot_loss = 0
        z_new, outputs = self.forward(inputs)
        z_old = self.past_encoder(inputs)
        
        for task_name in labels.keys():
            loss = loss_func(outputs[task_name], labels[task_name])
            tot_loss += loss
        
        # to translate to MTL setting, we concatenate all labels into one tensor
        # and pass it to the pres_loss function
        # which then "intraclass" is defined as the same class across all tasks

        concatenated_label = torch.cat([labels[task_name].unsqueeze(-1).type(torch.long) for task_name in labels.keys()], dim=1)
        
        print(concatenated_label.shape)
        pres_loss = self.pres_loss(z_old, z_new, concatenated_label)
        tot_loss = self.lam_pred * tot_loss + self.lam_pres * pres_loss
        
        tot_loss.backward()
        self.optimizer.step()
        return tot_loss
