import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DDIMScheduler
from modified_pix2pixpipeline import InversableStableDiffusionInstructPix2PixPipeline
from PIL import Image, ImageOps
from argparse import ArgumentParser
from watermark import *
from image_utils import *
import os
import open_clip
from sklearn import metrics
import random

def set_random_seed(seed=0):
    torch.manual_seed(seed + 0)
    torch.cuda.manual_seed(seed + 1)
    torch.cuda.manual_seed_all(seed + 2)
    np.random.seed(seed + 3)
    torch.cuda.manual_seed_all(seed + 4)
    random.seed(seed + 5)

def transform_img(image, target_size=512):
    tform = transforms.Compose(
        [
            transforms.Resize(target_size),
            transforms.CenterCrop(target_size),
            transforms.ToTensor(),
        ]
    )
    image = tform(image)
    return 2.0 * image - 1.0
    # return (image + 1.0) /2.0

def main():
    parser = ArgumentParser()
    parser.add_argument("--resolution", default=512, type=int)
    parser.add_argument("--num", default=100, type=int)
    parser.add_argument("--steps", default=100, type=int)
    parser.add_argument("--num_inversion_steps", default=100, type=int)
    parser.add_argument("--config", default="configs/generate.yaml", type=str)
    parser.add_argument("--ckpt", default="checkpoints/instruct-pix2pix-00-22000.ckpt", type=str)
    parser.add_argument("--vae-ckpt", default=None, type=str)
    parser.add_argument("--cfg-text", default=7.5, type=float)
    parser.add_argument('--reference_model', default="ViT-g-14")
    parser.add_argument('--reference_model_pretrain', default="laion2b_s12b_b42k")
    parser.add_argument("--cfg-image", default=1.5, type=float)
    parser.add_argument("--seed", default=12345, type=int)
    
    # watermark
    parser.add_argument('--ring_threshold', default=0.5, type=float)
    parser.add_argument('--channel_copy', default=1, type=int)
    parser.add_argument('--hw_copy', default=8, type=int)
    parser.add_argument('--n_split', default=2, type=int)
    parser.add_argument('--user_number', default=1000000, type=int)
    parser.add_argument('--fpr', default=0.000001, type=float)
    parser.add_argument('--chacha', action='store_true', help='chacha20 for cipher')
    parser.add_argument('--no_encrypt', action='store_true', help='ciper')
    parser.add_argument('--w_r_start', default=10, type=int)
    parser.add_argument('--w_r_end', default=0, type=int)
    parser.add_argument('--w_r_interval', default=1, type=int)
    parser.add_argument('--ring', action='store_true')
    parser.add_argument('--channel', default=2, type=int)
    parser.add_argument('--disk', action='store_true')
    parser.add_argument('--fft', action='store_true')
    
    # for image distortion
    parser.add_argument('--name', default="none", type=str)
    parser.add_argument('--r_degree', default=None, type=int)
    parser.add_argument('--jpeg_ratio', default=None, type=int)
    parser.add_argument('--random_crop_ratio', default=None, type=float)
    parser.add_argument('--random_drop_ratio', default=None, type=float)
    parser.add_argument('--gaussian_blur_r', default=None, type=int)
    parser.add_argument('--median_blur_k', default=None, type=int)
    parser.add_argument('--resize_ratio', default=None, type=float)
    parser.add_argument('--gaussian_std', default=None, type=float)
    parser.add_argument('--sp_prob', default=None, type=float)
    parser.add_argument('--brightness_factor', default=None, type=float)
    parser.add_argument('--h_flip_prob', default=None, type=float)
    parser.add_argument('--v_flip_prob', default=None, type=float)
    parser.add_argument('--x_shift', default=None, type=float)
    parser.add_argument('--y_shift', default=None, type=float)
    parser.add_argument('--perspective_distortion', default=None, type=float)
    args = parser.parse_args()
    
    #reference model for CLIP Score
    if args.reference_model is not None:
        ref_tokenizer = open_clip.get_tokenizer(args.reference_model)
        ref_model, _, ref_clip_preprocess = open_clip.create_model_and_transforms(args.reference_model,
                                                                                  pretrained=args.reference_model_pretrain,
                                                                                  device="cuda")
    watermark = ring(args.channel_copy, args.hw_copy, args.fpr, args.user_number, args)
    model_id = "timbrooks/instruct-pix2pix"
    pipe = InversableStableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)
    pre_path = "../instruct-pix2pix_old/edit_data"
    # pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    # pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    pipe.to("cuda")
    #acc
    acc = []
    acc_no = []
    #CLIP Scores
    clip_scores = []
    clip_scores_no = []
    with open("select_seeds.json") as f:
        input_data = json.load(f)
    inputs = list(input_data.values())
    set_random_seed(args.seed)
    random.shuffle(inputs)
    for i in range(args.num):
        
        input_image_path = inputs[i]["image"]
        input_image = Image.open(os.path.join(pre_path, input_image_path)).convert("RGB")
        input_text = inputs[i]["edit"]
        input_image.save(f"imgs/test_ori_{i}_{input_text}.jpg")
        edit_prompt = inputs[i]["output"]
        # watermarked
        latents = watermark.create_watermark_and_return_w()
        images = pipe(input_text, image=input_image, latents = latents).images
        image_w = images[0]
        image_w.save(f"imgs/test_w_{i}_{input_text}.jpg")
        detect_image = image_distortion(image_w, args.seed, args)
        # detect_image = transform_img(detect_image).unsqueeze(0).to(torch.float16).to("cuda")
        detect_image = pipe.image_processor.preprocess(detect_image).to(device=pipe._execution_device, dtype=torch.float16)
        image_latents_w = pipe.get_image_latents(detect_image, sample=False)
        reversed_latents_w = pipe.forward_diffusion(
                latents=image_latents_w,
                image = detect_image,
                prompt = "",
                guidance_scale = 7.5,
                image_guidance_scale=1.5,
                num_inference_steps=args.num_inversion_steps,
            )
        acc_metric_w = watermark.eval_watermark(reversed_latents_w)
        
        # sample
        latents_no_w = torch.randn([1, 4, 64, 64], dtype=torch.float16)
        
        images_no_w = pipe(input_text, image=input_image, latents = latents_no_w).images
        image_no_w = images_no_w[0]
        image_no_w.save(f"imgs/test_no_w_{i}_{input_text}.jpg")
        detect_image_no_w = image_distortion(image_no_w, args.seed, args)
        # detect_image_no_w = transform_img(detect_image_no_w).unsqueeze(0).to(torch.float16).to("cuda")
        detect_image_no_w = pipe.image_processor.preprocess(detect_image_no_w).to(device=pipe._execution_device, dtype=torch.float16)
        image_latents_no_w = pipe.get_image_latents(detect_image_no_w, sample=False)
        reversed_latents_no_w = pipe.forward_diffusion(
                latents=image_latents_no_w,
                image = detect_image_no_w,
                prompt = "",
                guidance_scale=7.5,
                image_guidance_scale=1.5,
                num_inference_steps=args.num_inversion_steps,
            )
        acc_metric_no_w = watermark.eval_watermark(reversed_latents_no_w)
        if args.reference_model is not None:
            scores = measure_similarity([image_w, image_no_w], edit_prompt, ref_model,
                                              ref_clip_preprocess,
                                              ref_tokenizer, "cuda")
            clip_score_w = scores[0].item()
            clip_score_no_w = scores[1].item()
        else:
            clip_score_w = 0
            clip_score_no_w = 0
        acc.append(acc_metric_w)
        acc_no.append(acc_metric_no_w)
        clip_scores.append(clip_score_w)
        clip_scores_no.append(clip_score_no_w)
        print(acc_metric_w, clip_score_w, acc_metric_no_w, clip_score_no_w)
    preds = acc_no +  acc
    t_labels = [0] * len(acc_no) + [1] * len(acc)

    fpr, tpr, thresholds = metrics.roc_curve(t_labels, preds, pos_label=1)
    auc = metrics.auc(fpr, tpr)
    acc = np.max(1 - (fpr + (1 - tpr))/2)
    low = tpr[np.where(fpr<.01)[0][-1]]
    print(f'clip_score_mean: {np.mean(clip_scores_no)}')
    print(f'w_clip_score_mean: {np.mean(clip_scores)}')
    print(f'auc: {auc}, acc: {acc}, TPR@1%FPR: {low}')
    with open("results_test.txt", 'a+') as f:
        f.write(str(args)+"\n")
        f.write(f"{args.name} {args.ring} {args.ring_threshold}  {args.seed} {args.r_degree} {args.channel} {args.w_r_start} {args.w_r_end} {args.w_r_interval}\n")
        # f.write(str(args)+"\n")
        f.write(f'w_clip_score_mean: {np.mean(clip_scores)} clip_score_mean: {np.mean(clip_scores_no)}\n')
        f.write(f'auc: {auc} acc: {acc} TPR@1%FPR: {low}\n')
        f.write("\n")
    
if __name__ == "__main__":
    
    main()