import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import argparse
import os
import sys
import gc
import numpy as np
from PIL import Image
from tqdm import tqdm

current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = current_dir

for p in [current_dir, project_root]:
    if p not in sys.path:
        sys.path.insert(0, p)

from utils.utils import (
    WRAPPER_REGISTRY,
    load_wrapper,
    load_image,
    run_anti_attack,
)

def save_tensor_as_image(tensor, path):
    if tensor.ndim == 4:
        tensor = tensor[0]
    tensor = tensor.detach().cpu()
    tensor = torch.clamp((tensor + 1.0) / 2.0, 0, 1)
    img = (tensor * 255).permute(1, 2, 0).numpy().astype(np.uint8)
    Image.fromarray(img).save(path)


def get_target_image_for_wrapper(wrapper_name, device):
    if wrapper_name not in WRAPPER_REGISTRY:
        return None
    
    config = WRAPPER_REGISTRY[wrapper_name]
    
    if not config.get('requires_target', False):
        return None
    
    target_path = config.get('default_target_path')
    if target_path is None:
        return None
    
    if os.path.exists(target_path):
        try:
            return load_image(target_path, device)[0]
        except Exception as e:
            print(f"Failed to load target image from {target_path}: {e}")
            return None
    else:
        print(f"Target image not found: {target_path}")
        return None


def parse_arguments():
    parser = argparse.ArgumentParser(
        description="Batch Anti-Forgery Lab Attack on Multiple Wrappers"
    )
    
    parser.add_argument(
        "--image_dir",
        type=str,
        default="/dataset/",
        help="Directory containing test images"
    )
    
    parser.add_argument(
        "--n_images",
        type=int,
        default=100,
        help="Number of images to process per wrapper"
    )
    
    parser.add_argument(
        "--output_base",
        type=str,
        default=None,
        help="Base output directory (default: current_dir/batch_results_anti)"
    )
    
    parser.add_argument(
        "--epsilon",
        type=float,
        default=0.05,
        help="Perturbation budget in Lab space"
    )
    
    parser.add_argument(
        "--lr",
        type=float,
        default=1e-4,
        help="Learning rate for Adam optimizer"
    )
    
    parser.add_argument(
        "--steps",
        type=int,
        default=500,
        help="Number of optimization steps"
    )
    
    parser.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="Device to use"
    )
    
    parser.add_argument(
        "--wrappers",
        nargs="+",
        default=["simswap", "psp_mix", "diffae", "styleclip"],
        help="Wrappers to use (ordered from small to large models)"
    )
    
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility"
    )
    
    return parser.parse_args()


def main(args):
    device = args.device
    print(f"Device: {device}")
    
    if args.seed is not None:
        import random
        import numpy as np
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(args.seed)
            torch.cuda.manual_seed_all(args.seed)
        print(f"Random seed set to: {args.seed}")
    
    if args.output_base is None:
        args.output_base = os.path.join(current_dir, "batch_results_anti")
    
    print(f"Output base directory: {args.output_base}")
    
    img_dir = args.image_dir
    if not os.path.exists(img_dir):
        print(f"Image directory not found: {img_dir}")
        return
    
    img_paths = sorted([os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith('.jpg')])
    img_paths = img_paths[:args.n_images]
    
    print(f"Found {len(img_paths)} images to process")
    
    for wrapper_name in args.wrappers:
        if wrapper_name not in WRAPPER_REGISTRY:
            print(f"Wrapper '{wrapper_name}' not found, skipping")
            continue
        
        print(f"\n{'='*60}")
        print(f"Processing {wrapper_name.upper()} - Anti-Forgery Lab Attack")
        print(f"{'='*60}")
        
        config = WRAPPER_REGISTRY[wrapper_name]
        
        target_tensor = None
        if config['requires_target']:
            target_tensor = get_target_image_for_wrapper(wrapper_name, device)
            if target_tensor is None:
                print(f"{wrapper_name} requires target image but none available, skipping")
                continue
            else:
                print(f"   Target image loaded from: {config.get('default_target_path', 'unknown')}")
        
        if config['requires_attr'] and config.get('default_attr') is None:
            print(f"{wrapper_name} requires target attribute but no default set, skipping")
            continue
        
        output_dir = os.path.join(args.output_base, f"result_{wrapper_name}_anti")
        os.makedirs(output_dir, exist_ok=True)
        
        try:
            wrapper, config = load_wrapper(wrapper_name, device=device)
        except Exception as e:
            print(f"Failed to load wrapper: {e}")
            continue
        
        if config['requires_target']:
            wrapper_ref = target_tensor
            if hasattr(wrapper, 'set_target'):
                wrapper.set_target(target_tensor)
            print(f"   Using target image as reference")
        elif config['requires_attr']:
            wrapper_ref = config['default_attr']
            print(f"   Using attribute as reference: {wrapper_ref}")
        else:
            wrapper_ref = None
            print(f"   No reference needed (self-reconstruction)")
        
        for idx, img_path in enumerate(tqdm(img_paths, desc=wrapper_name)):
            try:
                source_tensor, _ = load_image(img_path, device)
                
                adv_tensor, ref = run_anti_attack(
                    wrapper=wrapper,
                    source_tensor=source_tensor,
                    target_tensor=target_tensor if config['requires_target'] else None,
                    target_attr=wrapper_ref if config['requires_attr'] else None,
                    epsilon=args.epsilon,
                    lr=args.lr,
                    steps=args.steps,
                    config=config
                )
                print(f"Adv Tensor Range: {adv_tensor.min().item():.4f} ~ {adv_tensor.max().item():.4f}")
                
                with torch.no_grad():
                    if config['requires_target'] and target_tensor is not None:
                        decoded_src = wrapper(source_tensor, ref=target_tensor, preprocess=False)
                        decoded_adv = wrapper(adv_tensor, ref=target_tensor, preprocess=False)
                    elif config['requires_attr'] and wrapper_ref is not None:
                        decoded_src = wrapper(source_tensor, target_attr=wrapper_ref)
                        decoded_adv = wrapper(adv_tensor, target_attr=wrapper_ref)
                    else:
                        decoded_src = wrapper(source_tensor)
                        decoded_adv = wrapper(adv_tensor)
                
                folder_path = os.path.join(output_dir, f"img_{idx+1:03d}")
                os.makedirs(folder_path, exist_ok=True)
                
                save_tensor_as_image(source_tensor, os.path.join(folder_path, "x_src.jpg"))
                save_tensor_as_image(adv_tensor, os.path.join(folder_path, "x_adv.jpg"))
                save_tensor_as_image(decoded_src, os.path.join(folder_path, "decoded_src.jpg"))
                save_tensor_as_image(decoded_adv, os.path.join(folder_path, "decoded_adv.jpg"))
                
                torch.save(adv_tensor, os.path.join(folder_path, "x_adv.pt"))
                torch.save(decoded_src, os.path.join(folder_path, "decoded_src.pt"))
                torch.save(decoded_adv, os.path.join(folder_path, "decoded_adv.pt"))
                
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                
            except Exception as e:
                print(f"\nError processing image {idx+1}: {e}")
                import traceback
                traceback.print_exc()
                continue
        
        del wrapper
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        gc.collect()
        
        print(f"{wrapper_name} completed!")
        print(f"   GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB / {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
    
    print(f"\n{'='*60}")
    print("All done!")
    print(f"{'='*60}")


if __name__ == "__main__":
    args = parse_arguments()
    try:
        main(args)
    except Exception as e:
        print(f"\nError: {e}")
        import traceback
        traceback.print_exc()
