import os
import yaml
import typing
import shutil
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
from glob import glob

import torch
import torch.optim as optim
import torchvision.transforms as transforms

from diffusers import DDIMScheduler
from datasets import load_dataset
from diffusers.utils.torch_utils import randn_tensor

from loss.pytorch_ssim import ssim
from main.wmpatch import GTWatermark, GTWatermarkMulti
from main.wmdiffusion import WMDetectStableDiffusionPipeline
from pq_utils import LossProvider, get_init_latent, binary_search_theta, set_seed, get_img_tensor, save_img
from utils.wm.wm_utils import WmProviders
from utils.image_utils import torch_to_PIL


from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
from cryptography.hazmat.backends import default_backend
import numpy as np
from scipy.stats import norm


def custom_get_wm_latents(self, latents_clean, **kwargs) -> typing.Dict[str, any]:
    """
    Get Watermarked latents and barcodes.
    If latents_clean is provided and requires grad, uses a differentiable path.

    @param latents_clean: Optional target latent tensor to make the output similar.
    @return: dict
    """
    latents_torch = []
    barcodes_torch = []
    message_bits_str_list = []
    
    # for 루프 인덱스 'i'와 STEP 3의 루프 인덱스 'i'가 겹치므로 외부 루프를 'b_idx'로 변경
    for b_idx in range(self.offset, self.batch_size + self.offset):
        # get the message, key and nonce
        message_bytes = self.messages[b_idx]
        message_bits_str_list.append(''.join(format(byte, '08b') for byte in self.messages[b_idx]))
        key_bytes = self.keys[b_idx]
        nonce_bytes = self.nonces[b_idx]

        # ----------------------------------------------------- STEP 1: Replicate the message to get a barcode -----------------------------------------------------
        replicated_message_bytes = message_bytes * self.num_replications
        replicated_message_bits_str = ''.join(format(byte, '08b') for byte in replicated_message_bytes)
        barcode_ints_2d_np = np.array([int(b, 2) for b in replicated_message_bits_str]).reshape(self.num_replications, self.message_width_in_bits)
        barcode_ints_2d_torch = torch.tensor(barcode_ints_2d_np, dtype=torch.uint8, device=self.device)
        barcodes_torch.append(barcode_ints_2d_torch)

        # ------------------------------------------------------------- STEP 2: Encrypt the barcode -------------------------------------------------------------
        cipher = Cipher(algorithms.ChaCha20(key_bytes, nonce_bytes), mode=None, backend=default_backend())
        encryptor = cipher.encryptor()
        encrypted_bytes = encryptor.update(replicated_message_bytes) + encryptor.finalize()
        encrypted_bits_str = ''.join(format(byte, '08b') for byte in encrypted_bytes)

        # ------------------------------------- STEP 3: Embed the encrypted message (Vectorized & Differentiable) -------------------------------------
        
        # 1. 암호화된 비트 문자열을 l-bit 정수 리스트로 변환
        num_pixels = len(encrypted_bits_str) // self.l
        y_list = [int(encrypted_bits_str[j:j + self.l], 2) for j in range(0, num_pixels * self.l, self.l)]

        # --- 미분 가능 경로 (PyTorch 연산) ---
        y = torch.tensor(y_list, dtype=torch.long, device=self.device)
        target_latent_flat = latents_clean[b_idx].flatten().to(self.device)

        sqrt_2 = torch.sqrt(torch.tensor(2.0, device=self.device, dtype=self.dtype))
        
        # norm.cdf의 PyTorch 구현
        cdf_val = 0.5 * (1 + torch.erf(target_latent_flat / sqrt_2))
        
        # u를 역산
        u = (cdf_val * (2**self.l)) - y
        u = torch.clamp(u, 0.0, 1.0)
        
        # norm.ppf의 PyTorch 구현
        p = (u + y) / (2**self.l)
        # p가 0 또는 1이 되는 것을 방지하여 erfinv의 NaN 오류 방지
        p = torch.clamp(p, 1e-7, 1.0 - 1e-7)
        pixel_values = sqrt_2 * torch.special.erfinv(2 * p - 1)

        latent_torch = pixel_values.reshape(self.latent_shape[1:])
        latents_torch.append(latent_torch)
        
    # finalize
    latents_torch = torch.stack(latents_torch, dim=0)
    barcodes_torch = torch.stack(barcodes_torch, dim=0)

    # PIL 변환 및 결과 딕셔너리 생성은 그대로 유지
    latents_PIL = torch_to_PIL(latents_torch)
    barcodes_PIL = torch_to_PIL(barcodes_torch)

    results_dict = {"zT_torch": latents_torch,
                    "zT_PIL": latents_PIL,
                    "zT": latents_PIL,
                    "barcodes_torch": barcodes_torch,
                    "barcodes_PIL": barcodes_PIL,
                    "barcodes": barcodes_PIL,
                    "message_bits_str_list": message_bits_str_list
                    }

    return results_dict

