import torch
import torch.nn as nn
from .utils import *
import torchvision.transforms as T

def lab_attack(wrapper, X_nat, epsilon=0.05, lr=1e-4, steps=500, ref=None):
    device = X_nat.device
    criterion = nn.MSELoss().to(device)
    
    wrapper.eval()
    
    if 'StarGAN' in wrapper.__class__.__name__:
        for module in wrapper.modules():
            if isinstance(module, (torch.nn.BatchNorm2d, torch.nn.InstanceNorm2d)):
                module.eval()
                module.track_running_stats = False
    
    pert_a = torch.zeros(X_nat.shape[0], 2, X_nat.shape[2], X_nat.shape[3]).cuda().requires_grad_()
    optimizer = torch.optim.Adam([pert_a], lr=lr, betas=(0.9, 0.999))
    
        
    X = denorm(X_nat.clone()).detach()
    
    pert_a = torch.zeros(X.shape[0], 2, X.shape[2], X.shape[3], device=device, requires_grad=True)
    optimizer = torch.optim.Adam([pert_a], lr=lr, betas=(0.9, 0.999))
    
    with torch.no_grad():
        encoded_src = wrapper.encode(X_nat)
        decoded_src = wrapper.decode(encoded_src, ref=ref)
    
    for i in range(steps):
        pert = torch.clamp(pert_a, min=-epsilon, max=epsilon)
        
        X_lab = rgb2lab(X).to(device)
        X_lab_pert = torch.cat([X_lab[:, :1], X_lab[:, 1:] + pert], dim=1)
        
        X_new_raw = lab2rgb(X_lab_pert)
        if torch.isnan(X_new_raw).any() or torch.isinf(X_new_raw).any():
            X_new_raw = torch.nan_to_num(X_new_raw, nan=0.0, posinf=1.0, neginf=0.0)
            
        X_new = norm(X_new_raw)
        
        encoded_adv = wrapper.encode(X_new)
        decoded_adv = wrapper.decode(encoded_adv, ref=ref)
        
        loss = -criterion(decoded_adv, decoded_src)
        
        optimizer.zero_grad()
        loss.backward()
        
        if pert_a.grad is not None:
            torch.nan_to_num_(pert_a.grad, nan=0.0, posinf=1.0, neginf=-1.0)
            
            torch.nn.utils.clip_grad_norm_([pert_a], max_norm=1.0)
        
        optimizer.step()

    return X_new.clamp(-1, 1).detach()

