import os
import copy
import torch
import logging
import argparse
from tqdm import tqdm

from src import utils
from src.io_utils import *
from src.optim_utils import *
from lora_diffusion import inject_trainable_lora
from src.inversion.inv_pipe import InversionPipeline
from src.fari import inject_fari, one_step_inversion


def main(args):
    # setup for evaluation
    output_path = os.path.join(args.output_dir, args.name)
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    fh = logging.FileHandler(f'./{output_path}/TR.log')
    fh.setLevel(logging.INFO)
    logger.addHandler(fh)

    settings = vars(args)
    print(settings)
    with open(f"./{output_path}/tr_settings.json", "w") as f:
        json.dump(settings, f, indent=4)
    
    # load diffusion model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    pipe = InversionPipeline.from_pretrained(args.model_id).to(device)
    pipe.set_progress_bar_config(disable=True)
    null_text_embeds, _ = pipe.encode_prompt("", pipe._execution_device, 1, False)

    trainable_params, _ = inject_trainable_lora(pipe.unet, r=args.lora_r)
    pipe.unet = inject_fari(pipe.unet)
    pipe.unet.load_state_dict(torch.load(f"./{output_path}/fari_weights.pth", weights_only=True), strict=False)
    pipe.unet.requires_grad_(False)

    # dataset
    dataset, prompt_key = get_dataset(args)

    # ground-truth patch
    gt_patch = get_watermarking_pattern(pipe, args, device)

    results = []
    for i in tqdm(range(args.start, args.end), desc="Validation TR"):
        seed = i + args.gen_seed
        current_prompt = dataset[i][prompt_key]
        
        ### generation
        # generation without watermarking
        set_random_seed(seed)
        init_latents_no_w = torch.randn(1, 4, 64, 64).to(device)  # 64x64 for SD2.1
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            outputs_no_w = pipe(
                current_prompt,
                num_images_per_prompt=args.num_images,
                guidance_scale=args.guidance_scale,
                num_inference_steps=args.num_inference_steps,
                height=args.image_length,
                width=args.image_length,
                latents=init_latents_no_w,
                )
        orig_image_no_w = outputs_no_w.images[0]
        
        # generation with watermarking
        if init_latents_no_w is None:
            set_random_seed(seed)
            init_latents_w = pipe.get_random_latents()
        else:
            init_latents_w = copy.deepcopy(init_latents_no_w)

        # get watermarking mask
        watermarking_mask = get_watermarking_mask(init_latents_w, args, device)

        # inject watermark
        init_latents_w = inject_watermark(init_latents_w, watermarking_mask, gt_patch, args)

        with torch.autocast(device_type='cuda', dtype=torch.float16):
            outputs_w = pipe(
                current_prompt,
                num_images_per_prompt=args.num_images,
                guidance_scale=args.guidance_scale,
                num_inference_steps=args.num_inference_steps,
                height=args.image_length,
                width=args.image_length,
                latents=init_latents_w,
                )
        orig_image_w = outputs_w.images[0]
        
        single_result = {}
        ### test watermark
        for j in range(len(noise_list)):
            orig_image_no_w_auged, noise_type = utils.image_distortion(orig_image_no_w, seed, args, j)
            orig_image_w_auged, noise_type = utils.image_distortion(orig_image_w, seed, args, j)
            # reverse img without watermarking
            img_no_w = transform_img(orig_image_no_w_auged).unsqueeze(0).to(device)
            image_latents_no_w = pipe.get_image_latents(img_no_w, sample=False)

            reversed_latents_no_w = one_step_inversion(pipe, image_latents_no_w, prompt_embeds=null_text_embeds.detach())

            # reverse img with watermarking
            img_w = transform_img(orig_image_w_auged).unsqueeze(0).to(device)
            image_latents_w = pipe.get_image_latents(img_w, sample=False)

            reversed_latents_w = one_step_inversion(pipe, image_latents_w, prompt_embeds=null_text_embeds.detach())

            # eval
            no_w_metric, w_metric = eval_watermark(reversed_latents_no_w, reversed_latents_w, watermarking_mask, gt_patch, args)

            logger.info(f'{noise_type}: no_w_metric: {no_w_metric}, w_metric: {w_metric}')

            single_result[noise_type] = {
                'no_w_metric': no_w_metric,
                'w_metric': w_metric,
            }
            results.append(single_result)

    with open(f"./{output_path}/val_tr.json", "w") as f:
        json.dump(results, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='tree-ring watermark')
    parser.add_argument('--name', default='test')
    parser.add_argument('--output_dir', default='results')
    parser.add_argument('--dataset', default='Gustavosta/Stable-Diffusion-Prompts')
    parser.add_argument('--start', default=0, type=int)
    parser.add_argument('--end', default=1000, type=int)
    parser.add_argument('--image_length', default=512, type=int)
    parser.add_argument('--model_id', default='stabilityai/stable-diffusion-2-1-base') # stable-diffusion-v1-5/stable-diffusion-v1-5
    parser.add_argument('--num_images', default=1, type=int)
    parser.add_argument('--guidance_scale', default=7.5, type=float)
    parser.add_argument('--num_inference_steps', default=50, type=int)
    parser.add_argument('--test_num_inference_steps', default=None, type=int)
    parser.add_argument('--gen_seed', default=0, type=int)
    parser.add_argument('--lora_r', default=8, type=int)

    # watermark
    parser.add_argument('--w_seed', default=999999, type=int)
    parser.add_argument('--w_channel', default=3, type=int)
    parser.add_argument('--w_pattern', default='rand')
    parser.add_argument('--w_mask_shape', default='circle')
    parser.add_argument('--w_radius', default=10, type=int)
    parser.add_argument('--w_measurement', default='l1_complex')
    parser.add_argument('--w_injection', default='complex')
    parser.add_argument('--w_pattern_const', default=0, type=float)
    
    # for image distortion
    parser.add_argument('--jpeg_ratio', type=int, default=25)
    parser.add_argument('--random_crop_ratio', type=float, default=0.6)
    parser.add_argument('--random_drop_ratio', type=float, default=0.8)
    parser.add_argument('--gaussian_blur_r', type=int, default=4)
    parser.add_argument('--median_blur_k', type=int, default=7)
    parser.add_argument('--resize_ratio', type=float, default=0.25)
    parser.add_argument('--gaussian_std', type=float, default=0.05)
    parser.add_argument('--sp_prob', type=float, default=0.05)
    parser.add_argument('--brightness_factor', type=float, default=6)
    args = parser.parse_args()

    if args.test_num_inference_steps is None:
        args.test_num_inference_steps = args.num_inference_steps
    
    main(args)