import torch
import torch.nn.functional as F
import numpy as np
import sys
import os

current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)


class LinfPGDAttack(object):
    def __init__(self, wrapper=None, ComG=None, device=None, epsilon=0.05, k=10, a=0.01):
        self.wrapper = wrapper
        self.ComG = ComG
        self.epsilon = epsilon
        self.k = k
        self.a = a
        self.loss_fn = F.mse_loss
        self.device = device
        self.rand = True 
    
    def perturb(self, X_nat, ref=None):
        if self.rand:
            X = X_nat.clone().detach() + torch.tensor(
                np.random.uniform(-self.epsilon, self.epsilon, X_nat.shape).astype('float32')
            ).to(self.device)
        else:
            X = X_nat.clone().detach()
        
        for i in range(self.k):
            X.requires_grad = True
            
            if hasattr(self.wrapper, 'zero_grad'):
                self.wrapper.zero_grad()
            if hasattr(self.ComG, 'zero_grad'):
                self.ComG.zero_grad()
            
            X_comg = self.ComG(X).contiguous()
            encoded = self.wrapper.encode(X_comg)
            output = self.wrapper.decode(encoded, ref=ref)
            
            with torch.no_grad():
                X_nat_comg = self.ComG(X_nat).contiguous()
                encoded_nat = self.wrapper.encode(X_nat_comg)
                output_nat = self.wrapper.decode(encoded_nat, ref=ref)
            
            loss = self.loss_fn(output, output_nat)
            loss.backward()
            
            grad = X.grad
            X_adv = X + self.a * grad.sign()
            eta = torch.clamp(X_adv - X_nat, min=-self.epsilon, max=self.epsilon)
            X = torch.clamp(X_nat + eta, min=-1.0, max=1.0).detach()
        
        return X, X - X_nat


def df_rap_attack(wrapper, X_nat, epsilon=0.05, alpha=0.01, steps=10, 
                  ref=None, ComG=None):
    if ComG is None:
        raise ValueError("ComG is required for DF_RAP attack")
    
    device = X_nat.device
    
    attack = LinfPGDAttack(
        wrapper=wrapper,
        ComG=ComG,
        device=device,
        epsilon=epsilon,
        k=steps,
        a=alpha
    )
    
    X_adv, _ = attack.perturb(X_nat, ref=ref)
    
    return X_adv


def df_rap_attack_legacy(wrapper, X_nat, epsilon=0.05, alpha=0.01, steps=10,
                         ref=None, faketype="simswap", model=None, ComG=None):
    device = X_nat.device
    
    X = X_nat.clone().detach() + torch.tensor(
        np.random.uniform(-epsilon, epsilon, X_nat.shape).astype('float32')
    ).to(device)
    
    for i in range(steps):
        X.requires_grad = True
        
        if faketype == "StarGAN":
            if use_comg and ComG is not None:
                if ComG_woj is not None:
                    output1, _ = model.features(ComG(X), ref)
                    output2, _ = model.features(ComG_woj(X), ref)
                    output = balance * output1 + (1.0 - balance) * output2
                else:
                    output, _ = model.features(ComG(X), ref)
            else:
                output, _ = model.features(X, ref)
                
        elif faketype == "simswap":
            if use_comg and ComG is not None:
                if ComG_woj is not None:
                    img_id_downsample1 = F.interpolate(ComG(X), size=(112, 112))
                    latent_id1 = model.netArc(img_id_downsample1)
                    latent_id1 = latent_id1 / torch.norm(latent_id1, p=2, dim=1, keepdim=True)
                    output1 = model(ComG(X), ref, latent_id1, latent_id1, True)
                    
                    img_id_downsample2 = F.interpolate(ComG_woj(X), size=(112, 112))
                    latent_id2 = model.netArc(img_id_downsample2)
                    latent_id2 = latent_id2 / torch.norm(latent_id2, p=2, dim=1, keepdim=True)
                    output2 = model(ComG_woj(X), ref, latent_id2, latent_id2, True)
                    
                    output = balance * output1 + (1.0 - balance) * output2
                else:
                    img_id_downsample = F.interpolate(ComG(X), size=(112, 112))
                    latent_id = model.netArc(img_id_downsample)
                    latent_id = latent_id / torch.norm(latent_id, p=2, dim=1, keepdim=True)
                    output = model(ComG(X), ref, latent_id, latent_id, True)
            else:
                img_id_downsample = F.interpolate(X, size=(112, 112))
                latent_id = model.netArc(img_id_downsample)
                latent_id = latent_id / torch.norm(latent_id, p=2, dim=1, keepdim=True)
                output = model(X, ref, latent_id, latent_id, True)
        
        with torch.no_grad():
            if faketype == "StarGAN":
                gen_clean, _ = model.features(X_nat, ref)
            elif faketype == "simswap":
                img_id_clean = F.interpolate(X_nat, size=(112, 112))
                latent_clean = model.netArc(img_id_clean)
                latent_clean = latent_clean / torch.norm(latent_clean, p=2, dim=1, keepdim=True)
                gen_clean = model(X_nat, ref, latent_clean, latent_clean, True)
        
        loss = F.mse_loss(output, gen_clean)
        loss.backward()
        
        grad = X.grad
        X_adv = X + alpha * grad.sign()
        eta = torch.clamp(X_adv - X_nat, min=-epsilon, max=epsilon)
        X = torch.clamp(X_nat + eta, min=-1.0, max=1.0).detach()
    
    return X
