import os
import ast
import argparse
from collections import defaultdict
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd


def parse_response_times(response_times_str: str) -> List[float]:
    try:
        if isinstance(response_times_str, str):
            response_times = ast.literal_eval(response_times_str)
        else:
            response_times = response_times_str
        return [float(t) for t in response_times if t is not None and t > 0]
    except (ValueError, SyntaxError):
        return []


def calculate_boxplot_stats(times: List[float]) -> Dict[str, float]:
    if not times:
        return {
            "lower_whisker": 0,
            "lower_quartile": 0,
            "median": 0,
            "upper_quartile": 0,
            "upper_whisker": 0,
        }

    times = np.array(times)
    q1 = np.percentile(times, 25)
    median = np.percentile(times, 50)
    q3 = np.percentile(times, 75)

    iqr = q3 - q1
    lower_whisker = max(times.min(), q1 - 1.5 * iqr)
    upper_whisker = min(times.max(), q3 + 1.5 * iqr)

    return {
        "lower_whisker": lower_whisker,
        "lower_quartile": q1,
        "median": median,
        "upper_quartile": q3,
        "upper_whisker": upper_whisker,
    }


def aggregate_model_times(csv_path: str) -> Dict[str, List[float]]:
    df = pd.read_csv(csv_path)
    scene_times = defaultdict(list)

    for _, row in df.iterrows():
        scene_name = row["SceneName"]
        response_times_str = row["K_ResponseTimes"]
        times = parse_response_times(response_times_str)
        scene_times[scene_name].extend(times)

    return dict(scene_times)


def aggregate_human_times(csv_path: str) -> Dict[str, List[float]]:
    df = pd.read_csv(csv_path)
    scene_times = defaultdict(list)

    scene_names = [
        "Agent Sight",
        "Full Views",
        "Polyomino",
        "Pyramid",
        "Revolution",
        "Sun Direction",
        "Unfolded",
    ]

    for _, row in df.iterrows():
        for scene_name in scene_names:
            time_col = scene_name.lower().replace(" ", "_") + "_mean_time"
            score_col = scene_name.lower().replace(" ", "_") + "_score"

            if time_col in df.columns and score_col in df.columns:
                time_value = row[time_col]
                score_value = row[score_col]
                if (
                    pd.notna(time_value)
                    and time_value > 0
                    and pd.notna(score_value)
                    and score_value != -1
                ):
                    scene_times[scene_name].append(float(time_value))

    return dict(scene_times)


def generate_combined_tikz_plot(
    human_times: Dict[str, List[float]], model_times: Dict[str, List[float]]
) -> str:
    all_scenes = set(human_times.keys()) | set(model_times.keys())
    sorted_scenes = sorted(all_scenes)

    max_time = 0
    for times in list(human_times.values()) + list(model_times.values()):
        if times:
            stats = calculate_boxplot_stats(times)
            max_time = max(max_time, stats["upper_whisker"])
    max_time = max_time * 1.1 if max_time > 0 else 10

    lines = []
    lines.append("\\begin{tikzpicture}")
    lines.append("\\begin{axis}[")
    lines.append("    xbar,")
    lines.append("    boxplot/draw direction = x,")
    lines.append("    width=12cm,")
    lines.append("    height=6cm,")
    lines.append("    xlabel={Answer Time [s]},")
    lines.append(f"    ymin=0.5, ymax={len(sorted_scenes) + 0.5},")
    lines.append(f"    ytick={{1,2,3,4,5,6,7}},")
    lines.append("    yticklabels={")
    for scene in sorted_scenes:
        lines.append(f"        {scene},")
    lines.append("    },")
    lines.append("    yticklabel style={align=right, font=\\small},")
    lines.append(f"    xmin=0, xmax={max_time:.1f},")
    lines.append("    axis x line=bottom,")
    lines.append("    axis y line=left,")
    lines.append("    boxplot/variable width,")
    lines.append("    boxplot/box extend=0.25,")
    lines.append("    enlarge y limits=0.02,")
    lines.append("    enlarge x limits=0.05,")
    lines.append("    line width=1pt,")
    lines.append("    every boxplot/.append style={thick},")
    lines.append("    legend style={at={(0.98,0.98)}, anchor=north east},")
    lines.append("]")
    lines.append("")

    lines.append("\\definecolor{humancolor}{RGB}{255,165,0}")
    lines.append("\\definecolor{modelcolor}{RGB}{70,130,180}")
    lines.append("")

    lines.append("\\addlegendimage{area legend, draw=black, fill=humancolor}")
    lines.append("\\addlegendentry{Human}")
    lines.append("\\addlegendimage{area legend, draw=black, fill=modelcolor}")
    lines.append("\\addlegendentry{Models}")
    lines.append("")

    for i, scene in enumerate(sorted_scenes):
        y_pos = i + 1
        lines.append(f"% {y_pos}: {scene}")

        if scene in model_times and model_times[scene]:
            stats = calculate_boxplot_stats(model_times[scene])
            lines.append("\\addplot+[")
            lines.append(
                f"  boxplot prepared={{lower whisker={stats['lower_whisker']:.2f}, "
                f"lower quartile={stats['lower_quartile']:.2f}, "
                f"median={stats['median']:.2f}, "
                f"upper quartile={stats['upper_quartile']:.2f}, "
                f"upper whisker={stats['upper_whisker']:.2f}}},"
            )
            lines.append(f"  boxplot prepared={{draw position={y_pos}-0.18}},")
            lines.append("  fill=modelcolor, draw=black")
            lines.append("] coordinates {};")

        if scene in human_times and human_times[scene]:
            stats = calculate_boxplot_stats(human_times[scene])
            lines.append("\\addplot+[")
            lines.append(
                f"  boxplot prepared={{lower whisker={stats['lower_whisker']:.2f}, "
                f"lower quartile={stats['lower_quartile']:.2f}, "
                f"median={stats['median']:.2f}, "
                f"upper quartile={stats['upper_quartile']:.2f}, "
                f"upper whisker={stats['upper_whisker']:.2f}}},"
            )
            lines.append(f"  boxplot prepared={{draw position={y_pos}+0.18}},")
            lines.append("  fill=humancolor, draw=black")
            lines.append("] coordinates {};")

        lines.append("")

    lines.append("\\end{axis}")
    lines.append("\\end{tikzpicture}")

    return "\n".join(lines)


def main():
    parser = argparse.ArgumentParser(
        description="Generate combined TikZ plot for human vs model response times"
    )
    parser.add_argument(
        "--model-csv",
        type=str,
        required=True,
        help="Path to model evaluation results CSV file",
    )
    parser.add_argument(
        "--human-csv",
        type=str,
        required=True,
        help="Path to human results CSV file",
    )
    parser.add_argument(
        "--output",
        type=str,
        help="Output file path (if not specified, prints to stdout)",
    )

    args = parser.parse_args()

    if not os.path.exists(args.model_csv):
        raise FileNotFoundError(f"Model CSV file not found: {args.model_csv}")

    if not os.path.exists(args.human_csv):
        raise FileNotFoundError(f"Human CSV file not found: {args.human_csv}")

    model_times = aggregate_model_times(args.model_csv)
    human_times = aggregate_human_times(args.human_csv)

    if not model_times and not human_times:
        print("No timing data found in CSV files.")
        return

    tikz_code = generate_combined_tikz_plot(human_times, model_times)

    if args.output:
        with open(args.output, "w") as f:
            f.write(tikz_code)
        print(f"TikZ code saved to {args.output}")
    else:
        print(tikz_code)


if __name__ == "__main__":
    main()
