import argparse
import yaml
import os
import logging
import shutil
import numpy as np
from PIL import Image 
logger = logging.getLogger()
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.addHandler(handler)

import torch
import torch.optim as optim
import torchvision.transforms as transforms
from diffusers import DDIMScheduler
from datasets import load_dataset
from diffusers.utils.torch_utils import randn_tensor

from main.wmdiffusion import WMDetectStableDiffusionPipeline
from main.wmpatch import GTWatermark, GTWatermarkMulti
from main.utils import *
from loss.loss import LossProvider
from loss.pytorch_ssim import ssim
from glob import glob
from pq_utils import set_seed
from tqdm import tqdm

# --------------------------------------------------------------------------------


logging.info(f'===== Load Config =====')
device = torch.device('cuda:1')
# with open('./example/config/config.yaml', 'r') as file:
#     cfgs = yaml.safe_load(file)
cfgs = {
    "method": "ZoDiac",
    "save_img": "real_eval_zod",
    "model_id": "stabilityai/stable-diffusion-2-1-base",
    "gen_seed": 0,
    "empty_prompt": True,

    "w_type": "single",
    "w_channel": 3,
    "w_radius": 10,
    "w_seed": 10,

    "start_latents": "init_w",
    "iters": 100,
    "save_iters": [100],
    "loss_weights": [10.0, 0.1, 1.0, 0.0],
    "ssim_threshold": 0.92,
    "detect_threshold": 0.9,

    # "datasets": "output_images_wo_wm",
    "datasets": "real_images",
    "percent": 1.0,
}
logging.info(cfgs)

set_seed(42)

# --------------------------------------------------------------------------------
### necessary setup for all sections ###

logging.info(f'===== Init Pipeline =====')
if cfgs['w_type'] == 'single':
    wm_pipe = GTWatermark(device, w_channel=cfgs['w_channel'], w_radius=cfgs['w_radius'], generator=torch.Generator(device).manual_seed(cfgs['w_seed']))
elif cfgs['w_type'] == 'multi':
    wm_pipe = GTWatermarkMulti(device, w_settings=cfgs['w_settings'], generator=torch.Generator(device).manual_seed(cfgs['w_seed']))

scheduler = DDIMScheduler.from_pretrained(cfgs['model_id'], subfolder="scheduler")
pipe = WMDetectStableDiffusionPipeline.from_pretrained(cfgs['model_id'], scheduler=scheduler).to(device)
pipe.set_progress_bar_config(disable=True)


# --------------------------------------------------------------------------------


imagedirs = []
for ext in ['png', 'jpg', 'jpeg']:
    imagedirs += glob(cfgs["datasets"] + '/**.' + ext, recursive=True)
imagedirs.sort()
imagedirs = imagedirs[:int(len(imagedirs) * cfgs["percent"])]

# imagedirs = imagedirs[:200]
# imagedirs = imagedirs[200:400]
# imagedirs = imagedirs[400:600]
# imagedirs = imagedirs[640:800]
# imagedirs = imagedirs[800:]

# imagedirs = imagedirs[::-1]

# --------------------------------------------------------------------------------

wm_path = cfgs['save_img']
os.makedirs(wm_path, exist_ok=True)

for imagedir in tqdm(imagedirs, total=len(imagedirs)):

    gt_img_tensor = get_img_tensor(f'{imagedir}', device)

    imagename = os.path.basename(imagedir)

    path = os.path.join(wm_path, f"{imagename.split('.')[0]}-{cfgs['save_iters'][-1]}.png")
    if os.path.exists(os.path.join(wm_path, f"{os.path.basename(path).split('.')[0]}.png")):
        print(f'skipping... {imagename}')
        continue


    # --------------------------------------------------------------------------------
    ### image watermarking ###

    # Step 1: Get init noise
    def get_init_latent(img_tensor, pipe, text_embeddings, guidance_scale=1.0):
        # DDIM inversion from the given image
        img_latents = pipe.get_image_latents(img_tensor, sample=False)
        reversed_latents = pipe.forward_diffusion(
            latents=img_latents,
            text_embeddings=text_embeddings,
            guidance_scale=guidance_scale,
            num_inference_steps=50,
        )
        return reversed_latents

    empty_text_embeddings = pipe.get_text_embedding('')
    init_latents_approx = get_init_latent(gt_img_tensor, pipe, empty_text_embeddings)


    # --------------------------------------------------------------------------------


    # Step 2: prepare training
    init_latents = init_latents_approx.detach().clone()
    init_latents.requires_grad = True
    optimizer = optim.Adam([init_latents], lr=0.01)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,80], gamma=0.3) 

    totalLoss = LossProvider(cfgs['loss_weights'], device)
    loss_lst = [] 


    # --------------------------------------------------------------------------------


    # Step 3: train the init latents
    for i in range(cfgs['iters']):
        logging.info(f'iter {i}:')
        init_latents_wm = wm_pipe.inject_watermark(init_latents)
        if cfgs['empty_prompt']:
            pred_img_tensor = pipe('', guidance_scale=1.0, num_inference_steps=50, output_type='tensor', use_trainable_latents=True, init_latents=init_latents_wm).images
        else:
            pred_img_tensor = pipe(prompt, num_inference_steps=50, output_type='tensor', use_trainable_latents=True, init_latents=init_latents_wm).images
        loss = totalLoss(pred_img_tensor, gt_img_tensor, init_latents_wm, wm_pipe)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        loss_lst.append(loss.item())
        # save watermarked image
        if (i+1) in cfgs['save_iters']:
            path = os.path.join(wm_path, f"{imagename.split('.')[0]}-{i+1}.png")
            save_img(path, pred_img_tensor, pipe)
    torch.cuda.empty_cache()


    # --------------------------------------------------------------------------------
    ### postprocessing with adaptive enhancement


    # hyperparameter
    ssim_threshold = cfgs['ssim_threshold']

    wm_img_path = os.path.join(wm_path, f"{imagename.split('.')[0]}-{cfgs['save_iters'][-1]}.png")
    wm_img_tensor = get_img_tensor(wm_img_path, device)
    ssim_value = ssim(wm_img_tensor, gt_img_tensor).item()
    logging.info(f'Original SSIM {ssim_value}')


    # --------------------------------------------------------------------------------


    def binary_search_theta(threshold, lower=0., upper=1., precision=1e-6, max_iter=1000):
        for i in range(max_iter):
            mid_theta = (lower + upper) / 2
            img_tensor = (gt_img_tensor-wm_img_tensor)*mid_theta+wm_img_tensor
            ssim_value = ssim(img_tensor, gt_img_tensor).item()

            if ssim_value <= threshold:
                lower = mid_theta
            else:
                upper = mid_theta
            if upper - lower < precision:
                break
        return lower

    optimal_theta = binary_search_theta(ssim_threshold, precision=0.01)
    logging.info(f'Optimal Theta {optimal_theta}')

    img_tensor = (gt_img_tensor-wm_img_tensor)*optimal_theta+wm_img_tensor

    ssim_value = ssim(img_tensor, gt_img_tensor).item()
    psnr_value = compute_psnr(img_tensor, gt_img_tensor)

    tester_prompt = '' 
    text_embeddings = pipe.get_text_embedding(tester_prompt)
    det_prob = 1 - watermark_prob(img_tensor, pipe, wm_pipe, text_embeddings, device=device)

    path = os.path.join(wm_path, f"{os.path.basename(wm_img_path).split('.')[0]}-{int(det_prob * 100)}.png")
    save_img(path, img_tensor, pipe)

