import os
import random
import numpy as np
import argparse
import copy
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from diffusers import StableDiffusionPipeline, DDIMScheduler
from dataset import ImageDataset
from adversarial_optimization import Adversarial_Opt
from tests import attack_local_models


def parse_args():
    parser = argparse.ArgumentParser(description="")

    # 确保默认路径就是完整的根目录路径
    parser.add_argument('--source_dir',
                        default="assets/datasets/LADN_align",
                        type=str,
                        help="dataset")
    parser.add_argument('--test_dir',
                        default="",
                        type=str,
                        help="test images folder path for obfuscation")
    parser.add_argument('--protected_image_dir',
                        default="results",
                        type=str)
    parser.add_argument('--comparison_null_text',
                        default=False,
                        type=bool)
    parser.add_argument('--target_choice',
                        default='1',
                        type=str,
                        help='Choice of target identity')
    parser.add_argument("--test_model_name",
                        default=['mobile_face'])
    parser.add_argument("--surrogate_model_names",
                        default=['irse50', 'ir152', 'facenet'])
    parser.add_argument('--is_makeup',
                        default=False,
                        type=bool)
    parser.add_argument('--source_text',
                        default='face',
                        type=str)
    parser.add_argument('--makeup_prompt',
                        default='red lipstick',
                        type=str)
    parser.add_argument('--MTCNN_cropping',
                        default=True,
                        type=bool)
    parser.add_argument('--is_obfuscation',
                        default=False,
                        type=bool)
    parser.add_argument('--image_size',
                        default=256,
                        type=int)
    parser.add_argument('--prot_steps',
                        default=40,
                        type=int)
    parser.add_argument('--diffusion_steps',
                        default=20,
                        type=int)
    parser.add_argument('--fast_inversion',
                        action='store_true',
                        help='使用 Negative-prompt Inversion 加速')
    parser.add_argument('--start_step',
                        default=17,
                        type=int,
                        help='Which DDIM step to start the protection (20 - 17 = 3)')
    parser.add_argument('--null_optimization_steps',
                        default=20,
                        type=int)

    args = parser.parse_args()
    return args


def initialize_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


if __name__ == "__main__":

    initialize_seed(seed=10)
    args_in = parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    diffusion_path = '../stable-diffusion-2-base'
    diff_model = StableDiffusionPipeline.from_pretrained(
        diffusion_path).to(device)
    diff_model.scheduler = DDIMScheduler.from_config(
            diff_model.scheduler.config
        )

    test_combos = [
        (['mobile_face'], ['irse50', 'ir152', 'facenet']),
        (['irse50'], ['mobile_face', 'ir152', 'facenet']),
        (['ir152'], ['mobile_face', 'irse50', 'facenet']),
        (['facenet'], ['mobile_face', 'irse50', 'ir152'])
    ]

    for i in range(1, 5):
        print("=" * 60)
        print(f"当前使用文件夹 {i} 模拟目标 {i}")
        print("=" * 60)

        data_folder = os.path.join(args_in.source_dir, str(i))
        dataset = ImageDataset(
            data_folder,
            transforms.Compose([
                transforms.Resize((args_in.image_size, args_in.image_size)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5],
                                    [0.5, 0.5, 0.5])
            ])
        )

        dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
        args_in.target_choice = str(i)
        args = copy.deepcopy(args_in)
        args.source_dir = data_folder
        args.dataloader = dataloader
        args.device = device

        for test_model, sur_models in test_combos:
            tmp_args = copy.deepcopy(args)
            tmp_args.test_model_name = test_model
            tmp_args.surrogate_model_names = sur_models

            print("=" * 60)
            print(f"当前测试模型: {tmp_args.test_model_name}")
            print(f"当前替代模型: {tmp_args.surrogate_model_names}")
            print(f"target_choice: {tmp_args.target_choice}")
            print("=" * 60)

            adversarial_opt = Adversarial_Opt(tmp_args, diff_model)
            adversarial_opt.run()
            attack_local_models(tmp_args, protection=False)
            attack_local_models(tmp_args, protection=True)