import torch
import argparse
import os
import sys
import gc
import shutil
import numpy as np
from PIL import Image
from tqdm import tqdm


current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = current_dir

if current_dir not in sys.path:
    sys.path.insert(0, current_dir)

try:
    from wrappers.null_swap_wrapper import NullSwap
    from utils.utils import WRAPPER_REGISTRY, load_wrapper, load_image
except ImportError:
    sys.path.append(os.path.join(current_dir, 'wrappers'))
    from null_swap_wrapper import NullSwap
    from utils.utils import WRAPPER_REGISTRY, load_wrapper, load_image

def save_tensor_as_image(tensor, path):
    if tensor.ndim == 4:
        tensor = tensor[0]
    tensor = tensor.detach().cpu()
    tensor = torch.clamp((tensor + 1.0) / 2.0, 0, 1)
    img = (tensor * 255).permute(1, 2, 0).numpy().astype(np.uint8)
    Image.fromarray(img).save(path)

def get_target_image_for_wrapper(wrapper_name, device):
    if wrapper_name not in WRAPPER_REGISTRY: return None
    config = WRAPPER_REGISTRY[wrapper_name]
    if not config.get('requires_target', False): return None
    
    target_path = config.get('default_target_path')
    if target_path and os.path.exists(target_path):
        return load_image(target_path, device)[0]
    return None

def parse_arguments():
    parser = argparse.ArgumentParser(description="Batch NullSwap Attack")
    
    parser.add_argument("--image_dir", type=str, default="/dataset/")
    parser.add_argument("--n_images", type=int, default=100)
    parser.add_argument("--output_base", type=str, default=None)
    parser.add_argument("--device", type=str, default="cuda")
    
    parser.add_argument("--ckpt", type=str, required=True, 
                        help="Path to NullSwap Generator .pth")
    
    parser.add_argument("--wrappers", nargs="+", 
                        default=["simswap", "psp_mix", "diffae", "styleclip", "stargan", "blendface"])
    
    return parser.parse_args()

def main(args):
    device = args.device
    print(f"Device: {device}")
    
    if args.output_base is None:
        args.output_base = os.path.join(current_dir, "batch_results_nullswap")
    os.makedirs(args.output_base, exist_ok=True)
    
    print(f"\n{'='*60}")
    print(f"⚡ Loading NullSwap Attacker from: {os.path.basename(args.ckpt)}")
    print(f"{'='*60}")
    
    try:
        attacker_net = NullSwap().to(device)
        state_dict = torch.load(args.ckpt, map_location=device)
        attacker_net.load_state_dict(state_dict, strict=True)
        attacker_net.eval()
        print("NullSwap Generator loaded successfully!")
    except Exception as e:
        print(f"Failed to load NullSwap model: {e}")
        return

    img_dir = args.image_dir
    img_paths = sorted([os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith('.jpg')])
    img_paths = img_paths[:args.n_images]
    print(f"Found {len(img_paths)} images to process")

    for wrapper_name in args.wrappers:
        if wrapper_name not in WRAPPER_REGISTRY:
            print(f"Wrapper '{wrapper_name}' not found, skipping")
            continue
            
        print(f"\n{'='*60}")
        print(f"Processing Victim: {wrapper_name.upper()}")
        print(f"{'='*60}")
        
        try:
            wrapper, config = load_wrapper(wrapper_name, device=device)
        except Exception as e:
            print(f"Failed to load wrapper {wrapper_name}: {e}")
            continue

        output_dir = os.path.join(args.output_base, f"result_{wrapper_name}_nullswap")
        os.makedirs(output_dir, exist_ok=True)

        target_ref_img = None
        target_ref_attr = None
        
        if config['requires_target']:
            target_ref_img = get_target_image_for_wrapper(wrapper_name, device)
            if target_ref_img is None:
                print(f"Skipping {wrapper_name}: Target image missing.")
                del wrapper
                continue
        elif config['requires_attr']:
            target_ref_attr = config.get('default_attr')

        for idx, img_path in enumerate(tqdm(img_paths, desc=wrapper_name)):
            img_id = f"img_{idx+1:03d}"
            folder_path = os.path.join(output_dir, img_id)
            os.makedirs(folder_path, exist_ok=True)
            
            try:
                source_tensor, _ = load_image(img_path, device)
                with torch.no_grad():
                    x_adv = attacker_net(source_tensor)
                    
                with torch.no_grad():
                    if config['requires_target']:
                        decoded_src = wrapper(source_tensor, ref=target_ref_img, preprocess=False)
                    elif config['requires_attr']:
                        decoded_src = wrapper(source_tensor, target_attr=target_ref_attr)
                    else:
                        decoded_src = wrapper(source_tensor)
                    
                    if config['requires_target']:
                        decoded_adv = wrapper(x_adv, ref=target_ref_img, preprocess=False)
                    elif config['requires_attr']:
                        decoded_adv = wrapper(x_adv, target_attr=target_ref_attr)
                    else:
                        decoded_adv = wrapper(x_adv)

                save_tensor_as_image(source_tensor, os.path.join(folder_path, "x_src.jpg"))
                save_tensor_as_image(x_adv, os.path.join(folder_path, "x_adv.jpg"))
                save_tensor_as_image(decoded_src, os.path.join(folder_path, "decoded_src.jpg"))
                save_tensor_as_image(decoded_adv, os.path.join(folder_path, "decoded_adv.jpg"))
                
                torch.save(x_adv, os.path.join(folder_path, "x_adv.pt"))
                torch.save(decoded_src, os.path.join(folder_path, "decoded_src.pt"))
                torch.save(decoded_adv, os.path.join(folder_path, "decoded_adv.pt"))

            except Exception as e:
                print(f"Error on {img_id}: {e}")
                continue

        del wrapper
        if torch.cuda.is_available(): torch.cuda.empty_cache()
        gc.collect()

    print(f"\n{'='*60}")
    print("All NullSwap experiments done.")

if __name__ == "__main__":
    main(parse_arguments())