import os
import yaml
import shutil
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
from glob import glob

import torch
import torch.optim as optim
import torchvision.transforms as transforms

from diffusers import DDIMScheduler
from datasets import load_dataset
from diffusers.utils.torch_utils import randn_tensor

from loss.pytorch_ssim import ssim
from main.wmpatch import GTWatermark, GTWatermarkMulti
from main.wmdiffusion import WMDetectStableDiffusionPipeline
from pq_utils import LossProvider, get_init_latent, binary_search_theta, set_seed, get_img_tensor, save_img
from utils.wm.wm_utils import WmProviders

# --------------------------------------------------------------------------------

parser = argparse.ArgumentParser(description="Configuration for image watermarking and generation.")

# --- General Settings ---
parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility.')
# parser.add_argument('--input_image', type=str, default='/home/datasets/flickr30k_images/flickr30k_images_resized/1000092795.jpg', help='Path to the input image file.')
parser.add_argument('--prompt', type=str, default='', help='Prompt for image generation.')
parser.add_argument('--invert_prompt', type=str, default='', help='Prompt for image inversion.')
parser.add_argument('--steps', type=int, default=50, help='Number of diffusion steps.')
parser.add_argument('--guidance_scale', type=float, default=1.0, help='Guidance scale value.')

# --- Watermark Type ---
parser.add_argument('--wm_type', type=str, default='PQ', choices=['GS', 'TR', 'PQ'], help='Watermark type to use (GS, TR or PQ).')

# --- GS (Guided Steganography) Settings ---
gs_group = parser.add_argument_group('GS Watermark Options')
gs_group.add_argument('--message_width_in_bytes', type=int, default=32, help='[GS] Message width in bytes.')
gs_group.add_argument('--num_replications', type=int, default=64, help='[GS] Number of message replications.')
gs_group.add_argument('--l', type=int, default=1, help='[GS] Parameter l.')
gs_group.add_argument('--offset', type=int, default=0, help='[GS] Offset value.')
gs_group.add_argument('--message', type=str, default=None, help='[GS] Message to embed.')
gs_group.add_argument('--key', type=str, default=None, help='[GS] Encryption key.')
gs_group.add_argument('--nonce', type=str, default=None, help='[GS] Nonce value.')

# --- TR (Tree-Ring) Settings ---
tr_group = parser.add_argument_group('TR Watermark Options')
tr_group.add_argument('--w_seed', type=int, default=42, help='[TR] Seed for watermark pattern generation.')
tr_group.add_argument('--w_channel', type=int, default=3, help='[TR] Channel to inject the watermark into.')
tr_group.add_argument('--w_pattern', type=str, default='ring', help='[TR] Watermark pattern type.')
tr_group.add_argument('--w_mask_shape', type=str, default='circle', help='[TR] Watermark mask shape.')
tr_group.add_argument('--w_radius', type=int, default=10, help='[TR] Radius of the watermark pattern.')
tr_group.add_argument('--w_measurement', type=str, default='l1_complex', help='[TR] Watermark measurement method.')
tr_group.add_argument('--w_injection', type=str, default='complex', help='[TR] Watermark injection method.')
tr_group.add_argument('--w_pattern_const', type=int, default=0, help='[TR] Watermark pattern constant.')

# --- PQ (PQIM) Settings ---
pq_group = parser.add_argument_group('PQIM Watermark Options')
pq_group.add_argument('--payload_bits', default=256, type=int)
pq_group.add_argument('--q_step', default=np.pi, type=float)
pq_group.add_argument('--n_bins', default=20, type=int)
pq_group.add_argument('--r_min_ratio', default=0.1, type=float)
pq_group.add_argument('--r_max_ratio', default=0.7, type=float)
pq_group.add_argument('--amp_threshold_percentile', default=100, type=int)

