import os

import gradio as gr
import torch
from pathlib import Path
from utils.ui_utils import store_img, get_points, undo_points, clear_all, run_drag_interface, store_sample

# -------- Based configure for Web GUI -------- #
torch.set_num_threads(4)
LENGTH=480 # length of the square area displaying/editing images

# Get the number of GPUs available
if torch.cuda.is_available():
    num_gpu = torch.cuda.device_count()
    devices = [f'cuda:{i}' for i in range(num_gpu)] + ["cpu"]
else:
    devices = ["cpu"]
# ---------------------------------------------- #

with gr.Blocks() as demo:
    ArgumentsSpace = gr.State(
                {
            "device": devices[0],
            "sam_checkpoints": "vit_b",
            "model_path": "runwayml/stable-diffusion-v1-5",
            "vae_path": "default",
            "masks": {},
            "points": {}
        }
    )
    # images = gr.State({}) # store images
    # masks = gr.State({}) # store mask
    mask = gr.State(value=None) # store mask
    selected_points = gr.State([]) # store points
    original_image = gr.State(value=None) # store original input image


    # ------------ layout definition ------------ #
    with gr.Row():
        gr.Markdown("""
        # Official Implementation of UnitDrag
        """)
    # ------------------------------------------- #
    
    # ------------ DragFlow ------------ #
    with gr.Tab(label="Editing Image"):
        with gr.Row():
            with gr.Column():
                gr.Markdown("""<p style="text-align: center; font-size: 20px">Choose Region</p>""")
                canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
                    show_label=True, height=LENGTH, width=LENGTH) # for mask painting
            with gr.Column():
                gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""")
                input_image = gr.Image(type="numpy", label="Click Points",
                    show_label=True, height=LENGTH, width=LENGTH, interactive=False) # for points clicking
            with gr.Column():
                gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing Results</p>""")
                output_image = gr.Image(type="numpy", label="Editing Results",
                    show_label=True, height=LENGTH, width=LENGTH, interactive=False)
        with gr.Row(equal_height=True):
            train_lora_button = gr.Button("Train LoRA (optional)")
            undo_button = gr.Button("Undo point")
            op_type = gr.Dropdown(
                value='translate',
                choices=['translate', 'rotate', 'scale', 'stretch'],
                label="Drag type", 
                show_label=True,
            )
            run_button = gr.Button("Run")
            clear_all_button = gr.Button("Clear All")
        with gr.Row(equal_height=True):
            sample_save_dir = gr.Textbox(value="./sample", label="Sample save path")
            sample_save_button = gr.Button("Save Sample")
            load_sample_dir = gr.Textbox(value="./sample", label="Load Sample path")
            load_sample_button = gr.Button("Load Sample")
            set_drag_type_button = gr.Button("Set Drag Type")
        with gr.Row():
            with gr.Accordion("Optional LoRA", open=False):
                with gr.Row():
                    train_lora_button = gr.Button("Train LoRA (Optinal)")
                    prompt = gr.Textbox(label="Prompt", info='Text prompt for the model')
                    guidance_scale = gr.Number(value=1, label="Guidance Scale", info="Guidance scale for the model")
                    lora_path = gr.Textbox(value="./lora", label="LoRA path", info="Path to save LoRA model")
                    lora_status_bar = gr.Textbox(label="display LoRA training status", info="Training status of LoRA model")
        with gr.Row():
            with gr.Accordion("Optical flows", open=False):
                with gr.Row():
                    optical_flow_images = gr.Gallery(label="Optical flows", show_label=True).style(height='auto',columns=4)
    # -------------------------------------------- #

    # ------------ Configure ------------ #
    with gr.Tab(label="Configure"):
        with gr.Row():
            with gr.Tab("Drag Config"):
                with gr.Row():
                    inversion_strength = gr.Slider(0, 1.0,
                        value=0.7,
                        label="inversion strength",
                        info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.")
                    n_inference_step = gr.Slider(0, 1000,
                        value=50,
                        label="inference step",
                        info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.")
                with gr.Row():
                    start_step = gr.Number(value=0, label="start step", precision=0, info="start step for latent optimization")
                    start_layer = gr.Number(value=10, label="start layer", precision=0, info="start layer for latent optimization")
                    end_step = gr.Number(value=0, label='End Timestep', precision=0, info="end step for latent optimization")
        with gr.Row():
            with gr.Tab("Base Model Config"):
                with gr.Row():
                    local_models_dir = Path('local_pretrained_models')
                    local_models_dir.mkdir(exist_ok=True, parents=True)
                    local_models_choice = \
                        [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
                    model_path = gr.Dropdown(value="runwayml/stable-diffusion-v1-5",
                        label="Diffusion Model Path",
                        choices=[
                            "runwayml/stable-diffusion-v1-5",
                            "gsdf/Counterfeit-V2.5",
                            "stablediffusionapi/anything-v5",
                            "SG161222/Realistic_Vision_V2.0",
                        ] + local_models_choice
                    )
                    vae_path = gr.Dropdown(value="default",
                        label="VAE choice",
                        choices=["default",
                        "stabilityai/sd-vae-ft-mse"] + local_models_choice
                    )
                with gr.Row():
                    save_prefix = gr.Textbox(value="results", label="Output path")
                    device = gr.Dropdown(
                        label="Device", 
                        choices=devices,
                        value=devices[0]
                    )
                    lambda_mix = gr.Number(value=-1, label="Lambda", precision=None)
                    gamma_ratio = gr.Number(value=1.0, label="Gamma ratio", precision=None)
                    upper_scale = gr.Number(value=5, label="Upper scale", precision=None)
                    lower_scale = gr.Number(value=0, label="Lower scale", precision=None)
                    alpha = gr.Number(value=1.0, label="Alpha", precision=None)
                    beta = gr.Number(value=2.0, label="Beta", precision=None)
        with gr.Tab("LoRA Parameters"):
            with gr.Row():
                lora_step = gr.Number(value=80, label="LoRA training steps", precision=0)
                lora_lr = gr.Number(value=0.0005, label="LoRA learning rate")
                lora_batch_size = gr.Number(value=4, label="LoRA batch size", precision=0)
                lora_rank = gr.Number(value=16, label="LoRA rank", precision=0)
    # ---------------------------------- #

    # ------------ Trigger ------------ #
    canvas.edit(
        store_img,
        [canvas],
        [original_image, selected_points, input_image, mask]
    )
    input_image.select(
        get_points,
        [input_image, selected_points],
        [input_image],
    )
    undo_button.click(
        undo_points,
        [original_image, mask],
        [input_image, selected_points]
    )
    clear_all_button.click(
        clear_all,
        [gr.Number(value=LENGTH, visible=False, precision=0)],
        [canvas,
        input_image,
        output_image,
        selected_points,
        original_image,
        mask]
    )
    run_button.click(
        run_drag_interface,
        [original_image,
        input_image,
        mask,
        prompt,
        selected_points,
        inversion_strength,
        model_path,
        vae_path,
        start_step,
        start_layer,
        n_inference_step,
        op_type,
        device,
        save_prefix,
        lambda_mix,
        gamma_ratio,
        upper_scale,
        lower_scale,
        alpha,
        beta,
        ],
        [output_image]
    )
    sample_save_button.click(
        store_sample,
        [canvas, sample_save_dir, selected_points, prompt],
        []
    )

demo.launch()