from paint import atten
from PIL import Image,ImageDraw, ImageFont
from pathlib import Path
import math
import os
import math
from typing import Callable, Dict, List, Optional, Tuple,Union
import argparse
import time
import json
from diffusers import StableDiffusionPipeline
import ptp_utils_2
from ptp_utils_2 import AttentionStore
def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawTextHelpFormatter,
        #description=DirectedDiffusion.ProgramInfo.get_parser_description(),
    )
    parser.add_argument(
        "-num_affected",
        "--num_affected_steps",
        metavar="N",
        type=int,
        default=12,
        help='a list of string describing the region of interests "left,right,top,bottom",...',
    )
    parser.add_argument(
        "-fix_lr",
        "--fix_lr",
        metavar="N",
        type=float,
        default=20.0,
        help='a list of string describing the region of interests "left,right,top,bottom",...',
    )

    parser.add_argument(
        "-num_attend",
        "--attend",
        metavar="N",
        type=int,
        default=25,
        help='a list of string describing the region of interests "left,right,top,bottom",...',
    )
    parser.add_argument(
        "-num_obj",
        "--objects",
        metavar="N",
        type=int,
        default=2,
        help='a list of string describing the region of interests "left,right,top,bottom",...',
    )
    parser.add_argument(
        "-w_f",
        "--weight_function",
        metavar="N",
        type=float,
        default=0.4,
        help='a list of string describing the region of interests "left,right,top,bottom",...',
    )
    parser.add_argument(
        "-w_neg",
        "--weight_negative",
        metavar="N",
        type=float,
        default=-1000.0,
        help='a list of string describing the region of interests "left,right,top,bottom",...',
    )
    parser.add_argument(
        "-color_context",
        "--color_context",
        metavar="N",
        type=str,
        default=None,
        help='a list of string describing the region of interests "left,right,top,bottom",...',
    )
    return parser.parse_args()
def main():
    """ The entry point to execute this program
    Args:
    Returns:
    """
    args = get_args()
    model_id = ""

    pipe = StableDiffusionPipeline.from_pretrained(model_id).to("cuda")
    def image_grid(imgs, rows=2, cols=2):                                                                                                                                                                                                         
        w, h = imgs[0].size                                                                                                                                                                                                                       
        grid = Image.new('RGB', size=(cols*w, rows*h))                                                                                                                                                                                            
                                                                                                                                                                                                                                                
        for i, img in enumerate(imgs):                                                                                                                                                                                                            
            grid.paste(img, box=(i%cols*w, i//cols*h))                                                                                                                                                                                            
        return grid 


    import warnings
    warnings.simplefilter("ignore")
    now = time.localtime()
    time_str = time.strftime('%m-%d %H:%M', now)
    time_num = time_str.replace('-', '').replace(':', '')
   
    input_img = Path('')

    output_img = Path(''+str(args.objects)+"_aff_"+str(args.num_affected_steps)+"_wf_"+str(args.weight_function)+"_attend_"+str(args.attend)+"_{}".format(time_num)+".png")


    model_options = {
        'v1.4': '', 
        'v2.1': ''
    }

    with open(args.color_context, "r") as f:
        args.color_context = json.load(f)


    prompt=args.color_context[0]['prompt']
    color_context=args.color_context[0]['color_context']
    token=args.color_context[0]['token']
    input_img=Path(args.color_context[0]['input_mask'])
    tuned_model=args.color_context[0]['tuned_model']
    edit_list=args.color_context[0]['edit_list']
    color_context_new = {
    tuple(map(int, key.split(','))): value
    for key, value in color_context.items()
    }
    color_context=color_context_new

    
    settings = {
        "model": model_options['v2.1'],
        "input_prompt": "a cat wearing sunglasses on face and a party hat on head, sitting on a chair next to a mug, with a lake in the background",
        "color_context": {
        },
        "color_map_img_path": input_img,
        "output_img_path": output_img,
        "seed": 0
    }
    settings["input_prompt"]=prompt
    settings["color_context"]=color_context
    w_f = lambda w, sigma, qk:( args.weight_function* w * math.log(1 + sigma**2)) * qk.std()
    color_map_image = Image.open(settings["color_map_img_path"]).convert("RGB")
    color_context = settings["color_context"]
    input_prompt = settings["input_prompt"]
    images=[]
    save_dir=''
    
    controller = AttentionStore()
    ptp_utils_2.register_attention_control(pipe, controller)
    now = time.localtime()

    time_str = time.strftime('%m-%d %H:%M', now)

    time_num = time_str.replace('-', '').replace(':', '')
    image_save_dir = os.path.join(save_dir, tuned_model, '_'.join(prompt.split()))
    os.makedirs(image_save_dir+"_{}".format(time_num), exist_ok=True)
    for i in range(701,750):
        img = paint_with_words_recurent_attend_always(
            seed =i,
            hf_model_path = settings['model'],
            color_context = color_context,
            color_map_image = color_map_image,
            input_prompt = input_prompt,
            num_inference_steps= 50,
            guidance_scale= 7.5,
            device= "cuda:0",
            weight_function = w_f,
            pipe=pipe,
            aff=args.num_affected_steps,
            neg=args.weight_negative,
            token=token,
            controller=controller,
            attend=args.attend,
            fix_lr=args.fix_lr,
            edit_list=edit_list
            
        )

        image_filename = os.path.join(image_save_dir+"_{}".format(time_num), f"{i}.png")
        img[0].save(image_filename)
        images+=img
    grid_image = image_grid(imgs=images, rows=7 ,cols=7)
    
    draw = ImageDraw.Draw(grid_image)
    line_height = 25


    x, y = 0, 0
    print_param={"lr":args.fix_lr,"affected_steps":args.num_affected_steps,"w_function":args.weight_function,'prompt':settings["input_prompt"],'neg':args.weight_negative,"seg":input_img,"attend":args.attend,"token":token}
    settings=settings["color_context"]
    for key,value in settings.items():
        print_param[value.split(",")[:-1][0]]=value.split(",")[-1]
    font = ImageFont.truetype('jetbrains-mono-bold.ttf', size=20)
    for key, value in print_param.items():
        text = '{}: {}'.format(key, value)
        draw.text((x, y), text, font=font, fill=(0, 0, 0))
        y += line_height
    grid_image.save(os.path.join(image_save_dir+"_{}".format(time_num), f"all_images.png"))
if __name__ == "__main__":
    main()