import os
import ast
import argparse
from collections import defaultdict
from typing import Dict, List

import numpy as np
import pandas as pd


def parse_response_times(response_times_str: str) -> List[float]:
    try:
        if isinstance(response_times_str, str):
            response_times = ast.literal_eval(response_times_str)
        else:
            response_times = response_times_str
        return [float(t) for t in response_times if t is not None and t > 0]
    except (ValueError, SyntaxError):
        return []


def calculate_boxplot_stats(times: List[float]) -> Dict[str, float]:
    if not times:
        return {
            "lower_whisker": 0,
            "lower_quartile": 0,
            "median": 0,
            "upper_quartile": 0,
            "upper_whisker": 0,
        }

    times = np.array(times)
    q1 = np.percentile(times, 25)
    median = np.percentile(times, 50)
    q3 = np.percentile(times, 75)

    iqr = q3 - q1
    lower_whisker = max(times.min(), q1 - 1.5 * iqr)
    upper_whisker = min(times.max(), q3 + 1.5 * iqr)

    return {
        "lower_whisker": lower_whisker,
        "lower_quartile": q1,
        "median": median,
        "upper_quartile": q3,
        "upper_whisker": upper_whisker,
    }


def format_model_name(model_name: str) -> str:
    return model_name.split("/")[-1] if "/" in model_name else model_name


def aggregate_times_by_model(csv_path: str) -> Dict[str, List[float]]:
    df = pd.read_csv(csv_path)

    model_times = defaultdict(list)

    for _, row in df.iterrows():
        model_name = row["ModelName"]
        response_times_str = row["K_ResponseTimes"]

        times = parse_response_times(response_times_str)
        model_times[model_name].extend(times)

    return dict(model_times)


def generate_tikz_boxplot(model_times: Dict[str, List[float]]) -> str:
    colors = [
        "myred",
        "myorange",
        "mybrown",
        "mygreen",
        "myteal",
        "myblue",
        "mypurple",
        "mycyan",
    ]

    sorted_models = sorted(model_times.keys())

    if not model_times:
        max_time = 10
    else:
        max_upper_whisker = 0
        for times in model_times.values():
            stats = calculate_boxplot_stats(times)
            max_upper_whisker = max(max_upper_whisker, stats["upper_whisker"])
        max_time = max_upper_whisker * 1.1

    lines = []
    lines.append("\\begin{tikzpicture}")
    lines.append("\\begin{axis}[")
    lines.append("    boxplot/draw direction = y,")
    lines.append("    ymode=log,")
    lines.append("    width=12cm,")
    lines.append(f"    height={len(sorted_models) * 0.8 + 2}cm,")
    lines.append("    ylabel={Response Time [s]},")
    lines.append(f"    xtick={{1,...,{len(sorted_models)}}},")
    lines.append("    xticklabels={")
    for model in sorted_models:
        formatted_name = format_model_name(model)
        lines.append(f"        {formatted_name},")
    lines.append("    },")
    lines.append("    xticklabel style={rotate=60, anchor=north east, font=\\small},")
    lines.append(f"    ymin=0.1, ymax={max_time:.1f},")
    lines.append("    axis x line=bottom,")
    lines.append("    axis y line=left,")
    lines.append("    boxplot/variable width,")
    lines.append("    enlarge x limits=0.05,")
    lines.append("    enlarge y limits=0.05,")
    lines.append("    line width=1pt,")
    lines.append("    every boxplot/.append style={thick},")
    lines.append("]")
    lines.append("")

    color_definitions = [
        "\\definecolor{myred}{RGB}{230,97,1}",
        "\\definecolor{myorange}{RGB}{253,184,99}",
        "\\definecolor{mybrown}{RGB}{178,171,210}",
        "\\definecolor{mygreen}{RGB}{120,198,121}",
        "\\definecolor{myteal}{RGB}{27,158,119}",
        "\\definecolor{myblue}{RGB}{66,146,198}",
        "\\definecolor{mypurple}{RGB}{158,154,200}",
        "\\definecolor{mycyan}{RGB}{102,194,165}",
    ]
    for color_def in color_definitions:
        lines.append(color_def)
    lines.append("")

    for i, model in enumerate(sorted_models):
        times = model_times[model]
        stats = calculate_boxplot_stats(times)
        color = colors[i % len(colors)]

        lines.append("\\addplot[")
        lines.append(
            f"    boxplot prepared={{lower whisker={stats['lower_whisker']:.2f}, "
            f"lower quartile={stats['lower_quartile']:.2f}, "
            f"median={stats['median']:.2f}, "
            f"upper quartile={stats['upper_quartile']:.2f}, "
            f"upper whisker={stats['upper_whisker']:.2f}}},"
        )
        lines.append(f"    fill={color}, draw=black, solid")
        lines.append("] coordinates {};")
        lines.append("")

    lines.append("\\end{axis}")
    lines.append("\\end{tikzpicture}")

    return "\n".join(lines)


def main():
    parser = argparse.ArgumentParser(
        description="Generate TikZ boxplot for response times by model"
    )
    parser.add_argument(
        "--stats-csv",
        type=str,
        required=True,
        help="Path to evaluation results CSV file",
    )
    parser.add_argument(
        "--output",
        type=str,
        help="Output file path (if not specified, prints to stdout)",
    )

    args = parser.parse_args()

    if not os.path.exists(args.stats_csv):
        raise FileNotFoundError(f"CSV file not found: {args.stats_csv}")

    model_times = aggregate_times_by_model(args.stats_csv)

    if not model_times:
        print("No timing data found in CSV file.")
        return

    tikz_code = generate_tikz_boxplot(model_times)

    if args.output:
        with open(args.output, "w") as f:
            f.write(tikz_code)
        print(f"TikZ code saved to {args.output}")
    else:
        print(tikz_code)


if __name__ == "__main__":
    main()
