import argparse
import copy
from tqdm import tqdm
import torch
from transformers import CLIPModel, CLIPTokenizer
from inverse_stable_diffusion import InversableStableDiffusionPipeline
from diffusers import DPMSolverMultistepScheduler, DDIMScheduler
import open_clip
from optim_utils import *
from io_utils import *
from image_utils import *
from watermark import *
from sklearn import metrics
from pytorch_fid.fid_score import *
from dreamsim import dreamsim


def main(args):
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # fid_value_w = calculate_fid_given_paths(["imgs/no_w", "imgs/w"],
    #                                     50,
    #                                     device,
    #                                     2048,
    #                                     4)
    # print(fid_value_w)
    model, preprocess = dreamsim(pretrained=True, device=device)
    scheduler = DPMSolverMultistepScheduler.from_pretrained(args.model_path, subfolder='scheduler')
    pipe = InversableStableDiffusionPipeline.from_pretrained(
            args.model_path,
            scheduler=scheduler,
            torch_dtype=torch.float16,
            revision='fp16',
    )
    pipe.safety_checker = None
    pipe = pipe.to(device)

    #reference model for CLIP Score
    if args.reference_model is not None:
        ref_tokenizer = open_clip.get_tokenizer(args.reference_model)
        ref_model, _, ref_clip_preprocess = open_clip.create_model_and_transforms(args.reference_model,
                                                                                  pretrained=args.reference_model_pretrain,
                                                                                  device=device)

    # dataset
    if args.coco:
        with open(args.prompt_file) as f:
            dataset = json.load(f)
            image_files = dataset['images']
            dataset = dataset['annotations']
            prompt_key = 'caption'
        w_dir = f"./fid_outputs/coco/{args.run_name}/w"
        no_w_dir = f"./fid_outputs/coco/{args.run_name}/no_w"
        
    else:
        dataset, prompt_key = get_dataset(args)
        w_dir = f"imgs/{args.run_name}/w"
        no_w_dir = f"imgs/{args.run_name}/no_w"
    os.makedirs(w_dir, exist_ok=True)
    os.makedirs(no_w_dir, exist_ok=True)
    # class for watermark
    watermark = ring_full(args.channel_copy, args.hw_copy, args.fpr, args.user_number, args)
    # watermark.create_watermark()
    
    os.makedirs(args.output_path, exist_ok=True)

    # assume at the detection time, the original prompt is unknown
    tester_prompt = ''
    text_embeddings = pipe.get_text_embedding(tester_prompt)

    #acc
    acc = []
    acc_no = []
    #CLIP Scores
    clip_scores = []
    clip_scores_no = []

    #test
    d_list = []
    for i in tqdm(range(args.num)):
        seed = i + args.gen_seed
        # seed = args.gen_seed
        current_prompt = dataset[i][prompt_key]
        
        
        
        #generate with watermark
        set_random_seed(seed)
        init_latents_w = watermark.create_watermark_and_return_w()
        # init_latents_w = watermark.sample_w()
        outputs = pipe(
            current_prompt,
            num_images_per_prompt=1,
            guidance_scale=args.guidance_scale,
            num_inference_steps=args.num_inference_steps,
            height=args.image_length,
            width=args.image_length,
            latents=init_latents_w,
        )
        image_w = outputs.images[0]

        # distortion
        image_w_distortion = image_distortion(image_w, seed, args)

        # reverse img
        image_w_distortion = transform_img(image_w_distortion).unsqueeze(0).to(text_embeddings.dtype).to(device)
        image_latents_w = pipe.get_image_latents(image_w_distortion, sample=False)
        reversed_latents_w = pipe.forward_diffusion(
            latents=image_latents_w,
            text_embeddings=text_embeddings,
            guidance_scale=1,
            num_inference_steps=args.num_inversion_steps,
        )

        #acc metric
        acc_metric = watermark.eval_watermark(reversed_latents_w)
        acc.append(acc_metric)
        
        
        #generate without watermark
        init_latents_no_w = pipe.get_random_latents()
        
        outputs_no_w = pipe(
            current_prompt,
            num_images_per_prompt=1,
            guidance_scale=args.guidance_scale,
            num_inference_steps=args.num_inference_steps,
            height=args.image_length,
            width=args.image_length,
            latents=init_latents_no_w,
        )
        
        image_no_w = outputs_no_w.images[0]
        
        image_no_w_distortion = image_distortion(image_no_w, seed, args)
        
        # reverse no img
        image_no_w_distortion = transform_img(image_no_w_distortion).unsqueeze(0).to(text_embeddings.dtype).to(device)
        image_latents_no_w = pipe.get_image_latents(image_no_w_distortion, sample=False)
        reversed_latents_no_w = pipe.forward_diffusion(
            latents=image_latents_no_w,
            text_embeddings=text_embeddings,
            guidance_scale=1,
            num_inference_steps=args.num_inversion_steps,
        )
        
        
        d0 = model(transform_img(image_no_w).unsqueeze(0).to(device), transform_img(image_w).unsqueeze(0).to(device))
        d = d0.item()
        d_list.append(d)
        print("dreamsim: ", d)
        
        acc_metric_no = watermark.eval_watermark(reversed_latents_no_w)
        acc_no.append(acc_metric_no)
        if args.coco:
            image_file_name = image_files[i]['file_name']
            image_w.save(f'{w_dir}/{image_file_name}')
            image_no_w.save(f"{no_w_dir}/{image_file_name}")
        else:
            image_w.save(f"{w_dir}/{i}.jpg")
            image_no_w.save(f"{no_w_dir}/{i}.jpg")
        #CLIP Score
        if args.reference_model is not None:
            scores = measure_similarity([image_w, image_no_w], current_prompt, ref_model,
                                              ref_clip_preprocess,
                                              ref_tokenizer, device)
            clip_score_w = scores[0].item()
            clip_score_no_w = scores[1].item()
        else:
            clip_score_w = 0
            clip_score_no_w = 0
        # with open(args.run_name+".txt", "a+") as f:
        #     if args.ring:
        #         f.write(f"{watermark.s} {clip_score_w - clip_score_no_w} {clip_score_w} {clip_score_no_w} {acc_metric} {acc_metric_no}\n")
        print(acc_metric, acc_metric_no, clip_score_w, clip_score_no_w)
        clip_scores.append(clip_score_w)
        clip_scores_no.append(clip_score_no_w)

    # roc
    preds = acc_no +  acc
    t_labels = [0] * len(acc_no) + [1] * len(acc)

    fpr, tpr, thresholds = metrics.roc_curve(t_labels, preds, pos_label=1)
    auc = metrics.auc(fpr, tpr)
    acc = np.max(1 - (fpr + (1 - tpr))/2)
    low = tpr[np.where(fpr<.01)[0][-1]]
    
    print(f'clip_score_mean: {mean(clip_scores_no)}')
    print(f'w_clip_score_mean: {mean(clip_scores)}')
    print(f'auc: {auc}, acc: {acc}, TPR@1%FPR: {low}')
    
    with open("results_test_full_s_non_consecutive.txt", 'a+') as f:
        f.write(str(args)+"\n")
        f.write(f"{args.run_name} {args.ring} {args.ring_threshold}  {args.gen_seed} {args.r_degree} {args.channel} {args.w_r_start} {args.w_r_end} {args.w_r_interval}\n")
        # f.write(str(args)+"\n")
        f.write(f'dreamsim: {mean(d_list)}\n')
        f.write(f'w_clip_score_mean: {mean(clip_scores)} clip_score_mean: {mean(clip_scores_no)}\n')
        f.write(f'auc: {auc} acc: {acc} TPR@1%FPR: {low}\n')
        f.write("\n")
    
    #tpr metric
    tpr_detection, tpr_traceability = watermark.get_tpr()
    
    # save metrics encrypt
    # save_metrics(args, tpr_detection, tpr_traceability, acc, clip_scores)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Gaussian Shading')
    parser.add_argument('--run_name', default='gs')
    parser.add_argument('--num', default=1000, type=int)
    parser.add_argument('--image_length', default=512, type=int)
    parser.add_argument('--guidance_scale', default=7.5, type=float)
    parser.add_argument('--num_inference_steps', default=50, type=int)
    parser.add_argument('--num_inversion_steps', default=None, type=int)
    parser.add_argument('--gen_seed', default=0, type=int)
    parser.add_argument('--channel_copy', default=1, type=int)
    parser.add_argument('--hw_copy', default=8, type=int)
    parser.add_argument('--n_split', default=2, type=int)
    parser.add_argument('--user_number', default=1000000, type=int)
    parser.add_argument('--fpr', default=0.000001, type=float)
    parser.add_argument('--output_path', default='./output/')
    parser.add_argument('--chacha', action='store_true', help='chacha20 for cipher')
    parser.add_argument('--no_encrypt', action='store_true', help='ciper')
    parser.add_argument('--reference_model', default="ViT-g-14")
    parser.add_argument('--reference_model_pretrain', default="laion2b_s12b_b42k")
    parser.add_argument('--dataset_path', default='Gustavosta/Stable-Diffusion-Prompts')
    parser.add_argument('--model_path', default='stabilityai/stable-diffusion-2-1-base')
    parser.add_argument('--ring_threshold', default=0.5, type=float)
    parser.add_argument('--coco', action='store_true')
    parser.add_argument('--prompt_file', default='./fid_outputs/coco/meta_data.json')
    parser.add_argument('--gt_folder', default='./fid_outputs/coco/ground_truth')

    parser.add_argument('--w_r_start', default=10, type=int)
    parser.add_argument('--w_r_end', default=0, type=int)
    parser.add_argument('--w_r_interval', default=1, type=int)
    parser.add_argument('--ring', action='store_true')
    parser.add_argument('--channel', default=2, type=int)
    parser.add_argument('--disk', action='store_true')
    parser.add_argument('--fft', action='store_true')
    # for image distortion
    parser.add_argument('--r_degree', default=None, type=int)
    parser.add_argument('--jpeg_ratio', default=None, type=int)
    parser.add_argument('--random_crop_ratio', default=None, type=float)
    parser.add_argument('--random_drop_ratio', default=None, type=float)
    parser.add_argument('--gaussian_blur_r', default=None, type=int)
    parser.add_argument('--median_blur_k', default=None, type=int)
    parser.add_argument('--resize_ratio', default=None, type=float)
    parser.add_argument('--gaussian_std', default=None, type=float)
    parser.add_argument('--sp_prob', default=None, type=float)
    parser.add_argument('--brightness_factor', default=None, type=float)
    parser.add_argument('--h_flip_prob', default=None, type=float)
    parser.add_argument('--v_flip_prob', default=None, type=float)
    parser.add_argument('--x_shift', default=None, type=float)
    parser.add_argument('--y_shift', default=None, type=float)
    parser.add_argument('--perspective_distortion', default=None, type=float)
    

    args = parser.parse_args()
    if args.num_inversion_steps is None:
        args.num_inversion_steps = args.num_inference_steps

    main(args)