# --------------------------------------------------------------------------------

parser = argparse.ArgumentParser(description="Configuration for image watermarking and generation.")

# --- General Settings ---
parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility.')
# parser.add_argument('--input_image', type=str, default='/home/datasets/flickr30k_images/flickr30k_images_resized/1000092795.jpg', help='Path to the input image file.')
parser.add_argument('--prompt', type=str, default='', help='Prompt for image generation.')
parser.add_argument('--invert_prompt', type=str, default='', help='Prompt for image inversion.')
parser.add_argument('--steps', type=int, default=50, help='Number of diffusion steps.')
parser.add_argument('--guidance_scale', type=float, default=1.0, help='Guidance scale value.')

# --- Watermark Type ---
parser.add_argument('--wm_type', type=str, default='GS', choices=['GS', 'TR', 'PQ'], help='Watermark type to use (GS, TR or PQ).')

# --- GS (Guided Steganography) Settings ---
gs_group = parser.add_argument_group('GS Watermark Options')
gs_group.add_argument('--message_width_in_bytes', type=int, default=32, help='[GS] Message width in bytes.')
gs_group.add_argument('--num_replications', type=int, default=64, help='[GS] Number of message replications.')
gs_group.add_argument('--l', type=int, default=1, help='[GS] Parameter l.')
gs_group.add_argument('--offset', type=int, default=0, help='[GS] Offset value.')
gs_group.add_argument('--message', type=str, default=None, help='[GS] Message to embed.')
gs_group.add_argument('--key', type=str, default=None, help='[GS] Encryption key.')
gs_group.add_argument('--nonce', type=str, default=None, help='[GS] Nonce value.')

# --- TR (Tree-Ring) Settings ---
tr_group = parser.add_argument_group('TR Watermark Options')
tr_group.add_argument('--w_seed', type=int, default=42, help='[TR] Seed for watermark pattern generation.')
tr_group.add_argument('--w_channel', type=int, default=3, help='[TR] Channel to inject the watermark into.')
tr_group.add_argument('--w_pattern', type=str, default='ring', help='[TR] Watermark pattern type.')
tr_group.add_argument('--w_mask_shape', type=str, default='circle', help='[TR] Watermark mask shape.')
tr_group.add_argument('--w_radius', type=int, default=10, help='[TR] Radius of the watermark pattern.')
tr_group.add_argument('--w_measurement', type=str, default='l1_complex', help='[TR] Watermark measurement method.')
tr_group.add_argument('--w_injection', type=str, default='complex', help='[TR] Watermark injection method.')
tr_group.add_argument('--w_pattern_const', type=int, default=0, help='[TR] Watermark pattern constant.')

# --- PQ (PQIM) Settings ---
pq_group = parser.add_argument_group('PQIM Watermark Options')
pq_group.add_argument('--payload_bits', default=256, type=int)
pq_group.add_argument('--q_step', default=np.pi, type=float)
pq_group.add_argument('--n_bins', default=20, type=int)
pq_group.add_argument('--r_min_ratio', default=0.1, type=float)
pq_group.add_argument('--r_max_ratio', default=0.7, type=float)
pq_group.add_argument('--amp_threshold_percentile', default=100, type=int)

