import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

class SparseModification(Function):
    """
    Custom autograd function for PEFT_S.
    Implements forward and backward pass as described in the paper.
    """

    @staticmethod
    def forward(ctx, modification, original_output, att_mask, use_anchor, threshold, threshold_update_step, threshold_sparse_strength, call_back):
        """
        modification: M(x); original_output: W_0x; threshold: \tau ; threshold_update_step: s; threshold_sparse_strength: \lambda
        """

        # Compute ||M(x)||_2 / ||W_0x||_2 (Equation 4 in the paper) 
        norm = torch.norm(modification, p=2, dim=-1, keepdim=True)
        x_norm = torch.norm(original_output, p=2, dim=-1, keepdim=True)
        norm_ratio = norm / (1e-10 + x_norm)


        # Compute S (Equation 4 in the paper)
        anchor_mask = (norm_ratio >= threshold).float()


        # Compute masked modification (Equation 7 in the paper) 
        modification_masked = anchor_mask * modification


        # Save for backward pass
        ctx.save_for_backward(modification, anchor_mask, att_mask)
        ctx.threshold_update_step = threshold_update_step
        ctx.threshold_sparse_strength = threshold_sparse_strength
        ctx.call_back = call_back
        ctx.use_anchor = use_anchor

        return modification_masked


    @staticmethod
    def backward(ctx, modification_masked_grad):

        modification, anchor_mask, att_mask = ctx.saved_tensors

        threshold_update_step = ctx.threshold_update_step
        threshold_sparse_strength = ctx.threshold_sparse_strength
        call_back = ctx.call_back
        use_anchor = ctx.use_anchor


        # Compute \mu (Equation 10 in the paper)
        final_modification_grad = modification_masked_grad * anchor_mask
        final_anchor_mask_grad = torch.sum(modification_masked_grad * modification, dim=-1, keepdim=True)


        att_mask = att_mask.float()


        # The computation of Equation 14 in the paper, excluding the multiplication by s, which is applied in the call_back step.
        grad_from_object = final_anchor_mask_grad
        grad_from_object = (grad_from_object > 0.) * (anchor_mask == 1.) * grad_from_object + (grad_from_object <= 0.) * (anchor_mask == 0.) * grad_from_object


        # The computation of Equation 15 in the paper, excluding the multiplication by s, which is applied in the call_back step.
        grad_from_strength = (final_anchor_mask_grad != 0.) * att_mask * anchor_mask * threshold_sparse_strength


        # Compute final gradient for \tau (Equation 17)
        grad_threshold = torch.mean(torch.sum(grad_from_object + grad_from_strength, dim=1, keepdim=False), dim=(0, 1), keepdim=False)

        call_back(threshold_update_step, grad_threshold)

        return final_modification_grad, None, None, None, None, None, None, None



class ThresholdUpdater(torch.nn.Module):
    """
    This module uses an Adam-like optimization strategy for updating the threshold.
    """

    def __init__(self, threshold_update_step, threshold_sparse_strength):
        super(ThresholdUpdater, self).__init__()

        self.threshold_update_step = threshold_update_step
        self.threshold_sparse_strength = threshold_sparse_strength

        # \tau
        self.threshold = nn.Parameter(0.1 * torch.ones(1))

        self.beta1 = 0.9
        self.beta2 = 0.98
        self.epsilon = 1.e-9 

        self.beta1_exp = 0.9
        self.beta2_exp = 0.98

        self.m1 = 0.
        self.m2 = 0.


    def backward_callback(self, threshold_update_step, grad_threshold):

        # Implements Equations 18 to 22 from the paper.
        self.m1 = self.beta1 * self.m1 + (1. - self.beta1) * grad_threshold
        self.m2 = self.beta2 * self.m2 + (1. - self.beta2) * (grad_threshold ** 2)

        m1_hat = self.m1 / (1. - self.beta1_exp)
        m2_hat = self.m2 / (1. - self.beta2_exp)

        self.beta1_exp = self.beta1_exp * self.beta1
        self.beta2_exp = self.beta2_exp * self.beta2

        added = m1_hat / (torch.sqrt(m2_hat) + self.epsilon)  

        self.threshold.data = self.threshold.data + threshold_update_step * added


    def forward(self, modification, original_output, att_mask, use_anchor):
        return SparseModification.apply(modification, original_output, att_mask, use_anchor, self.threshold, self.threshold_update_step, self.threshold_sparse_strength, self.backward_callback)



