import sys
from typing import List

import numpy as np
import pyrallis
import torch
from PIL import Image
from diffusers.training_utils import set_seed
import os
sys.path.append(".")
sys.path.append("..")
from pathlib import Path
from transfer_model import TransferModel
from config import RunConfig, Range
from utils import latent_utils
from utils.latent_utils import load_latents_or_invert_images
import argparse
import glob
@pyrallis.wrap()
def main(cfg: RunConfig):
    run(cfg)

def run(args):
    set_seed(args.seed)
    model = TransferModel()
    style_images =  glob.glob(f"{args.app_image_dir}/*")
    struct_images = glob.glob(f"{args.struct_image_dir}/*")
    adain= ""
    if args.use_adain:
        adain = "_adain"
    elif args.use_masked_adain:
        adain = "_mask_adain"
    swap_kv = ""
    if args.swap_kv:
        swap_kv = "swap_kv"
    image_save_dir = f"{args.save_dir}/{swap_kv}_{args.swap_guidance_scale}_app_{args.w_app}_struct_{args.w_struct}_global_{args.w_global}_{args.attention_guidance_type}_{args.cross_guidance}_pe_{args.pe_scale}_bg_{args.w_background}{adain}"
    if not os.path.exists(image_save_dir):
        os.makedirs(image_save_dir, exist_ok=True)

    i = 0
    for struct_image in struct_images:
        for style_image in style_images:
            cfg = RunConfig(
                app_image_dir=args.app_image_dir,
                struct_image_dir=args.struct_image_dir,
                struct_caption_file=args.struct_caption_file,
                style_caption_file=args.style_caption_file,
                app_image_path=Path(style_image),
                struct_image_path = Path(struct_image),
                save_dir=args.save_dir,
                domain_name=args.domain_name,
                seed=args.seed,
                load_latents=False,
                cross_guidance=args.cross_guidance,
                swap_guidance_scale=args.swap_guidance_scale,
                feat_guidance_type=args.feat_guidance_type,
                cross_guidance_type=args.cross_guidance_type,
                use_adain=args.use_adain,
                use_masked_adain=args.use_masked_adain,
                attention_guidance_type=args.attention_guidance_type,
                w_app=args.w_app,
                w_global=args.w_global,
                w_struct=args.w_struct,
                w_background=args.w_background,
                pe_scale=args.pe_scale,
                bg_energy_scale=args.bg_energy_scale,
                mask_type=args.mask_type,
                swap_kv=args.swap_kv)
            model.config = cfg
            model.step = 0
            save_file = f"{image_save_dir}/c_{struct_image.split('/')[-1].split('.')[0]}_s_{style_image.split('/')[-1].split('.')[0]}.png"
            if os.path.exists(save_file):
                print(f"Image already exsits!!!{save_file}")
                continue

            latents_app, latents_struct, noise_app, noise_struct = load_latents_or_invert_images(model=model, cfg=cfg)
            model.set_latents(latents_app, latents_struct)
            model.set_noise(noise_app, noise_struct)
            print("Running appearance transfer...")

            images = run_appearance_transfer(model=model, cfg=cfg)
            print("Done.")
            i+=1
            joined_images = np.concatenate(images[::-1], axis=1)
            Image.fromarray(joined_images).save(save_file)
            print(f'{i=}')
            print(f"image saved to {image_save_dir}/c_{struct_image.split('/')[-1].split('.')[0]}_s_{style_image.split('/')[-1].split('.')[0]}.png")


