import os
import torch
import argparse
import numpy as np
from PIL import Image
from glob import glob
from tqdm import tqdm
from torchvision.transforms.functional import pil_to_tensor

from pq_utils import (
    set_seed,
    apply_single_distortion,
    get_init_latent,
    get_img_tensor,
)

from diffusers import DDIMScheduler
from utils.wm.wm_utils import WmProviders
from main.wmdiffusion import WMDetectStableDiffusionPipeline
from sklearn import metrics

from main.attdiffusion import ReSDPipeline
from main.wmpatch import GTWatermark
from main.wmattacker import (
    DiffWMAttacker,
    VAEWMAttacker,
    JPEGAttacker,
    RotateAttacker,
    BrightnessAttacker,
    ContrastAttacker,
    GaussianNoiseAttacker,
    GaussianBlurAttacker,
    BM3DAttacker,
)


def get_metrics(
    w_latents_list,
    nw_latents_list,
    pipe,
    wm_pipe,
    text_embeddings,
):
    w_accs = []
    nw_accs = []
    print('Calculating metrics...')
    for w_latents, nw_latents in tqdm(zip(w_latents_list, nw_latents_list), total=len(w_latents_list)):
        wm_bit_acc = 1 - watermark_prob(w_latents, pipe, wm_pipe, text_embeddings, device=device)
        nwm_bit_acc = 1 - watermark_prob(nw_latents, pipe, wm_pipe, text_embeddings, device=device)

        w_accs.append(wm_bit_acc)
        nw_accs.append(nwm_bit_acc)

    preds = nw_accs + w_accs
    t_labels = [0] * len(nw_accs) + [1] * len(w_accs)

    assert len(set(t_labels)) >= 2, f"Could not compute AUC (only one class exists)."

    fpr, tpr, _ = metrics.roc_curve(t_labels, preds, pos_label=1)
    auc = metrics.auc(fpr, tpr)
    acc = np.max(1 - (fpr + (1 - tpr)) / 2)
    valid_indices = np.where(fpr < 0.01)[0]
    low = tpr[valid_indices[-1]] if len(valid_indices) > 0 else 0.0

    mean_w_acc = sum(w_accs) / len(w_accs)
        
    return low, mean_w_acc


def watermark_prob(img, dect_pipe, wm_pipe, text_embeddings, tree_ring=True, device=torch.device('cuda')):
    if isinstance(img, str):
        img_tensor = pil_to_tensor(Image.open(img).convert("RGB"))/255
        img_tensor = img_tensor.unsqueeze(0).to(device)
    elif isinstance(img, Image.Image):
        img_tensor = pil_to_tensor(img)/255
        img_tensor = img_tensor.unsqueeze(0).to(device)
    elif isinstance(img, torch.Tensor):
        img_tensor = img

    img_latents = dect_pipe.get_image_latents(img_tensor, sample=False)
    reversed_latents = dect_pipe.forward_diffusion(
        latents=img_latents,
        text_embeddings=text_embeddings,
        guidance_scale=1.0,
        num_inference_steps=50,
    )
    det_prob = wm_pipe.one_minus_p_value(reversed_latents) if not tree_ring else wm_pipe.tree_ring_p_value(reversed_latents)
    return det_prob


parser = argparse.ArgumentParser(description="Configuration for image watermarking and generation.")
parser.add_argument('--wm_type', type=str, default='PQ', choices=['GS', 'TR', 'PQ'], help='Watermark type to use (GS, TR or PQ).')
parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility.')
parser.add_argument('--device', type=str, default='cuda:0', help='Device to run on (e.g., "cuda:0", "cpu").')
parser.add_argument('--attack_name', type=str, 
                    default='JPEG', 
                    # default='GaussianBlur', 
                    # default='GaussianNoise', 
                    # default='Brightness', 
                    # default='Resize', 
                    # default='SPNoise', 
                    help='')
args = parser.parse_args()

cfgs = {
    "method": "ZoDiac",
    "model_id": "stabilityai/stable-diffusion-2-1-base",
    "gen_seed": 0,
    "empty_prompt": True,

    "w_type": "single",
    "w_channel": 3,
    "w_radius": 10,
    "w_seed": 10,

    "start_latents": "init_w",
    "iters": 100,
    "save_iters": [100],
    "loss_weights": [10.0, 0.1, 1.0, 0.0],
    "ssim_threshold": 0.92,
    "detect_threshold": 0.9,

    "datasets": "output_images_wo_wm",
    "percent": 1.0,
}

set_seed(args.seed)

device=args.device
source_dir = 'gen_zod'
gt_dir = 'output_images_wo_wm'
source_files = glob(f'{source_dir}/**.png')
gt_files = glob(f'{gt_dir}/**.png')
source_files.sort()
gt_files.sort()

# source_files = source_files[:10]
# gt_files = gt_files[:10]

### varify sorting ###
assert len(source_files) == len(gt_files), \
    f'length of source_files and gt_files mismatched, {len(source_files)} != {len(gt_files)}'

for source_file, gt_file in zip(source_files, gt_files):
    source_name = os.path.basename(source_file).split('-')[0]
    gt_name = os.path.basename(gt_file).split('.')[0]
    assert source_name == gt_name, f'source and gt sorting dismatched, {source_name} != {gt_name}'
### varify sorting ###


attack_names = [
    [args.attack_name],
]

attack_params = {
    'jpeg_ratio': 25,
    'gaussian_blur_r': 5,
    'gaussian_std': 0.1,
    'brightness_factor': 2,
    'resize_ratio': 0.5,
    'sp_prob': 0.2,
}

scheduler = DDIMScheduler.from_pretrained(
    'stabilityai/stable-diffusion-2-1-base',
    subfolder="scheduler"
)
pipe = WMDetectStableDiffusionPipeline.from_pretrained(
    'stabilityai/stable-diffusion-2-1-base',
    scheduler=scheduler
).to(device)
pipe.set_progress_bar_config(disable=True)

wm_pipe = GTWatermark(
    torch.device(device),
    w_channel=cfgs['w_channel'],
    w_radius=cfgs['w_radius'],
    generator=torch.Generator(device).manual_seed(cfgs['w_seed'])
)

invert_text_embedding = pipe.get_text_embedding('')

latent_shape = (1, 4, 64, 64)
tester_prompt = '' 
text_embeddings = pipe.get_text_embedding(tester_prompt)
# wm_provider = WmProviders[args.wm_type].value(
#     latent_shape=latent_shape,
#     **vars(args),
# )

for attack_name in attack_names:
    print(f'Running {attack_name} ...')
    
    w_latents = []
    nw_latents = []
    for gt_file, source_file in tqdm(zip(gt_files, source_files), total=len(source_files)):
        wm_image = Image.open(source_file)
        gt_image = Image.open(gt_file)

        att_wm_image = apply_single_distortion(wm_image, attack_name, attack_params)
        att_gt_image = apply_single_distortion(gt_image, attack_name, attack_params)

        w_latents.append(att_wm_image)
        nw_latents.append(att_gt_image)

    tpr, mean_w_acc = get_metrics(w_latents, nw_latents, pipe, wm_pipe, text_embeddings)
    
    print('-'*50)
    print(f'[{attack_name}] | TPR: {tpr} | ACC: {mean_w_acc}')
    print('-'*50)

        
print(0)

