import random
import wandb
from typing import List
from copy import deepcopy
import torch
import torch.nn.functional as F

from pytorch_optimizer import PCGrad
from pytorch_optimizer.base.types import OPTIMIZER


class PCGradOptimizer(PCGrad):
    def project_conflicting(self, grads: List[torch.Tensor], has_grads: List[torch.Tensor]) -> torch.Tensor:
        shared: torch.Tensor = torch.stack(has_grads).prod(0).bool()
        
        # Use clone instead of deepcopy
        il_grads, rl_grads = grads[0].detach(), grads[1].detach()
        
        # Compute metrics safely
        cosine_sim = F.cosine_similarity(il_grads.unsqueeze(0), 
                                       rl_grads.unsqueeze(0))[0].item()
        grad_dot_product = torch.dot(il_grads, rl_grads)
        grad_norms = il_grads.norm() * rl_grads.norm()
        conflict_score = (-grad_dot_product / (grad_norms + 1e-8)).item() if grad_dot_product < 0 else 0.0

        # Use clone for PC grad
        pc_grad = [g.clone() for g in grads]
        for i, g_i in enumerate(pc_grad):
            random_order = list(range(len(grads)))
            random.shuffle(random_order)
            for j in random_order:
                g_j = grads[j]
                g_i_g_j = torch.dot(g_i, g_j)
                if g_i_g_j < 0:
                    g_i.sub_((g_i_g_j / (g_j.norm().square() + 1e-8)) * g_j)

        merged_grad = torch.zeros_like(grads[0])
        shared_pc_gradients = torch.stack([g[shared] for g in pc_grad])
        
        if self.reduction == 'mean':
            merged_grad[shared] = shared_pc_gradients.mean(dim=0)
        else:
            merged_grad[shared] = shared_pc_gradients.sum(dim=0)
        merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)

        # Store only numerical values, move histograms to final logging
        self.grad_info = {
            'pcgrad/il_grad_norm': torch.norm(il_grads).item(),
            'pcgrad/rl_grad_norm': torch.norm(rl_grads).item(),
            'pcgrad/final_grad_norm': torch.norm(merged_grad).item(),
            'pcgrad/grad_cosine_similarity': cosine_sim,
            'pcgrad/conflict_score': conflict_score,
            'pcgrad/il_modification': torch.norm(merged_grad - il_grads).item(),
            'pcgrad/rl_modification': torch.norm(merged_grad - rl_grads).item(),
            'pcgrad/shared_param_ratio': shared.float().mean().item(),
        }
        
        return merged_grad
