import torch
import torch.nn as nn

from typing import Optional, Tuple
from .base import *
from .MTL import MTL
# https://github.com/AvivNavon/nash-mtl/blob/main/methods/weight_methods.py


class IMTLG(WeightMethod):
    """TOWARDS IMPARTIAL MULTI-TASK LEARNING: https://openreview.net/pdf?id=IMPnRXEWpvr"""

    def __init__(self, n_tasks, device: torch.device):
        super().__init__(n_tasks, device=device)

    def get_weighted_loss(
        self,
        losses,
        shared_parameters,
        **kwargs,
    ):
        grads = {}
        norm_grads = {}

        for i, loss in enumerate(losses):
            g = list(
                torch.autograd.grad(
                    loss,
                    shared_parameters,
                    retain_graph=True,
                )
            )
            grad = torch.cat([torch.flatten(grad) for grad in g])
            norm_term = torch.norm(grad)

            grads[i] = grad
            norm_grads[i] = grad / norm_term

        G = torch.stack(tuple(v for v in grads.values()))
        D = (G[0,]- G[1:,])

        U = torch.stack(tuple(v for v in norm_grads.values()))
        U = (U[0,] - U[1:,])
        first_element = torch.matmul(G[0,],U.t(),)
        try:
            second_element = torch.inverse(torch.matmul(D, U.t()))
        except:
            # workaround for cases where matrix is singular
            second_element = torch.inverse(
                torch.eye(self.n_tasks - 1, device=self.device) * 1e-8
                + torch.matmul(D, U.t()))

        alpha_ = torch.matmul(first_element, second_element)
        alpha = torch.cat(
            (torch.tensor(1 - alpha_.sum(), device=self.device).unsqueeze(-1), alpha_)
        )

        loss = torch.sum(losses * alpha)

        return loss, dict(weights=alpha)

class IMTLModel(MTL): # inherits MTL
    # Naive baseline with IMTL optimizer
    def __init__(self,
                 encoder: nn.Module,
                 tasks_name_to_cls_num: dict,
                 lr=0.001,
                 z_dim=512,
                 device='cuda',
                 **kwargs) -> None:
        super(IMTLModel, self).__init__(encoder,
                                        tasks_name_to_cls_num,
                                        lr,
                                        device=device,
                                        z_dim=z_dim,
                                        **kwargs)

        self.imtl = IMTLG(len(self.tasks_name), device=self.device)

    def compute_loss(self,
                     inputs: torch.Tensor,
                     labels: dict,
                     loss_func: nn.Module) -> torch.Tensor:
        self.optimizer.zero_grad()

        losses = []
        outputs = self.forward(inputs)
        for task_name in labels.keys():
            loss = loss_func(outputs[task_name], labels[task_name])
            losses.append(loss)
        
        _ = self.imtl.backward(torch.stack(losses),
                               list(self.encoder.parameters()),
                               task_specific_parameters=None,
                               last_shared_parameters=None,
                               representation=None)
        self.optimizer.step()
        
        loss_sum = sum(losses)
        return loss_sum