import torch
from vlm_eval.attacks.utils import project_perturbation, normalize_grad
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
@torch.no_grad()
def pgd(
        model,
        loss_fn,
        data_clean,
        targets,
        norm,
        eps,
        iterations,
        stepsize,
        output_normalize,
        perturbation=None,
        mode='min',
        momentum=0.9,
        verbose=False
):
    """
    Minimize or maximize given loss
    """
    # make sure data is in image space
    assert torch.max(data_clean) < 1. + 1e-6 and torch.min(data_clean) > -1e-6

    if perturbation is None:
        perturbation = torch.zeros_like(data_clean, requires_grad=True)
    velocity = torch.zeros_like(data_clean)
    for i in range(iterations):
        perturbation.requires_grad = True
        with torch.enable_grad():
            ptb_data = data_clean + perturbation
            out = model(pixel_values = ptb_data, output_normalize=output_normalize)
            loss = loss_fn(out, targets)
            if verbose:
                print(f'[{i}] {loss.item():.5f}')

        with torch.no_grad():
            gradient = torch.autograd.grad(loss, perturbation)[0]
            gradient = gradient
            if gradient.isnan().any():  #
                print(f'attention: nan in gradient ({gradient.isnan().sum()})')  #
                gradient[gradient.isnan()] = 0.
            # normalize
            gradient = normalize_grad(gradient, p=norm)
            # momentum
            velocity = momentum * velocity + gradient
            velocity = normalize_grad(velocity, p=norm)
            # update
            if mode == 'min':
                perturbation = perturbation - stepsize * velocity
            elif mode == 'max':
                perturbation = perturbation + stepsize * velocity
            else:
                raise ValueError(f'Unknown mode: {mode}')
            # project
            perturbation = project_perturbation(perturbation, eps, norm)
            perturbation = torch.clamp(
                data_clean + perturbation, 0, 1
            ) - data_clean  # clamp to image space
            assert not perturbation.isnan().any()
            assert torch.max(data_clean + perturbation) < 1. + 1e-6 and torch.min(
                data_clean + perturbation
            ) > -1e-6

            # assert (ctorch.compute_norm(perturbation, p=self.norm) <= self.eps + 1e-6).all()
    # todo return best perturbation
    # problem is that model currently does not output expanded loss
    return data_clean + perturbation.detach()



def mim_attack(
        model,
        loss_fn,
        data_clean,
        targets,
        eps,
        iterations,
        stepsize,
        output_normalize,
        norm='Linf',  # MIM通常用于L∞攻击，但也可以扩展
        perturbation=None,
        momentum=0.9,   # 动量衰减系数 μ，这是MIM的关键参数
        verbose=False
):
    """
    MIM (Momentum Iterative Method) attack implementation.
    Based on your PGD structure but with MIM-specific gradient normalization.
    """
    # make sure data is in image space
    assert torch.max(data_clean) < 1. + 1e-6 and torch.min(data_clean) > -1e-6

    if perturbation is None:
        perturbation = torch.zeros_like(data_clean)
    velocity = torch.zeros_like(data_clean)  # 动量项
    
    for i in range(iterations):
        # 设置requires_grad并计算梯度
        perturbation.requires_grad = True
        with torch.enable_grad():
            # 创建对抗样本
            ptb_data = data_clean + perturbation
            # 前向传播 (根据您的模型接口调整)
            out = model(pixel_values=ptb_data, output_normalize=output_normalize)
            loss = loss_fn(out, targets)
            if verbose:
                print(f'[MIM {i}] loss: {loss.item():.5f}')

        # 计算梯度
        gradient = torch.autograd.grad(loss, perturbation)[0]
        gradient = gradient.detach()  # 重要：断开计算图
        
        # 检查NaN并处理
        if gradient.isnan().any():
            if verbose:
                print(f'Warning: NaN in gradient ({gradient.isnan().sum()})')
            gradient[gradient.isnan()] = 0.

        # MIM核心：L1归一化 (与您的normalize_grad不同)
        # 将梯度展平为 [batch_size, -1] 然后计算L1范数
        gradient_flat = gradient.view(gradient.size(0), -1)
        l1_norm = torch.norm(gradient_flat, p=1, dim=1).view(-1, 1, 1, 1)
        l1_norm = torch.clamp(l1_norm, min=1e-12)  # 避免除零
        gradient_normalized = gradient / l1_norm

        # 更新动量 (MIM的关键步骤)
        velocity = momentum * velocity + gradient_normalized
        
        # 更新扰动：始终是最大化损失，所以用 '+' 
        perturbation = perturbation + stepsize * torch.sign(velocity)
        perturbation = perturbation.detach()  # 断开计算图

        # 投影到扰动范围内
        if norm == 'Linf':
            # L∞范数：裁剪到[-eps, eps]范围内
            perturbation = torch.clamp(perturbation, -eps, eps)
        else:
            # 对于其他范数，需要相应的投影函数
            # 这里需要您实现或使用现有的project_perturbation函数
            perturbation = project_perturbation(perturbation, eps, norm)
        
        # 确保对抗样本仍在有效图像范围内 [0, 1]
        perturbation = torch.clamp(data_clean + perturbation, 0, 1) - data_clean

        # 验证
        assert not perturbation.isnan().any()
        assert torch.max(data_clean + perturbation) < 1. + 1e-6
        assert torch.min(data_clean + perturbation) > -1e-6

    return data_clean + perturbation.detach()


