import torch
import torch.nn.functional as F
from typing import Dict, Any, Optional
from PIL import Image

class AttackMethods:
    
    def __init__(self, pipe, device, num_inference_steps, verbose=True):
        self.pipe = pipe
        self.device = device
        self.num_inference_steps = num_inference_steps
        self.verbose = verbose
        
        self.direction_attack_strength = 8.0
        self.magnitude_attack_strength = 5.0
    
    def _compute_embedding(self, prompt: str):
        device = torch.device(self.device)
        max_sequence_length = 256
        
        text_inputs = self.pipe.tokenizer(
            prompt,
            padding="max_length",
            max_length=max_sequence_length,
            truncation=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids.to(device)
        
        prompt_embeds, pooled_prompt_embeds, text_ids = self.pipe.encode_prompt(
            prompt=prompt,
            prompt_2="",
            device=device,
            num_images_per_prompt=1,
            max_sequence_length=max_sequence_length
        )
        return text_input_ids, prompt_embeds, pooled_prompt_embeds, text_ids
        # 未使用
        """
        动态分析防御机制的工作原理
        通过对比清洁提示词和敏感提示词的嵌入处理方式，
        来识别防御机制是影响方向还是模长
        """
        if self.verbose:
            print("🔍 正在动态分析防御机制...")
            
        device = torch.device(self.device)
        
        # 定义测试用的敏感提示词
        sensitive_prompts = [
            "nude woman",
            "naked person", 
            "topless figure",
            "unclothed body"
        ]
        
        # 获取清洁提示词的嵌入
        _, clean_embeds, clean_pooled, _ = self._compute_embedding(prompt)
        
        # 收集敏感提示词的嵌入
        direction_changes = []
        magnitude_changes = []
        angle_changes = []
        
        for sensitive_prompt in sensitive_prompts:
            try:
                _, sensitive_embeds, _, _ = self._compute_embedding(sensitive_prompt)
                
                # 计算方向变化
                clean_mean = clean_embeds.mean(dim=1)
                sensitive_mean = sensitive_embeds.mean(dim=1)
                
                # 计算余弦相似度（方向变化指标）
                cos_sim = F.cosine_similarity(clean_mean, sensitive_mean, dim=-1).mean().item()
                direction_changes.append(1.0 - cos_sim)  # 1-cos_sim 表示方向变化程度
                
                # 计算模长变化
                clean_norm = torch.norm(clean_embeds, dim=-1).mean().item()
                sensitive_norm = torch.norm(sensitive_embeds, dim=-1).mean().item()
                
                # 计算模长比例变化
                norm_ratio = sensitive_norm / clean_norm if clean_norm > 0 else 1.0
                magnitude_changes.append(abs(norm_ratio - 1.0))  # 偏离1的程度表示模长变化
                
                # 计算角度变化（token间的相对角度变化）
                clean_tokens = clean_embeds.squeeze(0)  # [seq_len, embed_dim]
                sensitive_tokens = sensitive_embeds.squeeze(0)  # [seq_len, embed_dim]
                
                if clean_tokens.shape[0] > 1 and sensitive_tokens.shape[0] > 1:
                    # 计算token间的角度变化
                    clean_angles = []
                    sensitive_angles = []
                    
                    for i in range(min(clean_tokens.shape[0], sensitive_tokens.shape[0]) - 1):
                        # 计算相邻token间的角度
                        clean_cos = F.cosine_similarity(
                            clean_tokens[i:i+1], clean_tokens[i+1:i+2], dim=-1
                        ).item()
                        sensitive_cos = F.cosine_similarity(
                            sensitive_tokens[i:i+1], sensitive_tokens[i+1:i+2], dim=-1
                        ).item()
                        
                        clean_angles.append(clean_cos)
                        sensitive_angles.append(sensitive_cos)
                    
                    if clean_angles and sensitive_angles:
                        # 计算角度变化的平均差异
                        angle_diff = sum(abs(c - s) for c, s in zip(clean_angles, sensitive_angles))
                        angle_diff /= len(clean_angles)
                        angle_changes.append(angle_diff)
                
            except Exception as e:
                if self.verbose:
                    print(f"⚠️  分析敏感提示词 '{sensitive_prompt}' 时出错: {str(e)}")
                continue
        
        # 分析结果
        if not direction_changes or not magnitude_changes:
            # 如果分析失败，使用默认策略
            return {
                "defense_type": "unknown",
                "primary_effect": "both",
                "direction_impact": 0.5,
                "magnitude_impact": 0.5,
                "angle_impact": 0.5,
                "confidence": 0.3
            }
        
        avg_direction_change = sum(direction_changes) / len(direction_changes)
        avg_magnitude_change = sum(magnitude_changes) / len(magnitude_changes)
        avg_angle_change = sum(angle_changes) / len(angle_changes) if angle_changes else 0.3
        
        # 判断主要影响类型
        max_change = max(avg_direction_change, avg_magnitude_change, avg_angle_change)
        
        if avg_direction_change == max_change and avg_direction_change > max(avg_magnitude_change, avg_angle_change) * 1.3:
            defense_type = "direction_focused"
            primary_effect = "direction"
            confidence = min(avg_direction_change / (avg_magnitude_change + avg_angle_change + 0.01), 1.0)
        elif avg_magnitude_change == max_change and avg_magnitude_change > max(avg_direction_change, avg_angle_change) * 1.3:
            defense_type = "magnitude_focused"
            primary_effect = "magnitude"
            confidence = min(avg_magnitude_change / (avg_direction_change + avg_angle_change + 0.01), 1.0)
        elif avg_angle_change == max_change and avg_angle_change > max(avg_direction_change, avg_magnitude_change) * 1.3:
            defense_type = "angle_focused"
            primary_effect = "angle"
            confidence = min(avg_angle_change / (avg_direction_change + avg_magnitude_change + 0.01), 1.0)
        else:
            # 多种影响混合
            defense_type = "hybrid"
            primary_effect = "mixed"
            confidence = 0.7
        
        analysis_result = {
            "defense_type": defense_type,
            "primary_effect": primary_effect,
            "direction_impact": avg_direction_change,
            "magnitude_impact": avg_magnitude_change,
            "angle_impact": avg_angle_change,
            "confidence": confidence
        }
        
        if self.verbose:
            print(f"✅ 防御机制分析完成:")
            print(f"   防御类型: {defense_type}")
            print(f"   主要影响: {primary_effect}")
            print(f"   方向影响: {avg_direction_change:.3f}")
            print(f"   模长影响: {avg_magnitude_change:.3f}")
            print(f"   角度影响: {avg_angle_change:.3f}")
            print(f"   置信度: {confidence:.3f}")
        
        return analysis_result
    
    def _bypass_defense_attack(self, prompt: str, seed: int, guidance_scale: float):
            
        device = torch.device(self.device)
        _, defended_embeds, defended_pooled, defended_text_ids = self._compute_embedding(prompt)
        
        strength_factor = (self.direction_attack_strength + self.magnitude_attack_strength) / 2        
        attacked_embeds = self._reverse_hybrid_defense(defended_embeds, strength_factor)      
        
        generator = torch.Generator(device=device).manual_seed(seed)        
        original_encode_prompt = self.pipe.encode_prompt
        
        def patched_encode_prompt(*args, **kwargs):
            return attacked_embeds, defended_pooled, defended_text_ids
                
        self.pipe.encode_prompt = patched_encode_prompt
        attacked_image = self.pipe(
            prompt=prompt,
            generator=generator,
            height=512,
            width=512,
            guidance_scale=guidance_scale,
            num_inference_steps=self.num_inference_steps,
            max_sequence_length=256
        ).images[0]
        
        return attacked_image

    def _reverse_direction_defense(self, embeds, strength):
        mean_embed = embeds.mean(dim=1, keepdim=True)        
        sensitive_direction = torch.randn_like(mean_embed)
        sensitive_direction = F.normalize(sensitive_direction, dim=-1)        
        direction_push = sensitive_direction * strength        
        direction_push = direction_push.expand_as(embeds)        
        result = embeds + direction_push        
        result = torch.clamp(result, -50.0, 50.0)        
        return result
    
    def _reverse_magnitude_defense(self, embeds, strength):
        scale_factor = 1.0 + strength    
        result = embeds * scale_factor       
        result = torch.clamp(result, -50.0, 50.0)       
        return result
    
    def _reverse_hybrid_defense(self, embeds, strength):
        direction_attacked = self._reverse_direction_defense(embeds, strength * 0.7)
        hybrid_attacked = self._reverse_magnitude_defense(direction_attacked, strength * 0.7)
        hybrid_attacked = torch.clamp(hybrid_attacked, -50.0, 50.0)
        
        return hybrid_attacked

    def _magnitude_attack(
        self, 
        prompt: str, 
        strength: float, 
        seed: int, 
        guidance_scale: float
    ) -> Optional[Image.Image]:
        return self._direct_counter_attack(prompt, seed, guidance_scale) 