import torch
import torch.nn.functional as F
from typing import Dict, List, Tuple, Union


class EU:
    """
    Lagrangian-Guided Adaptive Erasure Modulation
    Fast version using Implicit Gradient Surgery.
    """
    def __init__(
        self,
        device: torch.device,
        gamma: float = 0.01,   # the regularization coefficient
        w_lr: float = 0.3,   # the learning rate of the task lambda
        max_norm: float = 1.0, # the maximum gradient norm
        error: float = 0.003, # the error term
        log_loss: bool = False, # whether to log the loss (set to False for better numerical stability)
    ):
        self.device = device
        self.dtype = torch.float32
        
        self.min_losses = torch.zeros(2, device=device, dtype=self.dtype)
        self.w = torch.tensor(1, device=device, dtype=self.dtype, requires_grad=True)

        self.w_opt = torch.optim.Adam([self.w], lr=w_lr, weight_decay=gamma)
        self.max_norm = max_norm
        self.n_tasks = 2
        self.error = float(error)
        self.log_loss = log_loss
        self.prev_ret_loss = None


    def set_min_losses(self, losses):
        self.min_losses = losses.detach().to(device=self.device, dtype=self.dtype)

    def get_weighted_loss(self, ret_loss, fgt_loss):
        losses = torch.stack([ret_loss, fgt_loss]).to(device=self.device, dtype=self.dtype)
        self.prev_ret_loss = ret_loss.detach().to(device=self.device, dtype=self.dtype)
        
        # Normalize losses
        D = losses - self.min_losses + 1e-12
        if self.log_loss:
            D = D.log()
            
        w = self.w  # fp32
        D_copy = D.clone()
        
        # Apply dynamic weighting based on learnable weight w
        if w.item() < 0:
            D_copy[0] = D[0] * 0
        else:
            D_copy[0] = D[0] * (w / (1 + w))
            D_copy[1] = D[1] * (1 / (1 + w))

        return D_copy.sum()

    def update(self, curr_ret_loss, curr_lr=None):
        if self.prev_ret_loss is None:
            # First update, skip
            return
        
        # Handle both float and tensor inputs
        if isinstance(curr_ret_loss, (int, float)):
            curr_ret_loss = torch.tensor(curr_ret_loss, device=self.device, dtype=self.dtype)
        else:
            curr_ret_loss = curr_ret_loss.detach().to(device=self.device, dtype=self.dtype)

        if curr_lr is not None:
            lr = float(curr_lr)
            if self.log_loss:
                delta = ((self.prev_ret_loss - self.min_losses[0] + 1e-12).log() -
                        (curr_ret_loss - self.min_losses[0] + 1e-12).log()) / (lr + 1e-12) - self.error
            else:
                delta = (self.prev_ret_loss - curr_ret_loss) / (lr + 1e-12) - self.error
        else:
            if self.log_loss:
                delta = (self.prev_ret_loss - self.min_losses[0] + 1e-12).log() - \
                        (curr_ret_loss - self.min_losses[0] + 1e-12).log() - self.error
            else:
                delta = (self.prev_ret_loss - curr_ret_loss) - self.error

        if delta.dim() == 0:
            d = delta.to(dtype=self.dtype)  # Already a scalar
        else:
            d = delta.squeeze().to(dtype=self.dtype)  # Squeeze to scalar if needed

        self.w_opt.zero_grad(set_to_none=True)

        self.w.backward(d)

        self.w_opt.step()

    def apply_gradient_surgery(self, ret_loss, fgt_loss, shared_parameters):
        if shared_parameters is None or len(shared_parameters) == 0:
            return
            
        # Step 1: Compute gradients separately for retain and forget tasks
        # First, zero out existing gradients
        for param in shared_parameters:
            if param.grad is not None:
                param.grad.zero_()
        
        # Compute retain gradient
        retain_grads = []
        if ret_loss.requires_grad:
            ret_loss.backward(retain_graph=True)
            for p in shared_parameters:
                g = p.grad
                retain_grads.append(torch.zeros_like(p) if g is None else g.detach().clone())
                if p.grad is not None:
                    p.grad.zero_()
        
        # Compute forget gradient
        forget_grads = []
        if fgt_loss.requires_grad:
            fgt_loss.backward(retain_graph=True)
            for p in shared_parameters:
                g = p.grad
                forget_grads.append(torch.zeros_like(p) if g is None else g.detach().clone())
                if p.grad is not None:
                    p.grad.zero_()
        
        if len(retain_grads) == 0 or len(forget_grads) == 0:
            # If no gradients, use weighted loss
            weighted_loss = self.get_weighted_loss(ret_loss, fgt_loss)
            weighted_loss.backward()
            return
        
        # Step 2: Flatten gradients for computation
        retain_grad_flat = torch.cat([g.reshape(-1) for g in retain_grads]).to(torch.float32)
        forget_grad_flat = torch.cat([g.reshape(-1) for g in forget_grads]).to(torch.float32)
        
        # Step 3: Check if gradients conflict (negative inner product)
        grad_inner_product = torch.dot(retain_grad_flat, forget_grad_flat)
        
        # Step 4: Apply gradient surgery if gradients conflict
        if grad_inner_product < 0:  # Gradients conflict
            # Project retain gradient to remove component opposite to forget gradient
            # This is the PCGrad-style gradient surgery
            # Formula: retain_grad_projected = retain_grad - (retain_grad · forget_grad) / ||forget_grad||^2 * forget_grad
            forget_grad_norm_sq = torch.dot(forget_grad_flat, forget_grad_flat)
            if forget_grad_norm_sq > 1e-8:
                proj_coeff = grad_inner_product / forget_grad_norm_sq
                retain_grad_projected = retain_grad_flat - proj_coeff * forget_grad_flat
                
                # Step 5: Update gradients in parameters
                # Apply dynamic weighting from lagrangian
                # w_retain = self.w / (1 + self.w) if self.w >= 0 else 0.0
                # w_forget = 1.0 / (1 + self.w)
                w_retain = (self.w / (1 + self.w)).clamp(min=0).to(torch.float32)
                w_forget = (1.0 / (1 + self.w)).to(torch.float32)
                
                idx = 0
                for i, param in enumerate(shared_parameters):
                    if param.grad is None:
                        param.grad = torch.zeros_like(param)
                    param_size = param.numel()
                    # Combine projected retain gradient with forget gradient using lagrangian weights
                    retain_part = w_retain * retain_grad_projected[idx:idx+param_size].view(param.shape)
                    forget_part = w_forget * forget_grads[i]
                    param.grad = (retain_part + forget_part).to(dtype=param.dtype)
                    idx += param_size
        else:
            # No conflict, use weighted combination
            # w_retain = self.w / (1 + self.w) if self.w >= 0 else 0.0
            w_retain = (self.w / (1 + self.w)).clamp(min=0).to(torch.float32)
            w_forget = (1.0 / (1 + self.w)).to(torch.float32)
            
            for i, param in enumerate(shared_parameters):
                if param.grad is None:
                    param.grad = torch.zeros_like(param)
                param.grad = (w_retain * retain_grads[i] + w_forget * forget_grads[i]).to(dtype=param.dtype)

    def backward(
        self,
        losses: torch.Tensor,
        shared_parameters: Union[
            List[torch.nn.parameter.Parameter], torch.Tensor
        ] = None,
        not_backward: bool = False,
    ) -> Union[torch.Tensor, None]:
        loss = self.get_weighted_loss(losses[0], losses[1])
        if self.max_norm > 0 and shared_parameters is not None:
            torch.nn.utils.clip_grad_norm_(shared_parameters, self.max_norm)
        if not_backward:
            return loss
        loss.backward()
        return loss

