import os
import torch
import logging
import argparse
from tqdm import tqdm

from src.utils import *
from lora_diffusion import inject_trainable_lora
from src.watermark import Gaussian_Shading_chacha
from src.inversion.inv_pipe import InversionPipeline
from src.fari import inject_fari, one_step_inversion

def main(args):
    output_path = os.path.join(args.output_dir, args.name)
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    fh = logging.FileHandler(f'./{output_path}/GS.log')
    fh.setLevel(logging.INFO)
    logger.addHandler(fh)

    settings = vars(args)
    print(settings)
    with open(f"./{output_path}/gs_settings.json", "w") as f:
        json.dump(settings, f, indent=4)

    gs = Gaussian_Shading_chacha(
        ch_factor=args.ch_factor, 
        hw_factor=args.hw_factor, 
        ch_num=args.ch_num, 
        fpr=args.fpr, 
        user_number=args.user_number
    )

    # load diffusion model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pipe = InversionPipeline.from_pretrained(args.model_id).to(device)
    pipe.set_progress_bar_config(disable=True)
    null_text_embeds, _ = pipe.encode_prompt("", pipe._execution_device, 1, False)

    trainable_params, _ = inject_trainable_lora(pipe.unet, r=args.lora_r)
    pipe.unet = inject_fari(pipe.unet)
    pipe.unet.load_state_dict(torch.load(f"./{output_path}/fari_weights.pth", weights_only=True), strict=False)
    pipe.unet.requires_grad_(False)

    results = []
    with tqdm(range(args.val_size), desc="Validation GS") as pbar:
        for i, prompt in enumerate(load_prompt(args.val_dataset_id)):
            set_random_seed(args.seed + i)
            initial_latents = gs.create_watermark_and_return_w()
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                images = pipe(
                    prompt,
                    num_images_per_prompt=1,
                    guidance_scale=args.guidance_scale,
                    num_inference_steps=args.num_inference_steps,
                    height=512,
                    width=512,
                    latents=initial_latents,
                ).images[0]
            results.append({})
            for j in range(10):
                noised_images, noise_type = image_distortion(images, args.seed + i, args, j)
                noised_image_tensor = to_tensor(noised_images).to(device)
                noised_image_latent = pipe.get_image_latents(noised_image_tensor, sample=False)
                noised_latents = one_step_inversion(pipe, noised_image_latent, prompt_embeds=null_text_embeds.detach())

                acc = gs.eval_watermark(noised_latents)
                mse = torch.nn.functional.mse_loss(noised_latents, initial_latents.cuda()).item()

                results[-1][noise_type] = {
                    "acc": acc,
                    "mse": mse,
                }

                logger.info(f"Iter {i} Step {j} Acc: {acc} MSE: {mse}")
            pbar.update(1)

            if i + 1 == args.val_size:
                with open(f"./{output_path}/val_gs.json", "w") as f:
                    json.dump(results, f, indent=4)
                logger.info(f"Iter {i} Save results")
                break


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='gaussian shading watermark')
    parser.add_argument("--name", type=str, default="test")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--output_dir", type=str, default="results")
    parser.add_argument("--model_id", type=str, default="stabilityai/stable-diffusion-2-1-base")
    parser.add_argument("--guidance_scale", type=float, default=7.5)
    parser.add_argument("--num_inference_steps", type=int, default=50)
    parser.add_argument("--lora_r", type=int, default=8)  # LoRA rank
    parser.add_argument("--val_dataset_id", type=str, default="Gustavosta/Stable-Diffusion-Prompts")
    parser.add_argument("--val_size", type=int, default=1000)

    parser.add_argument('--ch_factor', type=int, default=1)
    parser.add_argument('--hw_factor', type=int, default=8)
    parser.add_argument('--ch_num', type=int, default=4)
    parser.add_argument('--fpr', type=float, default=1e-6)
    parser.add_argument('--user_number', type=int, default=1e6)

    parser.add_argument('--jpeg_ratio', type=int, default=25)
    parser.add_argument('--random_crop_ratio', type=float, default=0.6)
    parser.add_argument('--random_drop_ratio', type=float, default=0.8)
    parser.add_argument('--gaussian_blur_r', type=int, default=4)
    parser.add_argument('--median_blur_k', type=int, default=7)
    parser.add_argument('--resize_ratio', type=float, default=0.25)
    parser.add_argument('--gaussian_std', type=float, default=0.05)
    parser.add_argument('--sp_prob', type=float, default=0.05)
    parser.add_argument('--brightness_factor', type=float, default=6)

    args = parser.parse_args()
    main(args)