# --- zodiac settings ---
zodiac_group = parser.add_argument_group('ZoDiac Method Options')
zodiac_group.add_argument('--save_img', type=str, default="", help='Directory to save output images.')
zodiac_group.add_argument('--model_id', type=str, default="stabilityai/stable-diffusion-2-1-base", help='Stable Diffusion model ID from Hugging Face.')
zodiac_group.add_argument('--empty_prompt', action=argparse.BooleanOptionalAction, default=True, help='Use an empty prompt for unconditional generation.')
zodiac_group.add_argument('--start_latents', type=str, default="init_w", help='Initial latents to start from.')
# zodiac_group.add_argument('--lr', type=float, default=0.09, help='Learning rate for optimization.')
zodiac_group.add_argument('--lr', type=float, default=0.1, help='Learning rate for optimization.')
zodiac_group.add_argument('--iters', type=int, default=100, help='Number of optimization iterations.')
# zodiac_group.add_argument('--save_iters', type=int, nargs='+', default=list(range(10, 101, 20)), help='List of iterations at which to save intermediate images.')
# zodiac_group.add_argument('--loss_weights', type=float, nargs='+', default=[10.0, 0.1, 1.0, 0.0], help='Weights for the different loss components.')
zodiac_group.add_argument('--loss_weights', type=float, nargs='+', default=[10.0, 0.1, 10.0, 0.1], help='Weights for the different loss components.')
zodiac_group.add_argument('--ssim_threshold', type=float, default=0.92, help='SSIM threshold for loss calculation.')
zodiac_group.add_argument('--detect_threshold', type=float, default=0.9, help='Detection threshold for the watermark.')
zodiac_group.add_argument('--milestones', type=int, nargs='+', default=[30, 60, 90], help='Scheduler milestones for learning rate adjustment.')

zodiac_group.add_argument('--eval_steps', type=int, default=20, help='')
zodiac_group.add_argument('--training_steps', type=int, default=100, help='')
# zodiac_group.add_argument('--eval_dir', type=str, default='eval_gs', help='')
zodiac_group.add_argument('--eval_dir', type=str, default='real_eval_gs', help='')

# parser.add_argument('--datasets', type=str, default='output_images_wo_wm', help='')
parser.add_argument('--datasets', type=str, default='real_images', help='')
parser.add_argument('--percent', type=float, default=1, help='')
parser.add_argument('--device', type=str, default='cuda:3', help='Device to run on (e.g., "cuda:0", "cpu").')

args = parser.parse_args()


# --------------------------------------------------------------------------------
### 1. initialization (pipe, dataset, wmprovider, lossprovider, text_embedding, ...) ###

set_seed(args.seed)

scheduler = DDIMScheduler.from_pretrained(args.model_id, subfolder="scheduler")
pipe = WMDetectStableDiffusionPipeline.from_pretrained(args.model_id, scheduler=scheduler).to(args.device)
pipe.set_progress_bar_config(disable=True)

imagedirs = []
for ext in ['png', 'jpg', 'jpeg']:
    imagedirs += glob(args.datasets + '/**.' + ext, recursive=True)
imagedirs.sort()
imagedirs = imagedirs[:int(len(imagedirs) * args.percent)]

# imagedirs = imagedirs[:200]
# imagedirs = imagedirs[200:400]
# imagedirs = imagedirs[400:600]
# imagedirs = imagedirs[600:800]
# imagedirs = imagedirs[979:980]

# imagedirs = imagedirs[::-1]

text_embedding = pipe.get_text_embedding(args.prompt)
invert_text_embedding = pipe.get_text_embedding(args.invert_prompt)

lossprovider = LossProvider(
    loss_weights=args.loss_weights,
    device=args.device,
)

latent_shape = (1, 4, 64, 64) # 512x512

wm_provider = WmProviders[args.wm_type].value(
    latent_shape=latent_shape,
    **vars(args),
)

# wm_mask = lossprovider.omegagroups2wmmask(
#     wm_provider.gt_patch['omega_groups'],
#     latent_shape=latent_shape,
# )

if args.wm_type == 'GS':
    wm_provider.get_wm_latents = custom_get_wm_latents.__get__(wm_provider, WmProviders)


### 1. initialization (pipe, dataset, wmprovider, lossprovider, text_embedding, ...) ###
# --------------------------------------------------------------------------------
### 2. run for loop ###