# 如果您需要其他范数的支持，可能需要这个函数
def project_perturbation(perturbation, eps, norm):
    """
    将扰动投影到指定范数的球体内
    """
    if norm == 'L2':
        # L2范数投影
        norm_val = torch.norm(perturbation.view(perturbation.size(0), -1), p=2, dim=1)
        scale = torch.ones_like(norm_val)
        mask = norm_val > eps
        scale[mask] = eps / norm_val[mask]
        scale = scale.view(-1, 1, 1, 1)
        return perturbation * scale
    elif norm == 'L1':
        # L1范数投影（更复杂，通常使用迭代方法）
        # 这里简化实现，可能需要更复杂的算法
        raise NotImplementedError("L1 projection is more complex")
    else:  # Linf
        return torch.clamp(perturbation, -eps, eps)

@torch.no_grad()
def cw_attack(model, data_clean, targets, confidence=0, 
             learning_rate=0.01, iterations=100, output_normalize=False, 
             verbose=False):
    """
    Carlini-Wagner attack implemented in PGD style
    """
    # Make sure data is in image space
    assert torch.max(data_clean) < 1. + 1e-6 and torch.min(data_clean) > -1e-6
    
    perturbation = torch.zeros_like(data_clean, requires_grad=True)
    best_adv = data_clean.clone()
    best_loss = torch.ones(len(data_clean), device=data_clean.device) * float('inf')
    
    for i in range(iterations):
        perturbation.requires_grad = True
        
        with torch.enable_grad():
            # Create adversarial example
            ptb_data = data_clean + perturbation
            
            # Forward pass
            out = model(pixel_values=ptb_data, output_normalize=output_normalize)
            
            # CW-specific loss calculation
            logits = out.logits if hasattr(out, 'logits') else out
            
            # Get true class scores
            real = logits[torch.arange(len(targets)), targets]
            
            # Get maximum other class scores
            mask = torch.ones_like(logits)
            mask[torch.arange(len(targets)), targets] = 0
            other = (logits * mask).max(dim=1)[0]
            
            # CW loss: maximize the gap between other and real classes
            cw_loss = torch.clamp(real - other + confidence, min=0)
            
            # Total loss with L2 regularization
            l2_norm = torch.norm(perturbation.view(len(perturbation), -1), p=2, dim=1)
            total_loss = cw_loss + 0.001 * l2_norm
            loss = total_loss.mean()
            
            if verbose:
                success_rate = (other > real).float().mean()
                print(f'[{i}] Loss: {loss.item():.5f}, Success: {success_rate:.3f}')

        # Compute gradients
        with torch.no_grad():
            gradient = torch.autograd.grad(loss, perturbation)[0]
            
            if gradient.isnan().any():
                print(f'Warning: NaN in gradient ({gradient.isnan().sum()})')
                gradient[gradient.isnan()] = 0.
            
            # Normalize gradient
            gradient = normalize_grad(gradient, p=2)
            
            # Update perturbation (minimize the CW loss)
            perturbation = perturbation - learning_rate * gradient
            
            # Project to L2 ball
            perturbation = project_perturbation(perturbation, eps=float('inf'), norm=2)
            
            # Clamp to image space [0, 1]
            adv_data = data_clean + perturbation
            adv_data = torch.clamp(adv_data, 0, 1)
            perturbation = adv_data - data_clean
            
            # Track best adversarial examples
            current_success = (other > real).float()
            improved_mask = (current_success > 0) & (l2_norm < best_loss)
            
            if improved_mask.any():
                best_adv[improved_mask] = adv_data[improved_mask]
                best_loss[improved_mask] = l2_norm[improved_mask]
            
            assert not perturbation.isnan().any()
            assert torch.max(adv_data) < 1. + 1e-6 and torch.min(adv_data) > -1e-6

    return best_adv.detach()

def fgsm(model, loss_fn, data_clean, targets, eps, output_normalize=False, verbose=False):
    """
    Fast Gradient Sign Method (FGSM) attack in PGD style
    """
    # Make sure data is in image space
    assert torch.max(data_clean) < 1. + 1e-6 and torch.min(data_clean) > -1e-6
    
    # Create perturbation with requires_grad
    perturbation = torch.zeros_like(data_clean, requires_grad=True)
    
    with torch.enable_grad():
        # Create adversarial example
        ptb_data = data_clean + perturbation
        
        # Forward pass
        out = model(pixel_values=ptb_data, output_normalize=output_normalize)
        
        # Calculate loss
        loss = loss_fn(out, targets)
        
        if verbose:
            print(f'FGSM Loss: {loss.item():.5f}')

    with torch.no_grad():
        # Compute gradients
        gradient = torch.autograd.grad(loss, perturbation)[0]
        
        if gradient.isnan().any():
            print(f'Warning: NaN in gradient ({gradient.isnan().sum()})')
            gradient[gradient.isnan()] = 0.
        
        # Get sign of gradient (FGSM核心)
        sign_grad = gradient.sign()
        
        # Create adversarial perturbation
        perturbation = eps * sign_grad
        
        # Create adversarial example
        adv_data = data_clean + perturbation
        
        # Clamp to image space [0, 1]
        adv_data = torch.clamp(adv_data, 0, 1)
        
        assert not adv_data.isnan().any()
        assert torch.max(adv_data) < 1. + 1e-6 and torch.min(adv_data) > -1e-6

    return adv_data.detach()