import os
import re
import sys
import glob
import json
import traceback
import numpy as np
import gradio as gr
from PIL import Image
import plotly.graph_objects as go

OUTPUT_ROOT = 'output'
METHODS = ['nemo', 'baseline']
PLACEHOLDER_IMAGE_PATH = 'placeholder.png'
if not os.path.exists(PLACEHOLDER_IMAGE_PATH):
    img = Image.new('RGB', (256, 256), (50, 50, 50))
    img.save(PLACEHOLDER_IMAGE_PATH, 'PNG')
PLACEHOLDER_IMAGE_PIL = Image.open(PLACEHOLDER_IMAGE_PATH)

def load_all_data():
    """Scans the output directory and builds a nested dictionary for fast lookup."""
    data_index = {}
    all_json_files = glob.glob(os.path.join(OUTPUT_ROOT, 'baseline/laion_memorized', 'prompt_*_metrics.json'), recursive=True)
    print(f"Found {len(all_json_files)} JSON files in {OUTPUT_ROOT}")
    for f_path in all_json_files:
        try:
            parts = f_path.split(os.sep)
            method, source = parts[-3], parts[-2]
            base_name = os.path.basename(f_path)
            match = re.match(r"prompt_(\d+)_(\d+)_", base_name)
            if not match: continue
            prompt_id, seed_id = int(match.group(1)), int(match.group(2))
            
            if source not in data_index: data_index[source] = {}
            if prompt_id not in data_index[source]: data_index[source][prompt_id] = {}
            if method not in data_index[source][prompt_id]: data_index[source][prompt_id][method] = {}
            
            with open(f_path, 'r') as f:
                data_index[source][prompt_id][method][seed_id] = {
                    'metrics': json.load(f), 'json_path': f_path,
                    'image_path': f_path.replace('_metrics.json', '_image.png'),
                    'gt_path': os.path.join(os.path.dirname(f_path), f"prompt_{prompt_id:04d}_ground_truth.png"),
                    'source': source
                }
        except Exception:
            print(traceback.format_exc())
            pass # Ignore files that can't be parsed
    return data_index

def create_plotly_plot(metrics, plot_type, t_step=None):
    """Creates a Plotly figure, returning an empty one with a message on error or no data."""
    fig = go.Figure()
    title = f"Plot Error"
    try:
        has_data = False
        if plot_type == 'noise' and metrics and 'Noise_Difference_Norm' in metrics:
            title = "Noise Trajectory"
            traj = metrics['Noise_Difference_Norm'].get('noise_diff_norm_traj', [])
            if traj:
                fig.add_trace(go.Scatter(y=traj, mode='lines', name='Noise Diff Norm'))
                has_data = True
        elif plot_type == 'eigval' and metrics and 'HessianMetric' in metrics and t_step in metrics['HessianMetric']:
            title = f"Eigenvalues @ {t_step}"
            data = metrics['HessianMetric'][t_step]
            cond_eigvals = sorted(data.get('cond_eigvals', []))
            uncond_eigvals = sorted(data.get('uncond_eigvals', []))
            if cond_eigvals and uncond_eigvals:
                fig.add_trace(go.Scatter(y=uncond_eigvals, mode='lines', name='uncond', line=dict(dash='dash')))
                fig.add_trace(go.Scatter(y=cond_eigvals, mode='lines', name='cond'))
                has_data = True
        
        if not has_data:
            fig.add_annotation(text="No Data", showarrow=False)
            
    except Exception as e:
        print(traceback.format_exc())
        fig.add_annotation(text=f"Plot Error: {e}", showarrow=False)

    fig.update_layout(title=title, template="plotly_white", height=300, margin=dict(t=40, b=20))
    return fig

def load_image_safely(path, placeholder_path):
    """Tries to load an image from a path, returning a placeholder on failure."""
    try:
        if os.path.exists(path) and os.path.getsize(path) > 0:
            return Image.open(path)
        else:
            raise FileNotFoundError(f"File does not exist or is empty at path: {path}")
    except Exception as e:
        print(f"[IMAGE LOAD ERROR] Could not load image '{path}'. Reason: {e}")
        return Image.open(placeholder_path)


