from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base import *


class MTL(MultitaskLearning):
    def __init__(self,
                 encoder: nn.Module,
                 tasks_name_to_cls_num: dict,
                 lr=0.001,
                 z_dim=512,
                 device='cuda',
                 **kwargs) -> None:
        super(MTL, self).__init__(encoder,
                                  tasks_name_to_cls_num,
                                  lr,
                                  cls_output_dim=2)
        self.predictors = {name: nn.Linear(
            z_dim, num_class).to(device) for name, num_class in tasks_name_to_cls_num.items()}
        self.tasks_name = tasks_name_to_cls_num.keys()
        trained_parameters = list(self.encoder.parameters())
        for predictor in self.predictors.values():
            trained_parameters += list(predictor.parameters())
            
        self.device = device
        self.optimizer = torch.optim.Adam(trained_parameters, lr=0.001)

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        z = self.encoder(x)
        ys = {task_name: predictor(z) for task_name,
        predictor in self.predictors.items()}
        return ys

    def compute_loss(self,
                     inputs: torch.Tensor,
                     labels: dict,
                     loss_func: nn.Module) -> torch.Tensor:
        self.optimizer.zero_grad()
        tot_loss = 0
        outputs = self.forward(inputs)
        for task_name in labels.keys():
            loss = loss_func(outputs[task_name], labels[task_name])
            tot_loss += loss
        tot_loss.backward()
        self.optimizer.step()
        return tot_loss

    def compute_loss_nograd(self,
                            inputs: torch.Tensor,
                            labels: dict,
                            loss_func: nn.Module) -> torch.Tensor:
        tot_loss = 0
        with torch.no_grad():
            outputs = self.forward(inputs)
            for task_name in labels.keys():
                loss = loss_func(outputs[task_name], labels[task_name])
                tot_loss += loss
        return tot_loss
    
    def calculate_accuraciess(self,
                              valid_loader: torch.utils.data.DataLoader,
                              tasks_name: Tuple[str],
                              device: torch.device) -> dict:
        # assert tasks_name == self.tasks_name
        return super().calculate_accuracies(self.predictors,
                                            valid_loader,
                                            tasks_name,
                                            device)
