import os
import cv2
import numpy as np
import gradio as gr
from copy import deepcopy
import datetime
import PIL
from PIL import Image
from PIL.ImageOps import exif_transpose
from .edit_utils import run_drag
from .colorwheel import flow_to_image
import matplotlib
import yaml
import pickle

def mask_image(image,
               mask,
               color=[255,0,0],
               alpha=0.5):
    """ Overlay mask on image for visualization purpose. 
    Args:
        image (H, W, 3) or (H, W): input image
        mask (H, W): mask to be overlaid
        color: the color of overlaid mask
        alpha: the transparency of the mask
    """
    out = deepcopy(image)
    img = deepcopy(image)
    img[mask == 1] = color
    out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out)
    return out

def store_img(img, length=512):
    image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
    height,width,_ = image.shape
    image = Image.fromarray(image)
    image = exif_transpose(image)
    image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR)
    mask  = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST)
    image = np.array(image)

    if mask.sum() > 0:
        mask = np.uint8(mask > 0)
        masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
    else:
        masked_img = image.copy()
    # when new image is uploaded, `selected_points` should be empty
    return image, [], gr.Image.update(value=masked_img, interactive=True), mask

def get_points(img,
               sel_pix,
               evt: gr.SelectData):
    # collect the selected point
    sel_pix.append(evt.index)
    # draw points
    points = []
    for idx, point in enumerate(sel_pix):
        if idx % 2 == 0:
            # draw a red circle at the handle point
            cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
        else:
            # draw a blue circle at the handle point
            cv2.circle(img, tuple(point), 10, (0, 0, 255), -1)
        points.append(tuple(point))
        # draw an arrow from handle point to target point
        if len(points) == 2:
            cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
            points = []
    return img if isinstance(img, np.ndarray) else np.array(img)

def undo_points(original_image,
                mask):
    if mask.sum() > 0:
        mask = np.uint8(mask > 0)
        masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3)
    else:
        masked_img = original_image.copy()
    return masked_img, []

def clear_all(length=480):
    return gr.Image.update(value=None, height=length, width=length, interactive=True), \
        gr.Image.update(value=None, height=length, width=length, interactive=False), \
        gr.Image.update(value=None, height=length, width=length, interactive=False), \
        [], None, None
        
def save_depth_map(depth, save_dir, timestamp):
    cmap = matplotlib.colormaps.get_cmap('Spectral_r')
    raw_depth = Image.fromarray(depth.astype('uint16'))
    raw_depth.save(os.path.join(save_dir, timestamp, "row_depth.png"))

    depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
    depth = depth.astype(np.uint8)
    colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
    Image.fromarray(colored_depth).save(os.path.join(save_dir, timestamp, "colored_depth.png"))

    gray_depth = Image.fromarray(depth)
    gray_depth.save(os.path.join(save_dir, timestamp, "gray_depth.png"))
    
def run_drag_interface(source_image,
             image_with_clicks,
             mask,
             prompt,
             points,
             inversion_strength,
             model_path,
             vae_path,
             start_step,
             start_layer,
             n_inference_step,
             task_cat,
             device,
             save_dir="./results",
             lambda_mix=None,
             gamma_ratio=0.5,
             upper_scale=1.5,
             lower_scale=0.5,
             alpha=2.0,
             beta=2.0,
    ):
    if lambda_mix < 0:
        lambda_mix = None
    edited_image  = run_drag(None,source_image,
                                            mask,
                                            prompt,
                                            points,
                                            inversion_strength,
                                            model_path,
                                            vae_path,
                                            start_step,
                                            start_layer,
                                            n_inference_step,
                                            task_cat,
                                            lambda_mix=lambda_mix,
                                            gamma_ratio=gamma_ratio,
                                            upper_scale=upper_scale,
                                            lower_scale=lower_scale,
                                            alpha=alpha,
                                            beta=beta,
                                            device=device,)
    # save
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    source_image = source_image.astype(np.uint8)
    image_with_clicks = image_with_clicks.astype(np.uint8)
    edited_image = edited_image.astype(np.uint8)
    
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    if not os.path.exists(os.path.join(save_dir, timestamp)):
        os.makedirs(os.path.join(save_dir, timestamp))
    
    H = source_image.shape[0]
    separator = np.ones((H, 25, 3), dtype=np.uint8) * 255  # 白条
    summary = np.concatenate([
        source_image,
        separator,
        image_with_clicks,
        separator,
        edited_image
    ], axis=1)  
    
    # flow_im = flow_to_image(optical_flow.cpu().numpy())
    # save_depth_map(depth.detach().cpu().numpy(), save_dir, timestamp)
    # Image.fromarray(flow_im).save(os.path.join(save_dir, timestamp, "flow.png"))
    Image.fromarray(summary).save(os.path.join(save_dir, timestamp, "summary.png"))
    Image.fromarray(source_image).save(os.path.join(save_dir, timestamp, "original.png"))
    Image.fromarray(image_with_clicks).save(os.path.join(save_dir, timestamp, "clicks.png"))
    Image.fromarray(edited_image).save(os.path.join(save_dir, timestamp, "edited.png"))
    
    # 记录参数
    param_dict = {
        "prompt": prompt,
        "inversion_strength": inversion_strength,
        "model_path": model_path,
        "vae_path": vae_path,
        "start_step": start_step,
        "start_layer": start_layer,
        "n_inference_step": n_inference_step,
        "task_cat": task_cat,
        "lambda_mix": None if lambda_mix is None else float(lambda_mix),
        "gamma_ratio": float(gamma_ratio),
        "upper_scale": float(upper_scale),
        "lower_scale": float(lower_scale),
        "alpha": float(alpha),
        "beta": float(beta),
    }
    param_path = os.path.join(save_dir, timestamp, "params.yaml")
    with open(param_path, "w") as f:
        yaml.dump(param_dict, f, default_flow_style=False)
        
    return edited_image

def store_sample(input_image, sample_save_dir, selected_points, prompt):
    if not os.path.exists(sample_save_dir):
        os.makedirs(sample_save_dir)
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    sample_path = os.path.join(sample_save_dir, f"{timestamp}.png")
    meta_path = os.path.join(sample_save_dir, f"{timestamp}.pkl")
    ori_image, mask = input_image["image"], input_image["mask"]
    if mask.ndim == 3:
        mask = np.float32(mask[:, :, 0]) / 255.
    else:
        mask = np.float32(mask) / 255.
    ori_image = Image.fromarray(ori_image)
    ori_image.save(sample_path)
    with open(meta_path, "wb") as f:
        pickle.dump({
            "points": selected_points,
            "prompt": prompt,
            "mask": mask
        }, f)
    