import os
import torch
import argparse
import numpy as np
from PIL import Image
from glob import glob
from tqdm import tqdm
import shutil
from torchvision.transforms.functional import pil_to_tensor

from pq_utils import (
    set_seed,
    # apply_single_distortion,
    get_init_latent,
    get_img_tensor,
)
from torchvision import transforms
from PIL import Image, ImageFilter

from diffusers import DDIMScheduler
from utils.wm.wm_utils import WmProviders
from main.wmdiffusion import WMDetectStableDiffusionPipeline

from sklearn import metrics

from main.attdiffusion import ReSDPipeline
from main.wmattacker import (
    DiffWMAttacker,
    VAEWMAttacker,
    JPEGAttacker,
    RotateAttacker,
    BrightnessAttacker,
    ContrastAttacker,
    GaussianNoiseAttacker,
    GaussianBlurAttacker,
    BM3DAttacker,
)

def get_metrics(
    w_latents_list,
    nw_latents_list,
    wm_provider,
):
    w_accs = []
    nw_accs = []
    print('Calculating metrics...')
    for w_latents, nw_latents in tqdm(zip(w_latents_list, nw_latents_list), total=len(w_latents_list)):
        wm_bit_acc = wm_provider.get_accuracies(w_latents)["bit_accuracies"][0]
        nwm_bit_acc = wm_provider.get_accuracies(nw_latents)["bit_accuracies"][0]

        w_accs.append(wm_bit_acc)
        nw_accs.append(nwm_bit_acc)

    preds = nw_accs + w_accs
    t_labels = [0] * len(nw_accs) + [1] * len(w_accs)

    assert len(set(t_labels)) >= 2, f"Could not compute AUC (only one class exists)."

    fpr, tpr, _ = metrics.roc_curve(t_labels, preds, pos_label=1)
    auc = metrics.auc(fpr, tpr)
    acc = np.max(1 - (fpr + (1 - tpr)) / 2)
    valid_indices = np.where(fpr < 0.01)[0]
    low = tpr[valid_indices[-1]] if len(valid_indices) > 0 else 0.0

    mean_w_acc = sum(w_accs) / len(w_accs)
        
    return low, mean_w_acc


parser = argparse.ArgumentParser(description="Configuration for image watermarking and generation.")
parser.add_argument('--wm_type', type=str, default='GS', choices=['GS', 'TR', 'PQ'], help='Watermark type to use (GS, TR or PQ).')

parser.add_argument('--message_width_in_bytes', type=int, default=32, help='[GS] Message width in bytes.')
parser.add_argument('--num_replications', type=int, default=64, help='[GS] Number of message replications.')
parser.add_argument('--l', type=int, default=1, help='[GS] Parameter l.')
parser.add_argument('--offset', type=int, default=0, help='[GS] Offset value.')
parser.add_argument('--message', type=str, default=None, help='[GS] Message to embed.')
parser.add_argument('--key', type=str, default=None, help='[GS] Encryption key.')
parser.add_argument('--nonce', type=str, default=None, help='[GS] Nonce value.')

parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility.')
parser.add_argument('--device', type=str, default='cuda:2', help='Device to run on (e.g., "cuda:0", "cpu").')
parser.add_argument('--attack_name', type=str, 
                    default='diff_attacker_60', 
                    # default='cheng2020-anchor_3', 
                    # default='bmshj2018-factorized_3', 
                    # default='jpeg_attacker_50', 
                    # default='rotate_90', 
                    # default='brightness_0.5', 
                    # default='contrast_0.5', 
                    # default='Gaussian_noise', 
                    # default='Gaussian_blur', 
                    # default='bm3d', 
                    help='')
args = parser.parse_args()

set_seed(args.seed)

#####################################################################################################
att_pipe = ReSDPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, revision="fp16")
att_pipe.set_progress_bar_config(disable=True)
att_pipe.to(args.device)

attackers = {
    'diff_attacker_60': DiffWMAttacker(att_pipe, batch_size=5, noise_step=60, captions={}),
    'cheng2020-anchor_3': VAEWMAttacker('cheng2020-anchor', quality=3, metric='mse', device=args.device),
    'bmshj2018-factorized_3': VAEWMAttacker('bmshj2018-factorized', quality=3, metric='mse', device=args.device),
    'jpeg_attacker_50': JPEGAttacker(quality=50),
    'rotate_90': RotateAttacker(degree=90),
    'brightness_0.5': BrightnessAttacker(brightness=0.5),
    'contrast_0.5': ContrastAttacker(contrast=0.5),
    'Gaussian_noise': GaussianNoiseAttacker(std=0.05),
    'Gaussian_blur': GaussianBlurAttacker(kernel_size=5, sigma=1),
    'bm3d': BM3DAttacker(),
}
#####################################################################################################



