#invert once, use many watermarks to verify

"""
Script to run the imprinting forgery attack.
"""
import time
from PIL import Image
import torch
import pickle
import os
import argparse

from numpy import mean

from utils.wm.wm_utils import WmProviders
from utils.wm.gs_official_provider import parser as gs_ref_parser
from utils.wm.gs_optimised_provider import parser as gs_opt_parser
from utils.wm.prc_provider import parser as prc_parser
from utils.wm.gs_chroma_provider import parser as gs_chroma_parser
from utils.wm.messages_long import MESSAGES as messages_long
from utils import image_utils
from utils.prompt_utils import PROMPTS_SD_LIST, get_huggingface_list

from datetime import datetime
from utils.utils import *
from utils.wm.messages_long import MESSAGES

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# set seeds
set_random_seed(1337)




# args
parser = argparse.ArgumentParser(description="removal_mass_experiment", parents=[gs_ref_parser, gs_opt_parser,gs_chroma_parser,prc_parser])  # add your parser here  <---------------------------------------------------------------------------------------------- HERE
# wm_type
parser.add_argument("--wm_type", nargs="*", type=str, default="GSreference")  # add your new wm_type here 

parser.add_argument("--prompt_dataset_split", type=str, default="test")

# model
parser.add_argument("--model_id", type=str, default="stabilityai/stable-diffusion-2-1-base")
parser.add_argument("--scheduler", type=str, default="DDIM")

# experiments settings
parser.add_argument("--resolution", type=int, default=512)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--num_inversion_steps", type=int, default=50)
parser.add_argument("--guidance_scale", type=float, default=7.5)
parser.add_argument("--filepath",type=str,default="./data/noisy/")
parser.add_argument("--filepath_imgs",type=str,default="/data/thietj22/images/")
parser.add_argument("--exact_inversion",action=argparse.BooleanOptionalAction,default=False)

parser.add_argument("--distortion",nargs="*",type=str,default=None)
parser.add_argument("--distortion_value",nargs="*",type=float,default=None)

parser.add_argument("--amount",type=int,default=10)

args, unknown_args = parser.parse_known_args()


print(args)

assert(len(args.distortion) ==len(args.distortion_value))

MODELID ="stabilityai/stable-diffusion-2-1-base"
SCHEDULER ="DDIM"

if args.exact_inversion: #according to prc code
    from utils.pipe import prc_inversion    

    invpipe = prc_inversion.stable_diffusion_pipe()

else: #null text inverison as usual
    from utils.pipe import pipe_utils

    pipe_provider = pipe_utils.get_pipe_provider(pretrained_model_name_or_path=MODELID,
                                             schedulers_name=SCHEDULER,
                                             resolution=args.resolution,
                                             device=DEVICE,
                                             eager_loading=True,
                                             disable_tqdm=True)
if args.wm_type =="PRC":
    folder_path = args.filepath_imgs+"PRC/"
elif args.wm_type =="GSppReference":
    folder_path= args.filepath_imgs+"GSppReference/"
else:
    folder_path= args.filepath_imgs+"GSreference/"


timestamp = datetime.now().strftime("%y%m%d")
filename = args.filepath+"noise"+timestamp

for i in range(args.amount):
    for distortion,distvalue in zip(args.distortion,args.distortion_value):

        distargs = {distortion:distvalue}
        img = Image.open(folder_path+str(i)+".png")
        message_original = messages_long[i]
        distorted_images = image_utils.distort_images(img,**distargs)

        # invert the watermarked image with the attacker model to get the target latent zT
        if args.exact_inversion:
            start = time.perf_counter()
            zT_retrieved = prc_inversion.exact_inversion(
                image = img,
                guidance_scale=3.0, #vs args.guidance_scale = 7.5
                test_num_inference_steps=args.num_inference_steps,
                
                pipe=invpipe
            )
            end = time.perf_counter()
        else:
            start = time.perf_counter()
            results_inversion = pipe_provider.invert_images(images=distorted_images, num_inference_steps=args.num_inversion_steps)
            zT_retrieved = results_inversion["zT_torch"]
            end = time.perf_counter()

        time_inversion = end-start

        # ------------------->>>> verify <<<<-------------------
        for method in args.wm_type:
            if method == "GSoptimised1":
                args.randomness_check = True
                args.db_lookup = False
                wmtype = "GSoptimised"
            elif method == "GSoptimised8":
                args.randomness_check = True
                args.db_lookup = False
                wmtype = "GSoptimised"
            else:
                wmtype = method

            print(i,method)

            wm_provider = WmProviders[wmtype].value(latent_shape=(1, 4, 64, 64),
                                                    filepath_keys=folder_path+"keys.pkl",
                                                **vars(args),
                                            offset=i,

                                                )
            zT_copy = zT_retrieved.detach().clone()
            start = time.perf_counter()
            # watermark test
            accuracy_results = wm_provider.get_accuracies(zT_copy,images = [distorted_images])
            end = time.perf_counter()

            time_verification = end-start
            #verify if correct
            bit_accuracy = accuracy_results["bit_accuracies"]
            message_bits_str_recovered = accuracy_results["message_bits_str_list"]

            log_to_file(filename,[method,bit_accuracy,distortion,distvalue,time_inversion,time_verification])

 
            