
import argparse
import os
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import PIL
from PIL import Image
import cv2
import matplotlib

from copy import deepcopy
from einops import rearrange
from types import SimpleNamespace

from diffusers import DDIMScheduler, AutoencoderKL
from torchvision.utils import save_image
from pytorch_lightning import seed_everything
from utils.colorwheel import flow_to_image
import sys
from drag_pipeline import DragPipeline
from utils.unet_drag.unet_2d_condition import UNet2DConditionModel

from utils.attn_utils import MutualSelfAttentionControl

from utils.edit_utils import run_drag
import time
import yaml
torch.set_num_threads(4)

def preprocess_image(image,
                     device):
    image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1]
    image = rearrange(image, "h w c -> 1 c h w")
    image = image.to(device)
    return image

def save_depth_map(depth, save_dir):
    cmap = matplotlib.colormaps.get_cmap('Spectral_r')
    depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
    depth = depth.astype(np.uint8)
    colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
    Image.fromarray(colored_depth).save(os.path.join(save_dir, "colored_depth.png"))

    gray_depth = Image.fromarray(depth)
    gray_depth.save(os.path.join(save_dir, "gray_depth.png"))


def init_model( model_path, vae_path, device):
    scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
                          beta_schedule="scaled_linear", clip_sample=False,
                          set_alpha_to_one=False, steps_offset=1)
    model = DragPipeline.from_pretrained(model_path, scheduler=scheduler, torch_dtype=torch.float16)
    unet = UNet2DConditionModel.from_pretrained(
                     "SimianLuo/LCM_Dreamshaper_v7",
                    subfolder="unet",
                    torch_dtype=torch.float16,)
    model.unet = unet
    
    model.modify_unet_forward()
    if vae_path != "default":
        model.vae = AutoencoderKL.from_pretrained(
            vae_path
        ).to(model.vae.device, model.vae.dtype)
    model.enable_model_cpu_offload(device=device)
    return model

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="setting arguments")
    parser.add_argument('--lora_steps', type=int, default=80, help='number of lora fine-tuning steps')
    parser.add_argument('--inv_strength', type=float, default=0.7, help='inversion strength')
    parser.add_argument('--latent_lr', type=float, default=0.01, help='latent learning rate')
    parser.add_argument('--unet_feature_idx', type=int, default=3, help='feature idx of unet features')
    parser.add_argument('--result_dir', type=str, default=None, help='feature idx of unet features')
    parser.add_argument('--n_inference_step', type=int, default=10, help='feature idx of unet features')
    
    parser.add_argument('--lambda_mix', type=float, default=None, help='lambda mix')
    parser.add_argument('--gamma_ratio', type=float, default=1, help='gamma ratio')
    parser.add_argument('--upper_scale', type=float, default=5, help='upper scale')
    parser.add_argument('--lower_scale', type=float, default=0, help='lower scale')
    parser.add_argument('--alpha', type=float, default=1, help='alpha')
    parser.add_argument('--beta', type=float, default=1, help='beta')
    parser.add_argument('--device', type=str, default='cuda', help='device')
    parser.add_argument('--test_fusion', type=str, default='amplitude')
    parser.add_argument('--lora_dir', type=str, default=None, help='lora dir')
    args = parser.parse_args()

    all_category = [
        'art_work',
        'land_scape',
        'building_city_view',
        'building_countryside_view',
        'animals',
        'human_head',
        'human_upper_body',
        'human_full_body',
        'interior_design',
        'other_objects',
    ]


    root_dir = 'data/DragBench'
    lora_dir = args.lora_dir
    lambda_mix = args.lambda_mix if args.lambda_mix is not None else -1
    if args.result_dir == None:
        result_dir = 'geodrag' + \
            '_' + str(lambda_mix) + \
            '_' + str(args.gamma_ratio) + \
            '_' + str(args.upper_scale) + \
            '_' + str(args.lower_scale) + \
            '_' + str(args.alpha) + \
            '_' + str(args.beta) + \
            '_' + args.test_fusion + \
            '_' + str(args.n_inference_step)
    else:
        result_dir = args.result_dir+ \
            '_' + str(lambda_mix) + \
            '_' + str(args.gamma_ratio) + \
            '_' + str(args.upper_scale) + \
            '_' + str(args.lower_scale) + \
            '_' + str(args.alpha) + \
            '_' + str(args.beta) + \
            '_' + args.test_fusion + \
            '_' + str(args.n_inference_step)

    # mkdir if necessary
    # if not os.path.isdir(result_dir):
    os.makedirs(result_dir,exist_ok=True)
    for cat in all_category:
        os.makedirs(os.path.join(result_dir,cat),exist_ok=True)
    save_time_sum = 0
    start_time = time.time()
    mem_list = []
    model = init_model(model_path='runwayml/stable-diffusion-v1-5',
                                                              vae_path="default",device=args.device)
    for cat in all_category:
        file_dir = os.path.join(root_dir, cat)
        for sample_name in os.listdir(file_dir):
            if sample_name == '.DS_Store':
                continue
            
            sample_path = os.path.join(file_dir, sample_name)

            # read image file
            source_image = Image.open(os.path.join(sample_path, 'original_image.png'))
            source_image = np.array(source_image)

            # load meta data
            with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f:
                meta_data = pickle.load(f)
            prompt = meta_data['prompt']
            mask = meta_data['mask']
            points = meta_data['points']

            # load lora
            if args.lora_dir is not None:
                lora_path = os.path.join(lora_dir, cat, sample_name, str(args.lora_steps))
                print("applying lora: " + lora_path)
            else:
                lora_path = None
                print("editing: " + sample_name)

            image_with_clicks = None
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            out_image = run_drag(model,
                                            source_image,
                                            mask,
                                            prompt,
                                            points,
                                            args.inv_strength,
                                            model_path='runwayml/stable-diffusion-v1-5',
                                            vae_path="default",
                                            start_step=0,
                                            start_layer=10,
                                            n_inference_step=args.n_inference_step,
                                            task_cat="continuous drag",
                                            lambda_mix=args.lambda_mix,
                                            gamma_ratio=args.gamma_ratio,
                                            upper_scale=args.upper_scale,
                                            lower_scale=args.lower_scale,
                                            alpha=args.alpha,
                                            beta=args.beta,
                                            test_fusion=args.test_fusion,
                                            device=args.device,
                                            lora_path=lora_path,)
            end_time = time.time()
            save_time0 = time.time()
            save_dir = os.path.join(result_dir, cat, sample_name)
            if not os.path.isdir(save_dir):
                os.mkdir(save_dir)
            peak_mem = torch.cuda.max_memory_reserved() / (1024 ** 2)  # 转成MB
            mem_list.append(peak_mem)
            Image.fromarray(out_image).save(os.path.join(save_dir, 'dragged_image.png'))  
            save_time1 = time.time()
            save_time_sum += (save_time1-save_time0)
    mem_array = np.array(mem_list)
    print(f"Peak memory (MB): mean={mem_array.mean():.2f}, std={mem_array.std():.2f}, max={mem_array.max():.2f}, min={mem_array.min():.2f}")
    print(f"***************\n"*2)
    print(f"use time sum: {end_time-start_time}")
    print(f"use save time sum: {save_time_sum}")
    print(f"use drag time sum: {end_time-start_time-save_time_sum}")
    print(f"use drag time per point: {(end_time-start_time-save_time_sum)/349}")
    print(f"***************\n"*2)
    logg = f"***************\n"*2 + \
            f"{time.strftime('%Y-%m-%d %H:%M:%S',time.localtime())}\n" +\
            f"{result_dir}:  \n" + \
            f"Peak memory (MB): mean={mem_array.mean():.2f}, std={mem_array.std():.2f}, max={mem_array.max():.2f}, min={mem_array.min():.2f} \n" + \
            f"use time sum: {end_time-start_time} \n" + \
            f"use save time sum: {save_time_sum} \n" + \
            f"use drag time sum: {end_time-start_time-save_time_sum} \n" +\
            f"use drag time per point: {(end_time-start_time-save_time_sum)/349}\n\n\n"
    with open("./run_drag_result.txt", 'a') as f:
        f.write(logg)