
import argparse
import os
os.environ["HF_ENDPOINT"]="https://hf-mirror.com"
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import PIL
from PIL import Image
import cv2
import matplotlib

from copy import deepcopy
from einops import rearrange
from types import SimpleNamespace

from diffusers import DDIMScheduler, AutoencoderKL
from torchvision.utils import save_image
from pytorch_lightning import seed_everything
from utils.colorwheel import flow_to_image
import sys
sys.path.insert(0, '../')
from drag_pipeline import DragPipeline
from utils.unet_drag.unet_2d_condition import UNet2DConditionModel

from utils.attn_utils import MutualSelfAttentionControl

from utils.edit_utils import run_drag
import time
import yaml
torch.set_num_threads(4)

def preprocess_image(image,
                     device):
    image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1]
    image = rearrange(image, "h w c -> 1 c h w")
    image = image.to(device)
    return image

def save_depth_map(depth, save_dir):
    cmap = matplotlib.colormaps.get_cmap('Spectral_r')
    depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
    depth = depth.astype(np.uint8)
    colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
    Image.fromarray(colored_depth).save(os.path.join(save_dir, "colored_depth.png"))

    gray_depth = Image.fromarray(depth)
    gray_depth.save(os.path.join(save_dir, "gray_depth.png"))

def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument("--base_sd_path", type=str)
    parser.add_argument("--vae_path", type=str, default="default")
    parser.add_argument("--ip_adapter_path", type=str)
    parser.add_argument("--lightning_drag_model_path", type=str)
    parser.add_argument("--lcm_lora_path", type=str, default=None)
    parser.add_argument("--server_port", type=int, default=8888)
    args = parser.parse_args()
    return args

def init_model( model_path, vae_path, device):
    scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
                          beta_schedule="scaled_linear", clip_sample=False,
                          set_alpha_to_one=False, steps_offset=1)
    model = DragPipeline.from_pretrained(model_path, scheduler=scheduler, torch_dtype=torch.float16)
    unet = UNet2DConditionModel.from_pretrained(
                     "SimianLuo/LCM_Dreamshaper_v7",
                    subfolder="unet",
                    torch_dtype=torch.float16,)
    model.unet = unet
    
    model.modify_unet_forward()
    if vae_path != "default":
        model.vae = AutoencoderKL.from_pretrained(
            vae_path
        ).to(model.vae.device, model.vae.dtype)
    model.enable_model_cpu_offload(device=device)
    return model

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="setting arguments")
    parser.add_argument('--sample_path', type=str, default=None, help='sample path')
    
    parser.add_argument('--lora_steps', type=int, default=80, help='number of lora fine-tuning steps')
    parser.add_argument('--inv_strength', type=float, default=0.7, help='inversion strength')
    parser.add_argument('--latent_lr', type=float, default=0.01, help='latent learning rate')
    parser.add_argument('--unet_feature_idx', type=int, default=3, help='feature idx of unet features')
    parser.add_argument('--result_dir', type=str, default=None, help='feature idx of unet features')
    parser.add_argument('--n_inference_step', type=int, default=10, help='feature idx of unet features')
    
    parser.add_argument('--lambda_mix', type=float, default=None, help='lambda mix')
    parser.add_argument('--gamma_ratio', type=float, default=1, help='gamma ratio')
    parser.add_argument('--upper_scale', type=float, default=5, help='upper scale')
    parser.add_argument('--lower_scale', type=float, default=0, help='lower scale')
    parser.add_argument('--alpha', type=float, default=1, help='alpha')
    parser.add_argument('--beta', type=float, default=1, help='beta')
    parser.add_argument('--device', type=str, default='cuda', help='device')
    parser.add_argument('--test_fusion', type=str, default='amplitude')
    parser.add_argument('--lora_path', type=str, default=None, help='lora dir')
    args = parser.parse_args()
    model = init_model(model_path='runwayml/stable-diffusion-v1-5',
                                                              vae_path="default",device=args.device)
            
    sample_path =args.sample_path
    save_dir = args.result_dir
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # read image file
    source_image = Image.open(os.path.join(sample_path, 'original_image.png'))
    source_image = np.array(source_image)

    # load meta data
    with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f:
        meta_data = pickle.load(f)
    prompt = meta_data['prompt']
    mask = meta_data['mask']
    points = meta_data['points']
    lora_path = args.lora_path
        

    out_image = run_drag(model,
                                    source_image,
                                    mask,
                                    prompt,
                                    points,
                                    args.inv_strength,
                                    model_path='runwayml/stable-diffusion-v1-5',
                                    vae_path="default",
                                    start_step=0,
                                    start_layer=10,
                                    n_inference_step=args.n_inference_step,
                                    task_cat="continuous drag",
                                    lambda_mix=args.lambda_mix,
                                    gamma_ratio=args.gamma_ratio,
                                    upper_scale=args.upper_scale,
                                    lower_scale=args.lower_scale,
                                    alpha=args.alpha,
                                    beta=args.beta,
                                    # test_lambda=args.test_lambda,
                                    # test_space_weight=args.test_space_weight,
                                    # test_depth_weight=args.test_depth_weight,
                                    test_fusion=args.test_fusion,
                                    device=args.device,
                                    lora_path=lora_path,)
    Image.fromarray(out_image).save(os.path.join(save_dir, 'dragged_image.png'))
    # flow_im = flow_to_image(optical_flow.cpu().numpy())
    # Image.fromarray(flow_im).save(os.path.join(save_dir, 'flow.png'))