import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import argparse
import os
import sys
import gc
import numpy as np
from PIL import Image
from tqdm import tqdm

current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = current_dir

for p in [current_dir, project_root]:
    if p not in sys.path:
        sys.path.insert(0, p)

from utils.utils import (
    WRAPPER_REGISTRY,
    load_wrapper,
    load_image,
)

from disrupting_methods.LEAT.leat_ensemble import lpgd_ensemble_attack

def save_tensor_as_image(tensor, path):
    if tensor.ndim == 4:
        tensor = tensor[0]
    tensor = tensor.detach().cpu()
    tensor = torch.clamp((tensor + 1.0) / 2.0, 0, 1)
    img = (tensor * 255).permute(1, 2, 0).numpy().astype(np.uint8)
    Image.fromarray(img).save(path)


def get_target_image_for_wrapper(wrapper_name, device):
    if wrapper_name not in WRAPPER_REGISTRY:
        return None
    
    config = WRAPPER_REGISTRY[wrapper_name]
    
    if not config.get('requires_target', False):
        return None
    
    target_path = config.get('default_target_path')
    if target_path is None:
        return None
    
    if os.path.exists(target_path):
        try:
            return load_image(target_path, device)[0]
        except Exception as e:
            print(f"Failed to load target image from {target_path}: {e}")
            return None
    else:
        print(f"Target image not found: {target_path}")
        return None


def parse_arguments():
    parser = argparse.ArgumentParser(
        description="Batch LEAT Attack on Multiple Wrappers"
    )
    
    parser.add_argument(
        "--image_dir",
        type=str,
        default="/dataset/",
        help="Directory containing test images"
    )
    
    parser.add_argument(
        "--n_images",
        type=int,
        default=100,
        help="Number of images to process per wrapper"
    )
    
    parser.add_argument(
        "--output_base",
        type=str,
        default=None,
        help="Base output directory (default: current_dir/batch_results_leat)"
    )
    
    parser.add_argument(
        "--epsilon",
        type=float,
        default=0.05,
        help="Perturbation budget"
    )
    
    parser.add_argument(
        "--alpha",
        type=float,
        default=0.01,
        help="Step size"
    )
    
    parser.add_argument(
        "--steps",
        type=int,
        default=30,
        help="Number of attack steps"
    )
    
    parser.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="Device to use"
    )
    
    parser.add_argument(
        "--wrapper1",
        type=str,
        default="diffae",
        help="First wrapper for ensemble attack"
    )
    
    parser.add_argument(
        "--wrapper2",
        type=str,
        default="simswap",
        help="Second wrapper for ensemble attack"
    )
    
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility"
    )
    
    return parser.parse_args()