device=args.device
source_dir = 'gen_gs'
gt_dir = 'output_images_wo_wm'
source_files = glob(f'{source_dir}/**.png')
gt_files = glob(f'{gt_dir}/**.png')
source_files.sort()
gt_files.sort()

method = source_dir.split('_')[-1]

# source_files = source_files[:5]
# gt_files = gt_files[:5]

### varify sorting ###
assert len(source_files) == len(gt_files), \
    f'length of source_files and gt_files mismatched, {len(source_files)} != {len(gt_files)}'

for source_file, gt_file in zip(source_files, gt_files):
    source_name = os.path.basename(source_file).split('-')[0]
    gt_name = os.path.basename(gt_file).split('.')[0]
    assert source_name == gt_name, f'source and gt sorting dismatched, {source_name} != {gt_name}'
### varify sorting ###


attack_names = [
    args.attack_name,
]

scheduler = DDIMScheduler.from_pretrained(
    'stabilityai/stable-diffusion-2-1-base',
    subfolder="scheduler"
)
pipe = WMDetectStableDiffusionPipeline.from_pretrained(
    'stabilityai/stable-diffusion-2-1-base',
    scheduler=scheduler
).to(device)
pipe.set_progress_bar_config(disable=True)

invert_text_embedding = pipe.get_text_embedding('')

latent_shape = (1, 4, 64, 64)
wm_provider = WmProviders[args.wm_type].value(
    latent_shape=latent_shape,
    **vars(args),
)

source_save_dir = source_dir + f'_{method}_{args.attack_name}'
os.makedirs(source_save_dir, exist_ok=True)
gt_save_dir = gt_dir + f'_{method}_{args.attack_name}'
os.makedirs(gt_save_dir, exist_ok=True)
for attack_name in attack_names:
    print(f'Running {attack_name} ...')
    
    w_latents = []
    nw_latents = []
    for gt_file, source_file in tqdm(zip(gt_files, source_files), total=len(source_files)):
        wm_image = Image.open(source_file)
        gt_image = Image.open(gt_file)

        attackers[attack_name].attack([source_file], [os.path.join(source_save_dir, os.path.basename(source_file))])
        attackers[attack_name].attack([gt_file], [os.path.join(gt_save_dir, os.path.basename(gt_file))])

        att_wm_image = Image.open(os.path.join(source_save_dir, os.path.basename(source_file)))
        att_gt_image = Image.open(os.path.join(gt_save_dir, os.path.basename(gt_file)))

        # att_wm_image.save(os.path.join(source_save_dir, os.path.basename(source_file)))
        # att_gt_image.save(os.path.join(gt_save_dir, os.path.basename(gt_file)))

        att_wm_tensor = (pil_to_tensor(att_wm_image) / 255).unsqueeze(0).to(device)
        att_gt_tensor = (pil_to_tensor(att_gt_image) / 255).unsqueeze(0).to(device)
        att_wm_latents = get_init_latent(
            img_tensor=att_wm_tensor,
            pipe=pipe,
            text_embeddings=invert_text_embedding,
        )
        att_gt_latents = get_init_latent(
            img_tensor=att_gt_tensor,
            pipe=pipe,
            text_embeddings=invert_text_embedding,
        )

        # torch.save(att_wm_latents, os.path.join(source_save_dir, os.path.basename(source_file)).replace('png','pt'))
        # torch.save(att_gt_latents, os.path.join(gt_save_dir, os.path.basename(gt_file)).replace('png','pt'))

        w_latents.append(att_wm_latents)
        nw_latents.append(att_gt_latents)

    tpr, mean_w_acc = get_metrics(w_latents, nw_latents, wm_provider)
    
    print('-'*50)
    print(f'[{attack_name}] | TPR: {tpr} | ACC: {mean_w_acc}')
    print('-'*50)


shutil.rmtree(source_save_dir)
shutil.rmtree(gt_save_dir)
        
print(0)

