#without noise and batch create

import time
import os
from PIL import Image
import torch
import pickle
from datetime import datetime

import argparse

from utils.wm.wm_utils import WmProviders
from utils.wm.gs_official_provider import parser as gs_ref_parser
from utils.wm.gs_optimised_provider import parser as gs_opt_parser
from utils.wm.prc_provider import parser as prc_parser
from utils.wm.gs_chroma_provider import parser as gs_chroma_parser
from utils.wm.messages_long import MESSAGES as messages_long
from utils.prompt_utils import get_huggingface_list
# Here import your new parser  <---------------------------------------------------------------------------------------------- HERE


from utils.pipe import pipe_utils

from utils.utils import *
from utils.wm.messages_long import MESSAGES

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")




# set seeds
set_random_seed(1337)




# args
parser = argparse.ArgumentParser(description="create and noiseless", parents=[gs_ref_parser, gs_opt_parser,gs_chroma_parser,prc_parser]) 
# wm_type
parser.add_argument("--wm_type", type=str, default="GSreference", choices=["GSreference","PRC","GSoptimised","GSchroma","GSppReference"])  # add your new wm_type here 

# experiments settings
parser.add_argument("--batch_size", type=int, default=1,help="1 is recommended for verification experiments. Others are untested.")
parser.add_argument("--resolution", type=int, default=512)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--num_inversion_steps", type=int, default=50)
parser.add_argument("--guidance_scale", type=float, default=7.5)
parser.add_argument("--experiment",type=str,default="top",choices=["top","misses","random","batch_create"],help="what to do")
parser.add_argument("--filepath",type=str,default="./data/",help="where to store the results")
parser.add_argument("--filepath_imgs",type=str,default="./images/")

parser.add_argument("--amount",type=int,default=10,help="how many images in total?")
parser.add_argument("--step_size",default=10,type=int,help="how many to leave out at every step if experiment random")
parser.add_argument("--verbose",action="store_true",default=False)
parser.add_argument("--exact_inversion",action=argparse.BooleanOptionalAction,default=False)
parser.add_argument("--prompt_dataset_split", type=str, default="test",choices=["train","test"],help="which split of the prompts dataset? test does not contain 10k images")

args, unknown_args = parser.parse_known_args()


log_to_out(args.verbose,args)


# ------------------->>>> get prompts <<<<-----------------------
all_prompts = get_huggingface_list(amount=args.amount,split=args.prompt_dataset_split) 
prompts = all_prompts[:int(args.amount)] 
MODELID ="stabilityai/stable-diffusion-2-1-base"
SCHEDULER ="DDIM"

if args.exact_inversion: #according to prc code
    from utils.pipe import prc_inversion    

    invpipe = prc_inversion.stable_diffusion_pipe()

else: #null text inverison as usual
    from utils.pipe import pipe_utils

    pipe_provider = pipe_utils.get_pipe_provider(pretrained_model_name_or_path=MODELID,
                                             schedulers_name=SCHEDULER,
                                             resolution=args.resolution,
                                             device=DEVICE,
                                             eager_loading=True,
                                             disable_tqdm=True)