# --- Main Application ---
def launch_inspector():
    master_data_index = load_all_data()
    if not master_data_index:
        print("No data found. Cannot launch.")
        return

    all_sources = list(master_data_index.keys())
    all_prompt_ids = set()
    for source_data in master_data_index.values():
        all_prompt_ids.update(source_data.keys())
    prompt_ids = sorted(all_prompt_ids)
    
    # --- UI Update Functions ---
    def update_view(method, prompt_id, seed_id, source):
        prompt_id, seed_id = int(prompt_id), int(seed_id)
        data = master_data_index.get(source, {}).get(prompt_id, {}).get(method, {}).get(seed_id)

        if data:
            gt_path = data['gt_path'] if os.path.exists(data['gt_path']) else PLACEHOLDER_IMAGE_PATH
            gen_path = data['image_path'] if os.path.exists(data['image_path']) else PLACEHOLDER_IMAGE_PATH
            
            info = f"**Source:** {data['source']}\n\n**Prompt:** {data['metrics'].get('prompt', 'N/A')}"
            
            gt_image_data = load_image_safely(gt_path, PLACEHOLDER_IMAGE_PATH)
            gen_image_data = load_image_safely(gen_path, PLACEHOLDER_IMAGE_PATH)

            label = data['metrics'].get('memorized', False)
            noise_plot = create_plotly_plot(data['metrics'], 'noise')
            eig1_plot = create_plotly_plot(data['metrics'], 'eigval', 't1')
            eig20_plot = create_plotly_plot(data['metrics'], 'eigval', 't20')

            return (gt_image_data, info, gen_image_data, label, noise_plot, eig1_plot, eig20_plot)
        else:
            placeholder_img = load_image_safely(PLACEHOLDER_IMAGE_PATH, PLACEHOLDER_IMAGE_PATH)
            return (
                placeholder_img, "No data for this selection.", placeholder_img, False,
                create_plotly_plot(None, 'noise'), create_plotly_plot(None, 'eigval', 't1'),
                create_plotly_plot(None, 'eigval', 't20')
            )

    def save_label(method, prompt_id, seed_id, source, new_label):
        prompt_id, seed_id = int(prompt_id), int(seed_id)
        data = master_data_index.get(source, {}).get(prompt_id, {}).get(method, {}).get(seed_id)
        if data:
            json_path = data['json_path']
            try:
                with open(json_path, 'r+') as f:
                    metrics_data = json.load(f)
                    metrics_data['memorized'] = new_label
                    f.seek(0); json.dump(metrics_data, f, indent=2); f.truncate()
                master_data_index[source][prompt_id][method][seed_id]['metrics']['memorized'] = new_label
                return f"Label saved for {os.path.basename(json_path)}"
            except Exception as e:
                print(traceback.format_exc())
                return f"Error saving label: {e}"
        return "Could not find data to save label."

    # --- JavaScript for Keyboard Shortcuts ---
    js_keyboard_shortcuts = """
    () => {
        function getGradioInput(elem_id) {
            const elem = document.getElementById(elem_id);
            if (!elem) return null;
            // The actual input is often a child of the main component div
            return elem.querySelector('input');
        }

        function getGradioButton(elem_id) {
            const elem = document.getElementById(elem_id);
            if (!elem) return null;
            // The button can be the element itself or a child
            return elem.tagName === 'BUTTON' ? elem : elem.querySelector('button');
        }

        function changeSliderValue(elem_id, direction) {
            const sliderInput = getGradioInput(elem_id);
            if (sliderInput) {
                const currentValue = parseInt(sliderInput.value, 10);
                const step = parseInt(sliderInput.step, 10) || 1;
                const min = parseInt(sliderInput.min, 10);
                const max = parseInt(sliderInput.max, 10);
                let newValue = currentValue + (direction * step);
                // Clamp value within the slider's min/max range
                if (newValue < min) newValue = min;
                if (newValue > max) newValue = max;
                sliderInput.value = newValue;
                // Dispatch events to notify Gradio of the change
                sliderInput.dispatchEvent(new Event('input', { bubbles: true }));
                sliderInput.dispatchEvent(new Event('change', { bubbles: true }));
            }
        }
        
        function setSliderValue(elem_id, value) {
             const sliderInput = getGradioInput(elem_id);
             if (sliderInput) {
                const min = parseInt(sliderInput.min, 10);
                const max = parseInt(sliderInput.max, 10);
                let newValue = parseInt(value, 10);
                // Set value only if it's within the valid range
                if (newValue >= min && newValue <= max) {
                    sliderInput.value = newValue;
                    sliderInput.dispatchEvent(new Event('input', { bubbles: true }));
                    sliderInput.dispatchEvent(new Event('change', { bubbles: true }));
                }
             }
        }

        function setCheckboxValue(elem_id, isChecked) {
            const checkboxInput = getGradioInput(elem_id);
            if (checkboxInput && checkboxInput.type === 'checkbox') {
                // Only trigger a change if the state is different
                if (checkboxInput.checked !== isChecked) {
                    checkboxInput.checked = isChecked;
                    checkboxInput.dispatchEvent(new Event('change', { bubbles: true }));
                }
            }
        }

        window.addEventListener('keydown', (event) => {
            // Ignore key events if the user is typing in an input field
            if (event.target.tagName === 'INPUT' || event.target.tagName === 'TEXTAREA') {
                return;
            }
            
            switch (event.key) {
                case 'ArrowRight':
                    event.preventDefault();
                    changeSliderValue('prompt_slider', 1);
                    break;
                case 'ArrowLeft':
                    event.preventDefault();
                    changeSliderValue('prompt_slider', -1);
                    break;
                case 'm':
                case 'M':
                    event.preventDefault();
                    setCheckboxValue('label_checkbox', true);
                    break;
                case 'n':
                case 'N':
                    event.preventDefault();
                    setCheckboxValue('label_checkbox', false);
                    break;
                case 'Enter':
                case ' ': // Space key
                    event.preventDefault();
                    const saveButton = getGradioButton('save_button');
                    if (saveButton) {
                        saveButton.click();
                    }
                    break;
                default:
                    // Handle single digit number keys for the seed slider
                    if (!isNaN(parseInt(event.key, 10)) && event.key.length === 1) {
                        event.preventDefault();
                        setSliderValue('seed_slider', event.key);
                    }
                    break;
            }
        });
    }
    """

    with gr.Blocks(title="Memorization Inspector", css=".gradio-container {max-width: 98% !important;}", js=js_keyboard_shortcuts) as demo:
        gr.Markdown("# Unified Memorization Inspector")
        with gr.Accordion("Keyboard Shortcuts", open=False):
            gr.Markdown("- **Left/Right Arrows**: Change Prompt Index\n"
                        "- **Number Keys (0-9)**: Change Seed Index\n"
                        "- **M / N**: Mark as Memorized / Not Memorized\n"
                        "- **Enter / Space**: Save Label")
        
        with gr.Row():
            method_dropdown = gr.Dropdown(METHODS, value=METHODS[0], label="Unlearning Method")
            source_dropdown = gr.Dropdown(all_sources, value=all_sources[0] if all_sources else None, label="Data Source")
            prompt_slider = gr.Slider(minimum=prompt_ids[0], maximum=prompt_ids[-1], step=1, value=prompt_ids[0], label="Prompt Index", elem_id="prompt_slider")
            seed_slider = gr.Slider(minimum=0, maximum=10, step=1, value=0, label="Seed Index", elem_id="seed_slider")
        
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### Info & Ground Truth")
                gt_image = gr.Image(value=PLACEHOLDER_IMAGE_PIL, label="Ground Truth", height=300)
                info_display = gr.Markdown("Info will appear here.")
                
            with gr.Column(scale=2):
                gr.Markdown("### Generated Result")
                gen_image = gr.Image(value=PLACEHOLDER_IMAGE_PIL, label="Generated Image", height=300)
                with gr.Row():
                    label_checkbox = gr.Checkbox(label="Is Memorized?", interactive=True, elem_id="label_checkbox")
                    save_button = gr.Button("Save Label", elem_id="save_button")
                status_text = gr.Textbox(label="Status", interactive=False)

            with gr.Column(scale=2):
                gr.Markdown("### Metric Plots")
                noise_plot = gr.Plot()
                eig_t1_plot = gr.Plot()
                eig_t20_plot = gr.Plot()
        
        # --- Event Wiring ---
        inputs = [method_dropdown, prompt_slider, seed_slider, source_dropdown]
        outputs = [gt_image, info_display, gen_image, label_checkbox, noise_plot, eig_t1_plot, eig_t20_plot]
        
        for component in inputs:
            component.change(fn=update_view, inputs=inputs, outputs=outputs)
        
        save_button.click(fn=save_label, inputs=inputs + [label_checkbox], outputs=[status_text])
        
        demo.load(fn=update_view, inputs=inputs, outputs=outputs)

    print("Mem Inspector App with Keyboard Shortcuts is launching...")
    demo.launch(share=True)

if __name__ == "__main__":
    launch_inspector()