for imagedir in tqdm(imagedirs, total=len(imagedirs)):
    imagename = os.path.basename(imagedir).split('.')[0]
    image = Image.open(imagedir)

    ### init wm latents ###
    gt_img_tensor = get_img_tensor(imagedir, args.device) # gt_image

    gt_init_latents_approx = get_init_latent(
        img_tensor=gt_img_tensor,
        pipe=pipe,
        text_embeddings=invert_text_embedding,
    )

    gt_wm_initial_results = wm_provider.get_wm_latents(
        latents_clean=gt_init_latents_approx,
        seed=args.seed,
    )
    gt_wm_init_latents_approx = gt_wm_initial_results['zT_torch'] # gt_latent
    gt_message_bits_str = gt_wm_initial_results['message_bits_str_list'][0]

    init_latents = gt_wm_init_latents_approx.detach().clone() # pred_latents (variable)

    pred_img_tensor = pipe( # pred_image (variable)
        prompt=args.prompt,
        guidance_scale=args.guidance_scale,
        num_inference_steps=args.steps,
        output_type='tensor',
        use_trainable_latents=True,
        init_latents=init_latents,
    )
    ### init wm latents ###

    ### init training tools ###
    init_latents.requires_grad = True
    optimizer = optim.Adam([init_latents], lr=args.lr)
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=0.3) 
    ### init training tools ###

    os.makedirs(args.eval_dir, exist_ok=True)
    for step in tqdm(range(args.training_steps), total=args.training_steps):
        ### insert wm ###
        wm_initial_results = wm_provider.get_wm_latents(
            latents_clean=init_latents,
            seed=args.seed,
        )
        wm_init_latents = wm_initial_results["zT_torch"]
        assert gt_message_bits_str == wm_initial_results["message_bits_str_list"][0], \
            f'gt_message_bits_str is not consistent, which should {gt_message_bits_str}, but {wm_initial_results["message_bits_str_list"][0]}'
        ### insert wm ###

        ### gen img ###
        pred_img_tensor = pipe(  # pred_image (variable)
            prompt=args.prompt,
            guidance_scale=args.guidance_scale,
            num_inference_steps=args.steps,
            output_type='tensor',
            use_trainable_latents=True,
            init_latents=wm_init_latents,
        ).images
        ### gen img ###

        ### calculate loss ###
        loss = lossprovider(
            pred_img_tensor=pred_img_tensor,
            gt_img_tensor=gt_img_tensor,
            pred_inverted_wm_latents=wm_init_latents,
            gt_inverted_wm_latents=gt_wm_init_latents_approx,
            wm_mask=None,
        )
        ### calculate loss ###

        ### (option) eval acc & save ###
        if (args.eval_steps is not None and (step+1) % args.eval_steps == 0):
            pred_wm_latents = get_init_latent(
                img_tensor=pred_img_tensor,
                pipe=pipe,
                text_embeddings=invert_text_embedding,
            )
            acc_results = wm_provider.get_accuracies(pred_wm_latents)

            save_img(
                os.path.join(args.eval_dir, f'{imagename}-{step+1}-{int(acc_results["bit_accuracies"][0]*100)}.png'),
                pred_img_tensor,
                pipe,
            )

            print(f'{step+1} step: [Loss] {loss.item():.2f} | [ACC] {acc_results["bit_accuracies"][0]:.2f}')
        else:
            print(f'{step+1} step: [Loss] {loss.item():.2f}')
        ### (option) eval acc & save ###

        ### update ###
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        ### update ###

    ### insert final wm ###
    wm_initial_results = wm_provider.get_wm_latents(
        latents_clean=init_latents,
        seed=args.seed,
    )
    wm_init_latents = wm_initial_results["zT_torch"]
    assert gt_message_bits_str == wm_initial_results["message_bits_str_list"][0], \
        f'gt_message_bits_str is not consistent, which should {gt_message_bits_str}, but {wm_initial_results["message_bits_str_list"][0]}'
    ### insert final wm ###

    ### generate final img ###
    pred_img_tensor = pipe( # pred_image (variable)
        prompt=args.prompt,
        guidance_scale=args.guidance_scale,
        num_inference_steps=args.steps,
        output_type='tensor',
        use_trainable_latents=True,
        init_latents=wm_init_latents,
    ).images
    ### generate final img ###

    ### detect final wm ###
    pred_wm_latents = get_init_latent(
        img_tensor=pred_img_tensor,
        pipe=pipe,
        text_embeddings=invert_text_embedding,
    )
    acc_results = wm_provider.get_accuracies(pred_wm_latents)
    ### detect final wm ###
    print(f'FINAL: [ACC] {acc_results["bit_accuracies"][0]:.2f}')