from ast import mod
import sys
from typing import List

import numpy as np
import pyrallis
import torch
from PIL import Image
from diffusers.training_utils import set_seed
import os
sys.path.append(".")
sys.path.append("..")

from transfer_model import TransferModel
from config import RunConfig, Range
from utils import latent_utils
from utils.latent_utils import load_latents_or_invert_images
import argparse

@pyrallis.wrap()
def main(cfg: RunConfig):
    run(cfg)


def run(cfg: RunConfig) -> List[Image.Image]:
    pyrallis.dump(cfg, open(cfg.output_path / 'config.yaml', 'w'))
    set_seed(cfg.seed)
    model = TransferModel(cfg)
    latents_app, latents_struct, noise_app, noise_struct = load_latents_or_invert_images(model=model, cfg=cfg)
    model.set_latents(latents_app, latents_struct)
    model.set_noise(noise_app, noise_struct)
    print("Running appearance transfer...")
    images = run_appearance_transfer(model=model, cfg=cfg)
    print("Done.")
    return images


def run_appearance_transfer(model: TransferModel, cfg: RunConfig) -> List[Image.Image]:
    init_latents, init_zs = latent_utils.get_init_latents_and_noises(model=model, cfg=cfg)
    model.pipe.scheduler.set_timesteps(cfg.num_timesteps)
    model.enable_edit = True 
    start_step = min(cfg.cross_attn_32_range.start, cfg.cross_attn_64_range.start)
    end_step = max(cfg.cross_attn_32_range.end, cfg.cross_attn_64_range.end)
    if cfg.mask_type == "sam":
        mask_struct, mask_style = model.sam.get_masks(config=cfg)
    else:
        mask_struct, mask_style = None, None
    del model.sam
    images = model.pipe(
        prompt=cfg.prompt,
        latents=init_latents,
        guidance_scale=3.5,
        num_inference_steps=cfg.num_timesteps,
        swap_guidance_scale=cfg.swap_guidance_scale,
        callback=model.get_adain_callback(),
        eta=1,
        zs=init_zs,
        generator=torch.Generator('cuda').manual_seed(cfg.seed),
        cross_image_attention_range=Range(start=start_step, end=end_step),
        model=model,
        config=cfg,
        mask_style=mask_style,
        mask_struct=mask_struct,
    ).images
    # Save images
    images[0].save(cfg.output_path / f"out_transfer---seed_{cfg.seed}.png")
    images[1].save(cfg.output_path / f"out_style---seed_{cfg.seed}.png")
    images[2].save(cfg.output_path / f"out_struct---seed_{cfg.seed}.png")
    joined_images = np.concatenate(images[::-1], axis=1)
    Image.fromarray(joined_images).save(cfg.output_path / f"out_joined---seed_{cfg.seed}.png")
    return images 

def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a editing script.")
    parser.add_argument(
        "--struct_image_path",
        type=str,
        required=True,
        help="struct image path",
    )
    parser.add_argument(
        "--app_image_path",
        type=str,
        required=True,
        help="appearance image path",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="output",
        help="appearance image path",
    )
    parser.add_argument(
        "--save_image_path",
        type=str,
        default=None,
        help="save image path",
    )
    parser.add_argument(
        "--domain_name",
        type=str,
        nargs="+",
        default=None,
        help="domain name",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="random seed",
    )
    parser.add_argument(
        "--bg_energy_scale",
        type=int,
        default=3e4,
        help="background guidance energy scale",
    )
    parser.add_argument(
        "--swap_guidance_scale",
        type=float,
        default=0,
        help="swap guidance",
    )
    parser.add_argument(
        "--w_global",
        type=float,
        default=0.5,
        help="global guidance weight",
    )
    parser.add_argument(
        "--w_struct",
        type=float,
        default=0.3,
        help="struct feature guidance weight",
    )
    parser.add_argument(
        "--w_app",
        type=float,
        default=0.3,
        help="appearance feature guidance",
    )
    parser.add_argument(
        "--w_background",
        type=float,
        default=0.3,
        help="background feature guidance",
    )
    parser.add_argument(
        "--pe_scale",
        type=float,
        default=10.0,
        help="position embedding scale",
    )
    parser.add_argument(
        "--cross_guidance",
        type=float,
        default=0.1,
        help="cross-attention guidance",
    )
    parser.add_argument(
        "--feat_guidance_type",
        type=str,
        default="app",
        choices=["app", "struct", "app_new_affine", "app_struct"],
        help=(
            "guidance type"
        ),
    )
    parser.add_argument(
        "--cross_guidance_type",
        type=str,
        default="l2",
        choices=["l2", "cos"],
        help=(
            "guidance type"
        ),
    )
    parser.add_argument(
        "--attention_guidance_type",
        type=str,
        default="cross",
        choices=["self", "cross"],
        help=(
            "attention guidance type"
        ),
    )
    parser.add_argument(
        "--mask_type",
        type=str,
        default="sam",
        choices=["sam", "kmeans"],
        help=(
            "get mask type"
        ),
    )
    parser.add_argument(
        "--use_adain",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--use_masked_adain",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--swap_kv",
        default=True,
        action="store_false",
    )
    args = parser.parse_args()
    return args



def main():
    from pathlib import Path
    from PIL import Image
    args = parse_args()
    
    config = RunConfig(
        app_image_path=Path(args.app_image_path),
        struct_image_path=Path(args.struct_image_path),
        domain_name=args.domain_name,
        seed=args.seed,
        load_latents=False,
        cross_guidance=args.cross_guidance,
        swap_guidance_scale=args.swap_guidance_scale,
        feat_guidance_type=args.feat_guidance_type,
        cross_guidance_type=args.cross_guidance_type,
        use_adain=args.use_adain,
        use_masked_adain=args.use_masked_adain,
        attention_guidance_type=args.attention_guidance_type,
        w_app=args.w_app,
        w_global=args.w_global,
        w_struct=args.w_struct,
        w_background=args.w_background,
        pe_scale=args.pe_scale,
        bg_energy_scale=args.bg_energy_scale,
        mask_type=args.mask_type,
        swap_kv=args.swap_kv,
    )
    save_path = f"./{args.save_dir}/{args.domain_name[0]}_{args.domain_name[1]}"
    adain= ""
    if config.use_adain:
        adain = "_adain"
    elif config.use_masked_adain:
        adain = "_mask_adain"
    if not os.path.exists(save_path):
        os.makedirs(save_path, exist_ok=True)
    swap_kv = ""
    if args.swap_kv:
        swap_kv = "swap_kv"
    if args.save_image_path is None:
        save_file = os.path.join(save_path,f"{swap_kv}_{config.swap_guidance_scale}_app_{config.w_app}_struct_{config.w_struct}_global_{config.w_global}_{config.attention_guidance_type}_{config.cross_guidance}_pe_{config.pe_scale}_bg_{config.w_background}{adain}.png")
    else:
        save_file = args.save_image_path
    if os.path.exists(save_file):
        return
    images = run(cfg=config)
    joined_images = Image.fromarray(np.concatenate(images[::-1], axis=1))
    joined_images.save(save_file)
    print(f'Image saved to {save_file}')
    
if __name__ == '__main__':
    main()
