import argparse
import yaml
import pandas as pd
import numpy as np
from dash import Dash, dcc, html, Input, Output
import plotly.graph_objs as go
import os
import sys

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

from modules.survey_converter import Survey
from modules.response_converter import Responses, ResponseUtils

# ----------------------------
# Helper: Entropy computation
# ----------------------------
def compute_entropy(counts, vocab_size):
    total = sum(counts)
    if total == 0:
        return 0.0
    probs = [c / total for c in counts if c > 0]
    entropy = -sum(p * np.log2(p) for p in probs)
    max_entropy = np.log2(vocab_size) if vocab_size > 1 else 1
    return entropy / max_entropy

# ----------------------------
# CLI Argument Parsing
# ----------------------------
parser = argparse.ArgumentParser(description="Run interactive entropy diagnostics on survey responses using YAML config.")
parser.add_argument(
    "--config",
    type=str,
    required=True,
    help="Path to YAML configuration file."
)
args = parser.parse_args()

# ----------------------------
# Load YAML and responses
# ----------------------------
with open(args.config, "r") as f:
    config = yaml.safe_load(f)

paths = config.get("paths", {})
responses_path = paths.get("responses_csv")
survey_csv = paths.get("survey_csv")
survey_yaml = paths.get("survey_yaml")

survey = Survey(csv_path=survey_csv, config_path=survey_yaml)
responses = Responses(source=responses_path, survey=survey, output_format="answer")

# ----------------------------
# Combine response matrices across splits
# ----------------------------
splits = ["train", "valid", "test"]
df_list = [responses.get_matrix_by_split(split) for split in splits]
df = pd.concat(df_list, axis=0)

response_data = {}
entropy_data = []

for qid in df.index:
    series = df.loc[qid].dropna().astype(str)
    response_data[qid] = series

    code_to_answer = responses.questions[qid].get("code_to_answer", {})
    if responses.output_format == "answer":
        vocab = list(code_to_answer.values())
    else:
        vocab = [str(k) for k in code_to_answer.keys()]

    counts = series.value_counts()
    padded_counts = [counts.get(label, 0) for label in vocab]
    entropy = compute_entropy(padded_counts, vocab_size=len(vocab))

    entropy_data.append({
        "qid": qid,
        "entropy": entropy,
        "split": responses.questions[qid].get("split", "unknown"),
        "vocab": vocab
    })

df_entropy = pd.DataFrame(entropy_data)
df_entropy["qid"] = df_entropy["qid"].astype(str)

# Set split as ordered categorical
split_order = ["train", "valid", "test", "unknown"]
df_entropy["split"] = pd.Categorical(df_entropy["split"], categories=split_order, ordered=True)

# Sort by split (custom order) then by entropy descending
df_entropy = (
    df_entropy
    .sort_values(["split", "entropy"], ascending=[True, False])
    .reset_index(drop=True)
)


# ----------------------------
# Build Dash App
# ----------------------------


split_colors = {
    "train": "skyblue",
    "valid": "lightgreen",
    "test": "gold",
    "unknown": "lightgray"
}

bars = []
for split in df_entropy["split"].unique():
    split_df = df_entropy[df_entropy["split"] == split]
    bars.append(
        go.Bar(
            x=split_df["qid"],
            y=split_df["entropy"],
            name=split,
            marker_color=split_colors.get(split, "gray")
        )
    )

app = Dash(__name__)
app.layout = html.Div([
    html.H2("Interactive Question Entropy Explorer"),
    dcc.Graph(
        id="entropy-plot",
        figure=go.Figure(
            data=bars,
            layout=go.Layout(
                title="Click a question to view response distribution",
                xaxis_title="QID",
                yaxis_title="Normalized Entropy",
                barmode="group",
                legend=dict(title="Split")
            )
        )
    ),
    html.Div(id="response-plot-container")
])

@app.callback(
    Output("response-plot-container", "children"),
    Input("entropy-plot", "clickData"),
)
def update_response_plot(clickData):
    if not clickData:
        return "Click a bar to view its response distribution."

    qid = clickData["points"][0]["x"]
    series = response_data.get(qid, pd.Series([], dtype=str))
    code_to_answer = responses.questions[qid].get("code_to_answer", {})
    if responses.output_format == "answer":
        vocab = list(code_to_answer.values())
    else:
        vocab = [str(k) for k in code_to_answer.keys()]

    counts = series.value_counts()
    padded_counts = [counts.get(label, 0) for label in vocab]

    fig = go.Figure(
        data=[
            go.Bar(
                y=vocab,
                x=padded_counts,
                orientation="h",
                marker_color="#588c73"
            )
        ],
        layout=go.Layout(
            title=f"Response Distribution for {qid}<br><sub>{responses.questions[qid]['question']}</sub>",
            xaxis_title="Frequency",
            yaxis_title="Answer"
        )
    )
    return dcc.Graph(figure=fig)

# ----------------------------
# Run the server
# ----------------------------
if __name__ == "__main__":
    app.run(debug=True)
