#!/usr/bin/env python

import os
import sys
import json
import argparse
import numpy as np
import datetime

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils import (
    write_dicts_to_csv,
    save_runtime_histogram,
    save_runtime_density_plot,
    save_runtime_boxplot,
    save_ratio_histogram,
    save_runtime_scatter_plot,
    save_runtime_multi_baselines_scatter_plot,
)

KINDS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../datasets/gqa"))

def load_runtimes(kind):
    eval_dir = os.path.join(KINDS_DIR, kind)
    
    runtimes = {}
    if not os.path.exists(eval_dir):
        print(f"Directory not found: {eval_dir}")
        return runtimes
    
    for filename in os.listdir(eval_dir):
        if filename.endswith(".json"):
            filepath = os.path.join(eval_dir, filename)
            prog_id = os.path.splitext(filename)[0]
            try:
                with open(filepath, 'r') as f:
                    result = json.load(f)
                    runtimes[prog_id] = result["time_ns"] / 1e6  # convert to milliseconds
            except Exception as e:
                print(f"Failed to load {filename}: {e}")
    return runtimes

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-k", "--kind",
        required=True,
        help="Execution kind directory name"
    )
    parser.add_argument(
        "-b", "--baselines",
        nargs="+",
        required=True,
        help="Execution kinds as reference point"
    )
    args = parser.parse_args()

    kind_runtime = load_runtimes(args.kind)
    kind_runtime_values = tuple(kind_runtime.values())

    all_runtimes = {}
    all_ratios = {}
    impr_ratios = {}
    stats = []
    kind_stats = {
        "execution": args.kind,
        "count": len(kind_runtime),
        "impr_count": 0,
        "median_ms": round(np.median(kind_runtime_values), 2),
        "mean_ms": round(np.mean(kind_runtime_values), 2),
        "std_ms": round(np.std(kind_runtime_values), 2),
        "min_ms": round(np.min(kind_runtime_values), 2),
        "max_ms": round(np.max(kind_runtime_values), 2),
        "median_ratio": None,
        "mean_ratio": None,
        "std_ratio": None,
        "min_ratio": None,
        "max_ratio": None,
        "median_impr_ratio": None,
        "mean_impr_ratio": None,
        "std_impr_ratio": None,
        "min_impr_ratio": None,
        "max_impr_ratio": None,
    }
    stats.append(kind_stats)
    all_runtimes[args.kind] = kind_runtime_values

    for baseline in args.baselines:
        baseline_runtimes = load_runtimes(baseline)
        baseline_runtime_values = tuple(baseline_runtimes.values())
        if baseline_runtimes:
            all_runtimes[baseline] = baseline_runtime_values
            all_ratios[baseline] = [kind_runtime[k]/baseline_runtimes[k] for k in kind_runtime.keys() if k in baseline_runtimes]
            impr_ratios[baseline] = [kind_runtime[k]/baseline_runtimes[k] for k in kind_runtime.keys() if k in baseline_runtimes and kind_runtime[k]/baseline_runtimes[k] < 0.975]
            ratios = all_ratios[baseline]
            improvements = impr_ratios[baseline]
            baseline_stats = {
                "execution": baseline,
                "count": len(baseline_runtimes),
                "impr_count": len(improvements),
                "median_ms": round(np.median(baseline_runtime_values), 2),
                "mean_ms": round(np.mean(baseline_runtime_values), 2),
                "std_ms": round(np.std(baseline_runtime_values), 2),
                "min_ms": round(np.min(baseline_runtime_values), 2),
                "max_ms": round(np.max(baseline_runtime_values), 2),
                "median_ratio": round(np.median(ratios), 3),
                "mean_ratio": round(np.mean(ratios), 3),
                "std_ratio": round(np.std(ratios), 3),
                "min_ratio": round(np.min(ratios), 3),
                "max_ratio": round(np.max(ratios), 3),
                "median_impr_ratio": round(np.median(improvements), 3) if improvements else None,
                "mean_impr_ratio": round(np.mean(improvements), 3) if improvements else None,
                "std_impr_ratio": round(np.std(improvements), 3) if improvements else None,
                "min_impr_ratio": round(np.min(improvements), 3) if improvements else None,
                "max_impr_ratio": round(np.max(improvements), 3) if improvements else None,
            }
            stats.append(baseline_stats)
            print(f"[{baseline}] Count: {baseline_stats['count']}, Mean: {baseline_stats['mean_ms']:.2f} ms, Std: {baseline_stats['std_ms']:.2f} ms")
            
        else:
            print(f"[{baseline}] No runtimes found.")

    # timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    if len(args.baselines) == 3:
        dirname = "three_baselines"
    elif len(args.baselines) == 2:
        dirname = "two_baselines"
    else:
        dirname = f"baseline_{args.baselines[0].replace('/', '_')}"
    output_dir = f"{KINDS_DIR}/plots/runtime/{args.kind.split('/', 1)[1]}/{dirname}"
    os.makedirs(output_dir, exist_ok=True)
    
    # Save CSV
    if stats:
        output_path = f"{output_dir}/runtime_stats.csv"
        write_dicts_to_csv(output_path, stats)
        json_output_path = output_path.replace(".csv", ".json")
        with open(json_output_path, "w") as f_json:
            json.dump(stats, f_json, indent=2)

    # Save plots
    if all_runtimes:
        hist_path = f"{output_dir}/runtime_histogram.pdf"
        truncation = None
        save_runtime_histogram(all_runtimes, hist_path, truncate=truncation, title=f"Execution Time Comparison Histogram")
        save_runtime_density_plot(all_runtimes, hist_path.replace("_histogram.pdf", "_density.pdf"), truncate=truncation, title=f"Execution Time Comparison Density")
        save_runtime_boxplot(all_runtimes, hist_path.replace("_histogram.pdf", "_boxplot.pdf"), truncate=truncation, title=f"Execution Time Comparison Boxplot")
        
        first_baseline = args.baselines[0]
        if first_baseline in all_runtimes:
            improved_runtime_scatter_path = f"{output_dir}/improved_runtime_scatter.pdf"
            save_runtime_scatter_plot(
                kind_runtime,
                load_runtimes(first_baseline),
                save_path=improved_runtime_scatter_path,
                labels=("QUASAR", "Python"),
                improved_only=True,
                truncate=20
            )
            improved_runtime_scatter_path_log = f"{output_dir}/improved_runtime_scatter_log.pdf"
            save_runtime_scatter_plot(
                kind_runtime,
                load_runtimes(first_baseline),
                save_path=improved_runtime_scatter_path_log,
                labels=("QUASAR", "Python"),
                improved_only=True,
                truncate=100,
                log_scale=True
            )
            all_runtime_scatter_path = f"{output_dir}/all_runtime_scatter.pdf"
            save_runtime_scatter_plot(
                kind_runtime,
                load_runtimes(first_baseline),
                save_path=all_runtime_scatter_path,
                labels=("QUASAR", "Python"),
                truncate=20
            )
            all_runtime_scatter_path_log = f"{output_dir}/all_runtime_scatter_log.pdf"
            save_runtime_scatter_plot(
                kind_runtime,
                load_runtimes(first_baseline),
                save_path=all_runtime_scatter_path_log,
                labels=("QUASAR", "Python"),
                truncate=100,
                log_scale=True
            )
        if len(args.baselines) > 1:
            all_baseline_runtimes = [load_runtimes(b) for b in args.baselines]
            all_baseline_labels = ["Python actions", "EPIC direct actions"]
            save_runtime_multi_baselines_scatter_plot(
                kind_runtime,
                all_baseline_runtimes,
                hist_path.replace("_histogram.pdf", "_scatter_all.pdf"),
                kind_label="QUASAR actions",
                baseline_labels=all_baseline_labels,
                truncate=20
            )

    if all_ratios and impr_ratios:
        improved_ratio_hist_path = f"{output_dir}/improved_runtime_ratio_histogram.pdf"
        save_ratio_histogram(impr_ratios, improved_ratio_hist_path, label="Execution Time Ratio (QUASAR / Python)", color="#4682B4")
        all_ratio_hist_path = f"{output_dir}/all_runtime_ratio_histogram.pdf"
        save_ratio_histogram(all_ratios, all_ratio_hist_path, label="Execution Time Ratio (QUASAR / Python)", color="#4682B4")
    
    print(f"Saved runtime comparison to: {output_dir}")

if __name__ == "__main__":
    main()