# --- zodiac settings ---
zodiac_group = parser.add_argument_group('ZoDiac Method Options')
zodiac_group.add_argument('--save_img', type=str, default="", help='Directory to save output images.')
zodiac_group.add_argument('--model_id', type=str, default="stabilityai/stable-diffusion-2-1-base", help='Stable Diffusion model ID from Hugging Face.')
zodiac_group.add_argument('--empty_prompt', action=argparse.BooleanOptionalAction, default=True, help='Use an empty prompt for unconditional generation.')
zodiac_group.add_argument('--start_latents', type=str, default="init_w", help='Initial latents to start from.')
zodiac_group.add_argument('--lr', type=float, default=0.09, help='Learning rate for optimization.')
zodiac_group.add_argument('--iters', type=int, default=100, help='Number of optimization iterations.')
# zodiac_group.add_argument('--save_iters', type=int, nargs='+', default=list(range(10, 101, 20)), help='List of iterations at which to save intermediate images.')
# zodiac_group.add_argument('--loss_weights', type=float, nargs='+', default=[10.0, 0.1, 1.0, 0.0], help='Weights for the different loss components.')
zodiac_group.add_argument('--loss_weights', type=float, nargs='+', default=[10.0, 0.1, 10.0, 0.1], help='Weights for the different loss components.')
zodiac_group.add_argument('--ssim_threshold', type=float, default=0.92, help='SSIM threshold for loss calculation.')
zodiac_group.add_argument('--detect_threshold', type=float, default=0.9, help='Detection threshold for the watermark.')
zodiac_group.add_argument('--milestones', type=int, nargs='+', default=[30, 60, 90], help='Scheduler milestones for learning rate adjustment.')

zodiac_group.add_argument('--eval_steps', type=int, default=20, help='')
zodiac_group.add_argument('--training_steps', type=int, default=100, help='')
# zodiac_group.add_argument('--eval_dir', type=str, default='eval_pqim', help='')
zodiac_group.add_argument('--eval_dir', type=str, default='real_eval_pqim', help='')

# parser.add_argument('--datasets', type=str, default='output_images_wo_wm', help='')
parser.add_argument('--datasets', type=str, default='real_images', help='')
# parser.add_argument('--percent', type=float, default=0.01, help='')
parser.add_argument('--percent', type=float, default=1, help='')
parser.add_argument('--device', type=str, default='cuda:3', help='Device to run on (e.g., "cuda:0", "cpu").')

args = parser.parse_args()


# --------------------------------------------------------------------------------
### 1. initialization (pipe, dataset, wmprovider, lossprovider, text_embedding, ...) ###

set_seed(args.seed)

scheduler = DDIMScheduler.from_pretrained(args.model_id, subfolder="scheduler")
pipe = WMDetectStableDiffusionPipeline.from_pretrained(args.model_id, scheduler=scheduler).to(args.device)
pipe.set_progress_bar_config(disable=True)

imagedirs = []
for ext in ['png', 'jpg', 'jpeg']:
    imagedirs += glob(args.datasets + '/**.' + ext, recursive=True)
imagedirs.sort()
imagedirs = imagedirs[:int(len(imagedirs) * args.percent)]

# imagedirs = imagedirs[:200]
# imagedirs = imagedirs[200:400]
# imagedirs = imagedirs[400:600]
# imagedirs = imagedirs[600:800]
# imagedirs = imagedirs[800:]

# imagedirs = imagedirs[760:770]
# imagedirs = imagedirs[770:780]
# imagedirs = imagedirs[780:790]
# imagedirs = imagedirs[790:800]

text_embedding = pipe.get_text_embedding(args.prompt)
invert_text_embedding = pipe.get_text_embedding(args.invert_prompt)

lossprovider = LossProvider(
    loss_weights=args.loss_weights,
    device=args.device,
)

latent_shape = (1, 4, 64, 64) # 512x512

wm_provider = WmProviders[args.wm_type].value(
    latent_shape=latent_shape,
    **vars(args),
)

wm_mask = lossprovider.omegagroups2wmmask(
    wm_provider.gt_patch['omega_groups'],
    latent_shape=latent_shape,
)


### 1. initialization (pipe, dataset, wmprovider, lossprovider, text_embedding, ...) ###
# --------------------------------------------------------------------------------
### 2. run for loop ###