if args.experiment == "batch_create":
    BATCH_SIZE = args.batch_size
    if args.wm_type == "GSoptimised":
        method = "GSreference" 
    else:
        method = args.wm_type
        
    #create img folder
    folder_path = args.filepath_imgs+method
    os.makedirs(folder_path, exist_ok=True)
    folder_path = folder_path+"/"
    if method in ["PRC","GSppReference"]: #create key and store it
        prc = WmProviders[method].value(latent_shape=(BATCH_SIZE, 4, 64, 64),
                                                **vars(args),
                                                offset=0)
        prc.dump_keys(folder_path+"keys.pkl")
    #create images

    repeats = args.amount // BATCH_SIZE
    for rep in range(repeats):
        offset =int(rep*BATCH_SIZE)
        log_to_out(args.verbose,[rep,offset,len(prompts)])
        wm_provider = WmProviders[method].value(latent_shape=(BATCH_SIZE, 4, 64, 64),
                                                filepath_keys=folder_path+"keys.pkl",
                                                **vars(args),
                                                offset=offset
                                                )

        wm_initial_results = wm_provider.get_wm_latents()
        wm_zT = wm_initial_results["zT_torch"]  # shape: (<batch_size>, 4, 64, 64)
        if args.exact_inversion: #normal inversion is fine for generating. Basically the same call.
            generated_images_PIL = []
            for i in range(BATCH_SIZE):
                img,_,_ =prc_inversion.generate(image_num=i,pipe=invpipe,init_latents=wm_zT[i],
                                    prompt=prompts[i],)
                generated_images_PIL.append(img)
        else:
            results_generation = pipe_provider.generate(prompts=prompts[offset:offset+BATCH_SIZE],
                                        num_inference_steps=args.num_inference_steps,
                                        guidance_scale=args.guidance_scale,
                                        latents=wm_zT,)
            generated_images_PIL = results_generation["images_PIL"]  # list of PIL images
        #save
        for i,img in enumerate(generated_images_PIL):
            index = rep*BATCH_SIZE+i
            img.save(folder_path+str(index)+".png")
    exit(0)


if args.experiment == "top":
    timestamp = datetime.now().strftime("%y%m%d_%H%M%S")


    if args.wm_type =="PRC":
        folder_path = args.filepath_imgs+"PRC/"
    elif args.wm_type =="GSppReference":
        folder_path= args.filepath_imgs+"GSppReference/"
    else:
        folder_path= args.filepath_imgs+"GSreference/"
    


    for i in range(args.amount):
        log_to_out(args.verbose,[args.wm_type,i,args.amount])
        

        wm_provider = WmProviders[args.wm_type].value(latent_shape=(1, 4, 64, 64),
                                                      filepath_keys=folder_path+"keys.pkl",
                                                    **vars(args),
                                                offset=i,)     
        # get watermarked latents
   
        img = Image.open(folder_path+str(i)+".png")
        message_original = messages_long[i]
        distorted_images = img

        # invert the watermark
        if args.exact_inversion:
            start = time.perf_counter()
            zT_retrieved = prc_inversion.exact_inversion(
                image = img,
                guidance_scale=3.0, #vs args.guidance_scale = 7.5
                test_num_inference_steps=args.num_inference_steps,
                pipe=invpipe
            )
            end = time.perf_counter()
        else:
            start = time.perf_counter()
            results_inversion = pipe_provider.invert_images(images=distorted_images, num_inference_steps=args.num_inversion_steps)
            zT_retrieved = results_inversion["zT_torch"]
            end = time.perf_counter()

        time_inversion = end-start
        # ------------------->>>> verify <<<<-------------------
        start = time.perf_counter()
        # watermark test
        accuracy_results = wm_provider.get_accuracies(zT_retrieved,images = [distorted_images])
        end = time.perf_counter()

        time_verification = end-start
        #verify if correct
        bit_accuracy = accuracy_results["bit_accuracies"]
        message_bits_str_recovered = accuracy_results["message_bits_str_list"]

        filenameresults = args.filepath+args.experiment+wm_provider.get_wm_type()+str(timestamp)


        log_to_file(filepath=filenameresults,log_list=[wm_provider.get_wm_type(),
                                              bit_accuracy,
                                              time_inversion,
                                              time_verification])