def main(args):
    device = args.device
    print(f"Device: {device}")
    
    if args.seed is not None:
        import random
        import numpy as np
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(args.seed)
            torch.cuda.manual_seed_all(args.seed)
        print(f"Random seed set to: {args.seed}")
    
    if args.output_base is None:
        args.output_base = os.path.join(current_dir, "batch_results_leat_ensemble")
    
    print(f"Output base directory: {args.output_base}")
    
    img_dir = args.image_dir
    if not os.path.exists(img_dir):
        print(f"Image directory not found: {img_dir}")
        return
    
    img_paths = sorted([os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith('.jpg')])
    img_paths = img_paths[:args.n_images]
    
    print(f"Found {len(img_paths)} images to process")
    
    for wrapper_name in [args.wrapper1, args.wrapper2]:
        if wrapper_name not in WRAPPER_REGISTRY:
            print(f"Wrapper '{wrapper_name}' not found")
            return
    
    print(f"\n{'='*60}")
    print(f"Ensemble Attack: {args.wrapper1.upper()} + {args.wrapper2.upper()}")
    print(f"{'='*60}")
    
    print(f"\nLoading {args.wrapper1}...")
    try:
        wrapper1, config1 = load_wrapper(args.wrapper1, device=device)
    except Exception as e:
        print(f"Failed to load wrapper1: {e}")
        return
    
    print(f"Loading {args.wrapper2}...")
    try:
        wrapper2, config2 = load_wrapper(args.wrapper2, device=device)
    except Exception as e:
        print(f"Failed to load wrapper2: {e}")
        del wrapper1
        return
    
    target1 = get_target_image_for_wrapper(args.wrapper1, device) if config1['requires_target'] else None
    target2 = get_target_image_for_wrapper(args.wrapper2, device) if config2['requires_target'] else None
    
    if config1['requires_target'] and target1 is not None:
        if hasattr(wrapper1, 'set_target'):
            wrapper1.set_target(target1)
    if config2['requires_target'] and target2 is not None:
        if hasattr(wrapper2, 'set_target'):
            wrapper2.set_target(target2)
    
    output_dir = os.path.join(args.output_base, f"result_{args.wrapper1}_{args.wrapper2}_ensemble")
    os.makedirs(output_dir, exist_ok=True)
    
    for idx, img_path in enumerate(tqdm(img_paths, desc="Ensemble")):
        try:
            source_tensor, _ = load_image(img_path, device)
            
            adv_tensor = lpgd_ensemble_attack(
                wrapper1=wrapper1,
                wrapper2=wrapper2,
                x_source=source_tensor,
                epsilon=args.epsilon,
                alpha=args.alpha,
                steps=args.steps,
            )
            
            with torch.no_grad():
                if config1['requires_target'] and target1 is not None:
                    decoded_src1 = wrapper1(source_tensor, ref=target1, preprocess=False)
                    decoded_adv1 = wrapper1(adv_tensor, ref=target1, preprocess=False)
                elif config1['requires_attr'] and config1.get('default_attr') is not None:
                    decoded_src1 = wrapper1(source_tensor, target_attr=config1['default_attr'])
                    decoded_adv1 = wrapper1(adv_tensor, target_attr=config1['default_attr'])
                else:
                    decoded_src1 = wrapper1(source_tensor)
                    decoded_adv1 = wrapper1(adv_tensor)
                
                if config2['requires_target'] and target2 is not None:
                    decoded_src2 = wrapper2(source_tensor, ref=target2, preprocess=False)
                    decoded_adv2 = wrapper2(adv_tensor, ref=target2, preprocess=False)
                elif config2['requires_attr'] and config2.get('default_attr') is not None:
                    decoded_src2 = wrapper2(source_tensor, target_attr=config2['default_attr'])
                    decoded_adv2 = wrapper2(adv_tensor, target_attr=config2['default_attr'])
                else:
                    decoded_src2 = wrapper2(source_tensor)
                    decoded_adv2 = wrapper2(adv_tensor)
            
            folder_path = os.path.join(output_dir, f"img_{idx+1:03d}")
            os.makedirs(folder_path, exist_ok=True)
            
            save_tensor_as_image(source_tensor, os.path.join(folder_path, "x_src.jpg"))
            save_tensor_as_image(adv_tensor, os.path.join(folder_path, "x_adv.jpg"))
            save_tensor_as_image(decoded_src1, os.path.join(folder_path, f"decoded_src_{args.wrapper1}.jpg"))
            save_tensor_as_image(decoded_adv1, os.path.join(folder_path, f"decoded_adv_{args.wrapper1}.jpg"))
            save_tensor_as_image(decoded_src2, os.path.join(folder_path, f"decoded_src_{args.wrapper2}.jpg"))
            save_tensor_as_image(decoded_adv2, os.path.join(folder_path, f"decoded_adv_{args.wrapper2}.jpg"))
            
            torch.save(adv_tensor, os.path.join(folder_path, "x_adv.pt"))
            torch.save(decoded_src1, os.path.join(folder_path, f"decoded_src_{args.wrapper1}.pt"))
            torch.save(decoded_adv1, os.path.join(folder_path, f"decoded_adv_{args.wrapper1}.pt"))
            torch.save(decoded_src2, os.path.join(folder_path, f"decoded_src_{args.wrapper2}.pt"))
            torch.save(decoded_adv2, os.path.join(folder_path, f"decoded_adv_{args.wrapper2}.pt"))
            
        except Exception as e:
            print(f"\nError processing image {idx+1}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    del wrapper1, wrapper2
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    gc.collect()
    
    print(f"\nEnsemble attack completed!")
    print(f"   GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB / {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
    
    print(f"\n{'='*60}")
    print("All done!")
    print(f"{'='*60}")


if __name__ == "__main__":
    args = parse_arguments()
    try:
        main(args)
    except Exception as e:
        print(f"\nError: {e}")
        import traceback
        traceback.print_exc()