for imagedir in tqdm(imagedirs, total=len(imagedirs)):
    imagename = os.path.basename(imagedir).split('.')[0]
    image = Image.open(imagedir)

    ### init wm latents ###
    gt_img_tensor = get_img_tensor(imagedir, args.device) # gt_image

    gt_init_latents_approx = get_init_latent(
        img_tensor=gt_img_tensor,
        pipe=pipe,
        text_embeddings=invert_text_embedding,
    )

    gt_wm_initial_results = wm_provider.get_wm_latents(
        latents_clean=gt_init_latents_approx,
        seed=args.seed,
    )
    gt_wm_init_latents_approx = gt_wm_initial_results['zT_torch'] # gt_latent
    gt_message_bits_str = gt_wm_initial_results['message_bits_str_list'][0]

    init_latents = gt_wm_init_latents_approx.detach().clone() # pred_latents (variable)

    pred_img_tensor = pipe( # pred_image (variable)
        prompt=args.prompt,
        guidance_scale=args.guidance_scale,
        num_inference_steps=args.steps,
        output_type='tensor',
        use_trainable_latents=True,
        init_latents=init_latents,
    )
    ### init wm latents ###

    ### init training tools ###
    init_latents.requires_grad = True
    optimizer = optim.Adam([init_latents], lr=args.lr)
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=0.3) 
    ### init training tools ###

    os.makedirs(args.eval_dir, exist_ok=True)
    for step in tqdm(range(args.training_steps), total=args.training_steps):
        ### insert wm ###
        wm_initial_results = wm_provider.get_wm_latents(
            latents_clean=init_latents,
            seed=args.seed,
        )
        wm_init_latents = wm_initial_results["zT_torch"]
        assert gt_message_bits_str == wm_initial_results["message_bits_str_list"][0], \
            f'gt_message_bits_str is not consistent, which should {gt_message_bits_str}, but {wm_initial_results["message_bits_str_list"][0]}'
        ### insert wm ###

        ### gen img ###
        pred_img_tensor = pipe(  # pred_image (variable)
            prompt=args.prompt,
            guidance_scale=args.guidance_scale,
            num_inference_steps=args.steps,
            output_type='tensor',
            use_trainable_latents=True,
            init_latents=wm_init_latents,
        ).images
        ### gen img ###

        ### calculate loss ###
        loss = lossprovider(
            pred_img_tensor=pred_img_tensor,
            gt_img_tensor=gt_img_tensor,
            pred_inverted_wm_latents=wm_init_latents,
            gt_inverted_wm_latents=gt_wm_init_latents_approx,
            wm_mask=wm_mask,
        )
        ### calculate loss ###

        ### (option) eval acc & save ###
        if (args.eval_steps is not None and (step+1) % args.eval_steps == 0):
            pred_wm_latents = get_init_latent(
                img_tensor=pred_img_tensor,
                pipe=pipe,
                text_embeddings=invert_text_embedding,
            )
            acc_results = wm_provider.get_accuracies(pred_wm_latents)

            save_img(
                os.path.join(args.eval_dir, f'{imagename}-{step+1}-{int(acc_results["bit_accuracies"][0]*100)}.png'),
                pred_img_tensor,
                pipe,
            )

            print(f'{step+1} step: [Loss] {loss.item():.2f} | [ACC] {acc_results["bit_accuracies"][0]:.2f}')
        else:
            print(f'{step+1} step: [Loss] {loss.item():.2f}')
        ### (option) eval acc & save ###

        ### update ###
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        ### update ###

    ### insert final wm ###
    wm_initial_results = wm_provider.get_wm_latents(
        latents_clean=init_latents,
        seed=args.seed,
    )
    wm_init_latents = wm_initial_results["zT_torch"]
    assert gt_message_bits_str == wm_initial_results["message_bits_str_list"][0], \
        f'gt_message_bits_str is not consistent, which should {gt_message_bits_str}, but {wm_initial_results["message_bits_str_list"][0]}'
    ### insert final wm ###

    ### generate final img ###
    pred_img_tensor = pipe( # pred_image (variable)
        prompt=args.prompt,
        guidance_scale=args.guidance_scale,
        num_inference_steps=args.steps,
        output_type='tensor',
        use_trainable_latents=True,
        init_latents=wm_init_latents,
    ).images
    ### generate final img ###

    ### detect final wm ###
    pred_wm_latents = get_init_latent(
        img_tensor=pred_img_tensor,
        pipe=pipe,
        text_embeddings=invert_text_embedding,
    )
    acc_results = wm_provider.get_accuracies(pred_wm_latents)
    ### detect final wm ###
    print(f'FINAL: [ACC] {acc_results["bit_accuracies"][0]:.2f}')