import torch
import torch.nn as nn
import torch.optim as optim
from .base import *
from typing import Tuple
from .MTL import MTL

# https://proceedings.neurips.cc/paper/2020/hash/3fe78a8acf5fda99de95303940a2420c-Abstract.html
class PCGrad:
    # wrapper for optimizer to compute PCGrad
    def __init__(self, optimizer, shared_parameters):
        """optimizer: the optimizer being wrapped"""
        self.optimizer = optimizer
        self.shared_parameters = list(shared_parameters)

    def compute_gradients(self, losses):
        """
        losses: a list of loss values (one per task)
        model: the model being optimized
        """
        assert isinstance(losses, list)
        num_tasks = len(losses)
        
        # Shuffle the losses
        losses = [losses[i] for i in torch.randperm(num_tasks)]

        # Compute per-task gradients
        grads_task = []
        for loss in losses:
            self.optimizer.zero_grad()
            loss.backward(retain_graph=True)
            grads_task.append([param.grad.clone().detach() for param in self.shared_parameters if param.grad is not None])

        grads_task_flatten = []
        for grad_task in grads_task:
            grads_task_flatten.append(torch.cat([g.view(-1) for g in grad_task]))

        grads_task_flatten = torch.stack(grads_task_flatten)

        # Compute gradient projections
        def proj_grad(grad_task, grads_task_flatten):
            for k in range(num_tasks):
                inner_product = torch.dot(grad_task, grads_task_flatten[k])
                proj_direction = inner_product / (torch.dot(grads_task_flatten[k], grads_task_flatten[k]) + 1e-10)
                grad_task = grad_task - torch.min(proj_direction, torch.tensor(0.0)) * grads_task_flatten[k]
            return grad_task

        proj_grads_flatten = torch.stack([proj_grad(grads_task_flatten[j], grads_task_flatten) for j in range(num_tasks)])

        # Unpack flattened projected gradients back to their original shapes
        proj_grads = []
        for j in range(num_tasks):
            start_idx = 0
            task_grads = []
            for param in self.shared_parameters:
                grad_shape = param.shape
                flatten_dim = param.numel()
                proj_grad = proj_grads_flatten[j][start_idx:start_idx+flatten_dim]
                proj_grad = proj_grad.view(grad_shape)
                task_grads.append(proj_grad)
                start_idx += flatten_dim
            proj_grads.append(task_grads)
        
        # Combine projected gradients across tasks
        final_shared_grads = [torch.sum(torch.stack([proj_grads[j][i] for j in range(num_tasks)]), dim=0)
                              for i in range(len(self.shared_parameters))]

        # Apply combined gradients
        for param, grad in zip(self.shared_parameters, final_shared_grads):
            param.grad = grad

    def step(self):
        """Apply gradients using the wrapped optimizer"""
        self.optimizer.step()

class PCGradModel(MTL):
    # Naive baseline with PCGrad optimizer
    def __init__(self,
                 encoder: nn.Module,
                 tasks_name_to_cls_num: dict,
                 lr=0.001,
                 device='cuda',
                 z_dim=512,
                 **kwargs) -> None:
        super(PCGradModel, self).__init__(encoder,
                                          tasks_name_to_cls_num,
                                          lr,
                                          device=device,
                                          z_dim=z_dim,
                                          **kwargs)
        self.pcgrad = PCGrad(self.optimizer, self.encoder.parameters())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.encoder(x)
        ys = {task_name: predictor(z) for task_name, predictor in self.predictors.items()}
        return ys

    def compute_loss(self, inputs: torch.Tensor, labels: dict, loss_func: nn.Module) -> torch.Tensor:
        self.optimizer.zero_grad()
        
        losses = []
        outputs = self.forward(inputs)
        for task_name in labels.keys():
            loss = loss_func(outputs[task_name], labels[task_name])
            losses.append(loss)
        
        self.pcgrad.compute_gradients(losses)
        self.pcgrad.step()
        
        loss_sum = sum(losses)
        return loss_sum
        
# Example usage:
# model = YourModel()
# base_optimizer = optim.SGD(model.parameters(), lr=0.01)
# optimizer = PCGrad(base_optimizer)

# losses = [loss_task1, loss_task2, loss_task3]  # Example loss values
# optimizer.compute_gradients(losses, model)
# optimizer.step()