elif args.experiment == "random":
    timestamp = datetime.now().strftime("%y%m%d_%H%M%S")


    if args.wm_type =="PRC":
        folder_path = args.filepath_imgs+"PRC/"
    elif args.wm_type =="GSppReference":
        folder_path= args.filepath_imgs+"GSppReference/"
    else:
        folder_path= args.filepath_imgs+"GSreference/"
    


    for i in range(0,args.amount,args.step_size):
        #choose as amount is total number of images availabe (or where to stop), step size is the step size.
        #we ran amount=10000, step_size = 10. this experiment uses batch_size = 1.

        wm_provider = WmProviders[args.wm_type].value(latent_shape=(1, 4, 64, 64),
                                                      filepath_keys=folder_path+"keys.pkl",
                                                    **vars(args),
                                                offset=i,)     
        # get watermarked latents
   
        img = Image.open(folder_path+str(i)+".png")
        message_original = messages_long[i]
        distorted_images = img

        # invert the watermark
        if args.exact_inversion:
            start = time.perf_counter()
            zT_retrieved = prc_inversion.exact_inversion(
                image = img,
                guidance_scale=3.0, #vs args.guidance_scale = 7.5
                test_num_inference_steps=args.num_inference_steps,
                pipe=invpipe
            )
            end = time.perf_counter()
        else:
            start = time.perf_counter()
            results_inversion = pipe_provider.invert_images(images=distorted_images, num_inference_steps=args.num_inversion_steps)
            zT_retrieved = results_inversion["zT_torch"]
            end = time.perf_counter()

        time_inversion = end-start
        # ------------------->>>> verify <<<<-------------------
        start = time.perf_counter()
        # watermark test
        accuracy_results = wm_provider.get_accuracies(zT_retrieved,images = [distorted_images])
        end = time.perf_counter()

        time_verification = end-start
        #verify if correct
        bit_accuracy = accuracy_results["bit_accuracies"]
        message_bits_str_recovered = accuracy_results["message_bits_str_list"]

        filenameresults = args.filepath+args.experiment+wm_provider.get_wm_type()+str(timestamp)


        log_to_file(filepath=filenameresults,log_list=[wm_provider.get_wm_type(),
                                              bit_accuracy,
                                              time_inversion,
                                              time_verification])
              
elif args.experiment == "misses":
    timestamp = datetime.now().strftime("%y%m%d_%H%M%S")

    #compute how many times to repeat because batch_size is smaller that number of what to do
    repeats = args.amount // args.batch_size
    results = {"bit_accuracies":list(),
                "messages_recovered":list(),
                "timings_inversion":list(),
                "timings_verification": list(),
                "amount":args.amount}   
    for rep in range(repeats):
        log_to_out(args.verbose,[rep,repeats])
        offset =int(rep*args.batch_size)

        #generate normal random latents
        latents = torch.randn(size=(args.batch_size, 4, 64, 64))
        #try
        results_generation = pipe_provider.generate(prompts=prompts[offset:offset+args.batch_size],
                                            num_inference_steps=args.num_inference_steps,
                                            guidance_scale=args.guidance_scale,
                                            latents=latents,)
        generated_images_PIL = results_generation["images_PIL"]  # list of PIL images

        wm_provider = WmProviders[args.wm_type].value(latent_shape=(args.batch_size, 4, 64, 64),
                                                    **vars(args),
                                                    )

        # invert the watermark
        if args.exact_inversion:
            start = time.perf_counter()
            zT_retrieved = prc_inversion.exact_inversion(
                image = img,
                guidance_scale=3.0, #vs args.guidance_scale = 7.5
                test_num_inference_steps=args.num_inference_steps,
                pipe=invpipe
            )
            end = time.perf_counter()
        else:
            start = time.perf_counter()
            results_inversion = pipe_provider.invert_images(images=generated_images_PIL, num_inference_steps=args.num_inversion_steps)
            zT_retrieved = results_inversion["zT_torch"]
            end = time.perf_counter()

        time_inversion = end-start
        # ------------------->>>> verify <<<<-------------------
        start = time.perf_counter()
        # watermark test
        accuracy_results = wm_provider.get_accuracies(zT_retrieved,images=generated_images_PIL)
        end = time.perf_counter()

        time_verification = end-start
        #verify if correct
        bit_accuracy = accuracy_results["bit_accuracies"]
        filenameresults = args.filepath+args.experiment+wm_provider.get_wm_type()+str(timestamp)


        log_to_file(filepath=filenameresults,log_list=[wm_provider.get_wm_type(),
                                              bit_accuracy,
                                              time_inversion,
                                              time_verification])
else:
    raise NotImplementedError
