import argparse
import csv
import json
import os
from collections import defaultdict
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt


mpl.rcParams.update(
    {
        "figure.dpi": 300,
        "font.size": 11.2,
        "axes.titlesize": 14,
        "axes.labelsize": 11.2,
        "legend.fontsize": 12.5,
        "xtick.labelsize": 10.2,
        "ytick.labelsize": 10.2,
        "axes.linewidth": 2.9,
        "grid.linewidth": 0.7,
        # "font.family": "serif",
        "mathtext.fontset": "cm",
        "figure.facecolor": "white",
        "axes.facecolor": "#fcfcfd",
        "savefig.bbox": "tight",
    }
)


def infer_model_count(model_value: str) -> int:
    """Infer the number of models from the `model` field value.

    Examples:
    - "1" -> 1
    - "1-2" -> 2
    - "5-6-7-8-9" -> 5
    - "  3  " -> 3
    """
    text = str(model_value).strip()
    # import pdb; pdb.set_trace()
    if not text:
        return 0

    # If hyphen-separated tokens are present, count tokens
    if "-" in text:
        tokens = [t for t in text.split("-") if t]
        return len(tokens)
    else:
        return 1
    


def extract_overall_ce_losses(csv_path: str, problem_key: str = "overall") -> dict:
    """Extract CE Loss for rows where problem matches `problem_key` (case-insensitive),
    grouped by number of models (1..9) inferred from the `model` column.
    Returns a dict: {1: [..], 2: [..], ..., 9: [...]}.
    """
    # Initialize keys 1..9 with empty lists to ensure presence even if no data
    result: dict[int, list[float]] = {k: [] for k in range(1, 10)}

    with open(csv_path, "r", newline="") as f:
        reader = csv.DictReader(f)

        # Normalize column names once for flexible matching
        field_map = {name.lower(): name for name in reader.fieldnames or []}

        # Expected columns (case-insensitive)
        model_col = field_map.get("model")
        problem_col = field_map.get("problem")
        ce_col = field_map.get("ce loss")

        if not model_col or not problem_col or not ce_col:
            missing = [
                name
                for name, ok in {
                    "model": bool(model_col),
                    "problem": bool(problem_col),
                    "CE Loss": bool(ce_col),
                }.items()
                if not ok
            ]
            raise KeyError(
                f"Missing expected column(s) in CSV: {', '.join(missing)}. "
                f"Found columns: {reader.fieldnames}"
            )

        for row in reader:
            prob_val = str(row.get(problem_col, "")).strip().lower()
            if prob_val != problem_key.lower():
                continue

            model_val = row.get(model_col, "")
            model_count = infer_model_count(model_val)
            if model_count < 1 or model_count > 9:
                continue

            ce_text = str(row.get(ce_col, "")).strip()
            if not ce_text:
                continue
            try:
                ce_value = float(ce_text)
            except ValueError:
                continue

            # Keep 6 decimal places
            ce_value = round(ce_value, 6)
            result[model_count].append(ce_value)

    return result


