import os
import argparse
from collections import defaultdict
from typing import Dict, List

import numpy as np
import pandas as pd


def calculate_boxplot_stats(times: List[float]) -> Dict[str, float]:
    if not times:
        return {
            "lower_whisker": 0,
            "lower_quartile": 0,
            "median": 0,
            "upper_quartile": 0,
            "upper_whisker": 0,
        }

    times = np.array(times)
    q1 = np.percentile(times, 25)
    median = np.percentile(times, 50)
    q3 = np.percentile(times, 75)

    iqr = q3 - q1
    lower_whisker = max(times.min(), q1 - 1.5 * iqr)
    upper_whisker = min(times.max(), q3 + 1.5 * iqr)

    return {
        "lower_whisker": lower_whisker,
        "lower_quartile": q1,
        "median": median,
        "upper_quartile": q3,
        "upper_whisker": upper_whisker,
    }


def aggregate_times_by_scene(csv_path: str) -> Dict[str, List[float]]:
    df = pd.read_csv(csv_path)

    scene_times = defaultdict(list)

    scene_names = [
        "Agent Sight",
        "Full Views",
        "Polyomino",
        "Pyramid",
        "Revolution",
        "Sun Direction",
        "Unfolded",
    ]

    for _, row in df.iterrows():
        for scene_name in scene_names:
            time_col = scene_name.lower().replace(" ", "_") + "_mean_time"
            score_col = scene_name.lower().replace(" ", "_") + "_score"

            if time_col in df.columns and score_col in df.columns:
                time_value = row[time_col]
                score_value = row[score_col]
                if (
                    pd.notna(time_value)
                    and time_value > 0
                    and pd.notna(score_value)
                    and score_value != -1
                ):
                    scene_times[scene_name].append(float(time_value))

    return dict(scene_times)


def generate_tikz_boxplot(scene_times: Dict[str, List[float]]) -> str:
    colors = [
        "myred",
        "myorange",
        "mybrown",
        "mygreen",
        "myteal",
        "myblue",
        "mypurple",
        "mycyan",
    ]

    sorted_scenes = sorted(scene_times.keys())

    if not scene_times:
        max_time = 10
    else:
        max_upper_whisker = 0
        for times in scene_times.values():
            stats = calculate_boxplot_stats(times)
            max_upper_whisker = max(max_upper_whisker, stats["upper_whisker"])
        max_time = max_upper_whisker * 1.1

    lines = []
    lines.append("\\begin{tikzpicture}")
    lines.append("\\begin{axis}[")
    lines.append("    xbar,")
    lines.append("    boxplot/draw direction = x,")
    lines.append("    width=12cm,")
    lines.append(f"    height={len(sorted_scenes) * 0.8 + 2}cm,")
    lines.append("    xlabel={Answer Time [s]},")
    lines.append(f"    ytick={{1,...,{len(sorted_scenes)}}},")
    lines.append("    yticklabels={")
    for label in sorted_scenes:
        lines.append(f"        {label},")
    lines.append("    },")
    lines.append("    yticklabel style={align=right, font=\\small},")
    lines.append(f"    xmin=0, xmax={max_time:.1f},")
    lines.append("    axis x line=bottom,")
    lines.append("    axis y line=left,")
    lines.append("    boxplot/variable width,")
    lines.append("    enlarge y limits=0.05,")
    lines.append("    enlarge x limits=0.05,")
    lines.append("    line width=1pt,")
    lines.append("    every boxplot/.append style={thick},")
    lines.append("]")
    lines.append("")

    color_definitions = [
        "\\definecolor{myred}{RGB}{230,97,1}",
        "\\definecolor{myorange}{RGB}{253,184,99}",
        "\\definecolor{mybrown}{RGB}{178,171,210}",
        "\\definecolor{mygreen}{RGB}{120,198,121}",
        "\\definecolor{myteal}{RGB}{27,158,119}",
        "\\definecolor{myblue}{RGB}{66,146,198}",
        "\\definecolor{mypurple}{RGB}{158,154,200}",
        "\\definecolor{mycyan}{RGB}{102,194,165}",
    ]
    for color_def in color_definitions:
        lines.append(color_def)
    lines.append("")

    for i, scene in enumerate(sorted_scenes):
        times = scene_times[scene]
        stats = calculate_boxplot_stats(times)
        color = colors[i % len(colors)]

        lines.append("\\addplot+[")
        lines.append(
            f"    boxplot prepared={{lower whisker={stats['lower_whisker']:.2f}, "
            f"lower quartile={stats['lower_quartile']:.2f}, "
            f"median={stats['median']:.2f}, "
            f"upper quartile={stats['upper_quartile']:.2f}, "
            f"upper whisker={stats['upper_whisker']:.2f}}},"
        )
        lines.append(f"    fill={color}, draw=black")
        lines.append("] coordinates {};")
        lines.append("")

    lines.append("\\end{axis}")
    lines.append("\\end{tikzpicture}")

    return "\n".join(lines)


def main():
    parser = argparse.ArgumentParser(
        description="Generate TikZ boxplot for response times by task"
    )
    parser.add_argument(
        "--stats-csv",
        type=str,
        required=True,
        help="Path to evaluation results CSV file",
    )
    parser.add_argument(
        "--output",
        type=str,
        help="Output file path (if not specified, prints to stdout)",
    )

    args = parser.parse_args()

    if not os.path.exists(args.stats_csv):
        raise FileNotFoundError(f"CSV file not found: {args.stats_csv}")

    scene_times = aggregate_times_by_scene(args.stats_csv)

    if not scene_times:
        print("No timing data found in CSV file.")
        return

    tikz_code = generate_tikz_boxplot(scene_times)

    if args.output:
        with open(args.output, "w") as f:
            f.write(tikz_code)
        print(f"TikZ code saved to {args.output}")
    else:
        print(tikz_code)


if __name__ == "__main__":
    main()