def run_appearance_transfer(model: TransferModel, cfg: RunConfig) -> List[Image.Image]:
    init_latents, init_zs = latent_utils.get_init_latents_and_noises(model=model, cfg=cfg)  # get latents from  model.set_latents(latents_app, latents_struct) and model.set_noise(noise_app, noise_struct)
    model.pipe.scheduler.set_timesteps(cfg.num_timesteps)
    model.enable_edit = True  # Activate our cross-image attention layers
    start_step = min(cfg.cross_attn_32_range.start, cfg.cross_attn_64_range.start)
    end_step = max(cfg.cross_attn_32_range.end, cfg.cross_attn_64_range.end)
    with torch.no_grad():
        try:
            if cfg.mask_type == "sam":
                mask_struct, mask_style = model.sam.get_masks(config=cfg)
            else:
                mask_struct, mask_style = None, None
        except:
            mask_struct = Image.open(f"./temp/masks/content_{cfg.struct_image_path.name.split('_s_')[0].replace('c_','')+'.jpg'}")#.resize((img_base.size[0], img_base.size[1]))
            mask_style = Image.open(f"./temp/masks/style_{cfg.app_image_path.name.split('_s_')[-1]}")#.resize((img_base.size[0], img_base.size[1]))
    # del model.sam
    torch.cuda.empty_cache()
    images = model.pipe(
        prompt=cfg.prompt,
        latents=init_latents,
        guidance_scale=3.5,
        num_inference_steps=cfg.num_timesteps,
        swap_guidance_scale=cfg.swap_guidance_scale,
        callback=model.get_adain_callback(),
        eta=1,
        zs=init_zs,
        generator=torch.Generator('cuda').manual_seed(cfg.seed),
        cross_image_attention_range=Range(start=start_step, end=end_step),
        model=model,
        config=cfg,
        mask_style=mask_style,
        mask_struct=mask_struct,
    ).images

    return images 

def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a editing script.")
    parser.add_argument(
        "--struct_image_dir",
        type=str,
        required=True,
        help="struct image path",
    )
    parser.add_argument(
        "--app_image_dir",
        type=str,
        required=True,
        help="appearance image path",
    )
    parser.add_argument(
        "--struct_caption_file",
        type=str,
        help="caption file",
    )
    parser.add_argument(
        "--style_caption_file",
        type=str,
        help="caption file",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="test_resluts",
        help="appearance image path",
    )
    parser.add_argument(
        "--save_image_path",
        type=str,
        default=None,
        help="save image path",
    )
    parser.add_argument(
        "--domain_name",
        type=str,
        nargs="+",
        default=[],
        help="domain name",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="random seed",
    )
    parser.add_argument(
        "--bg_energy_scale",
        type=int,
        default=3e4,
        help="background guidance energy scale",
    )
    parser.add_argument(
        "--swap_guidance_scale",
        type=float,
        default=0,
        help="swap guidance",
    )
    parser.add_argument(
        "--w_global",
        type=float,
        default=0.6,
        help="global guidance weight",
    )
    parser.add_argument(
        "--w_struct",
        type=float,
        default=0.9,
        help="struct feature guidance weight",
    )
    parser.add_argument(
        "--w_app",
        type=float,
        default=0.3,
        help="appearance feature guidance",
    )
    parser.add_argument(
        "--w_background",
        type=float,
        default=3,
        help="background feature guidance",
    )
    parser.add_argument(
        "--pe_scale",
        type=float,
        default=10.0,
        help="position embedding scale",
    )
    parser.add_argument(
        "--cross_guidance",
        type=float,
        default=0.1,
        help="cross-attention guidance",
    )
    parser.add_argument(
        "--feat_guidance_type",
        type=str,
        default="app_struct",
        choices=["app", "struct", "app_new_affine", "app_struct"],
        help=(
            "guidance type"
        ),
    )
    parser.add_argument(
        "--cross_guidance_type",
        type=str,
        default="l2",
        choices=["l2", "cos"],
        help=(
            "guidance type"
        ),
    )
    parser.add_argument(
        "--attention_guidance_type",
        type=str,
        default="cross",
        choices=["self", "cross"],
        help=(
            "attention guidance type"
        ),
    )
    parser.add_argument(
        "--mask_type",
        type=str,
        default="sam",
        choices=["sam", "kmeans"],
        help=(
            "get mask type"
        ),
    )
    parser.add_argument(
        "--use_adain",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--use_masked_adain",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--swap_kv",
        default=True,
        action="store_false",
    )
    args = parser.parse_args()
    return args



def main():
    from pathlib import Path
    from PIL import Image
    args = parse_args()

    run(args=args)
    
if __name__ == '__main__':
    main()