def main() -> None:
    parser = argparse.ArgumentParser(
        description=(
            "Extract Overall CE Loss from CSV grouped by number of models (1..9), and save as JSON."
        )
    )
    parser.add_argument(
        "--csv",
        dest="csv_path",
        type=str,
        default="data/results_llama_8B.csv",
        help="Path to input CSV file.",
    )
    parser.add_argument(
        "--output",
        dest="output_path",
        type=str,
        default=None,
        help="Path to output JSON file. Defaults to <csv_dir>/overall_ce_by_model_count.json",
    )
    parser.add_argument(
        "--problem",
        dest="problem_key",
        type=str,
        default="overall",
        help="Problem name to match (case-insensitive). Default: overall",
    )
    parser.add_argument(
        "--plot-output",
        dest="plot_output",
        type=str,
        default=None,
        help="Path to save the plot image. Defaults to <csv_dir>/example.png",
    )
    parser.add_argument(
        "--show",
        dest="show_plot",
        action="store_true",
        help="Show the plot window after saving (off by default).",
    )
    parser.add_argument(
        "--bg-color",
        dest="bg_color",
        type=str,
        default="white",
        help="Background color for figure and axes (e.g., '#ffffff' or 'white').",
    )

    args = parser.parse_args()

    data = extract_overall_ce_losses(args.csv_path, problem_key=args.problem_key)

    out_path = args.output_path
    if not out_path:
        base_dir = os.path.dirname(os.path.abspath(args.csv_path))
        out_path = os.path.join(base_dir, "overall_ce_by_model_count.json")

    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

    print(f"Saved JSON to: {out_path}")

    # Print lengths of value lists for each key (1..9)
    for key in range(1, 10):
        values = data.get(key, [])
        print(f"models={key}\tcount={len(values)}")

    # ----------------------------
    # Plotting and statistics
    # ----------------------------
    # Filter X values that have at least one sample
    x_values = sorted([k for k, v in data.items() if len(v) > 0])
    if not x_values:
        print("No data available to plot.")
        return

    # Compute averages per X
    y_avg = [round(float(np.mean(data[x])), 4) for x in x_values]
    print("各X值的平均Y值:", y_avg)

    print("\n各X值的Y值方差和标准差:")
    for x in x_values:
        variance = float(np.var(data[x]))
        std_dev = float(np.std(data[x]))
        print(f"X={x}: 方差 = {variance:.6f}, 标准差 = {std_dev:.6f}")

    # Collect all Y values to determine Y range
    all_y_values = [y for y_list in data.values() for y in y_list]
    y_min, y_max = min(all_y_values), max(all_y_values)

    # Figure and style
    fig, ax = plt.subplots(figsize=(8, 6))
    plt.style.use('seaborn-v0_8')
    fig.patch.set_alpha(1.0)
    ax.patch.set_alpha(1.0)
    fig.patch.set_facecolor(args.bg_color)
    ax.set_facecolor(args.bg_color)

    # Scatter points for all samples
    sample_labeled = False
    for x in x_values:
        y_values = data[x]
        x_coords = [x for _ in y_values]
        plt.scatter(
            x_coords,
            y_values,
            marker='o',
            color='#4e8ca2',
            s=64,
            alpha=0.45,
            edgecolors='white',
            linewidths=0.5,
            label='Samples' if not sample_labeled else ''
        )
        sample_labeled = True

    # Average line
    plt.plot(
        x_values,
        y_avg,
        marker='o',
        linestyle='-',
        color='#004c6d',
        markersize=14,
        linewidth=3.5,
        markeredgecolor='white',
        markeredgewidth=1.8,
        alpha=0.9,
        label='8B-TA',
        zorder=5,

    )

    # Grid and labels
    # plt.grid(True, alpha=0.3, linestyle='--')
    plt.xlabel('Expert Number', fontsize=16)
    plt.ylabel('Overall Loss', fontsize=16)

    # Axis ranges
    plt.xlim(min(x_values) - 0.5, max(x_values) + 0.5)
    plt.xticks(x_values, fontsize=16)
    plt.ylim(y_min - 0.01, y_max + 0.01)

    # Legend and layout
    plt.legend(loc='upper right', frameon=True, fancybox=True, shadow=True, fontsize=13.5)
    plt.tight_layout()

    # extract base name from csv path
    base_name = os.path.basename(args.csv_path)
    # remove csv extension
    base_name = base_name.replace('.csv', '')
   
    base_name = base_name.replace('results_', '')
    base_name = base_name.replace('_', '-')


    # file_name = args.csv_path.split('/')[-1].split('.')[0]
    # file_name = file_name.replace('results_', '')
    # file_name = file_name.replace('_', '-')
    # Save figure
    plot_out = args.plot_output
    if not plot_out:
        base_dir = os.path.dirname(os.path.abspath(args.csv_path))
        plot_out = os.path.join(base_dir, f'example-{base_name}.png')
    plt.savefig(plot_out, dpi=300, bbox_inches='tight', transparent=False)
    print(f"Saved plot to: {plot_out}")

    # Optional show
    if args.show_plot:
        plt.show()
    else:
        plt.close(fig)


if __name__ == "__main__":
    main()


