#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Tuple
import argparse
import json
import glob
import os
from collections import defaultdict, Counter

import numpy as np

from datasets import load_dataset

HF_DATASET_NAME = YOUR_PATH

VARIANT_ID = {
    "new_abc_variant": 0,
    "variant": 1,
    "default": 2,
}


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--results_dir", type=str, required=True)
    parser.add_argument("--prefix", type=str, default=None)
    parser.add_argument("--hotfix", action="store_true")
    return parser.parse_args()


def load_json(file_path: str) -> Any:
    with open(file_path, "r") as f:
        return json.load(f)


def save_json(data: Dict[str, Any], file_path: str):
    with open(file_path, "w") as f:
        json.dump(data, f, indent=4)


def parsing_experiment_info(result_path: str) -> str:
    folder_list = result_path.replace("\\", "/").split("/")
    model_name = folder_list[-3] if len(folder_list) >= 3 else "unknown_model"
    task_name = folder_list[-2] if len(folder_list) >= 2 else "unknown_task"
    suffix = os.path.basename(result_path).split(".")[0]
    return f"{model_name}_{task_name}_{suffix}"


def _normalize_board(board: Any) -> str:
    if board is None:
        return ""
    if isinstance(board, str):
        return board.strip()
    if isinstance(board, (list, tuple)):
        return json.dumps(board, ensure_ascii=False, separators=(",", ":"))
    if isinstance(board, dict):
        return json.dumps(board, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
    return str(board).strip()


def build_initial_board_to_group_map(dataset_name: str = HF_DATASET_NAME):
    ds_dict = load_dataset(dataset_name)

    board_to_group = {}
    group_order = []
    seen = set()

    for _, ds in ds_dict.items():
        for row in ds:
            b = _normalize_board(row.get("initial_board"))
            g = str(row.get("group", "default") or "default")
            if g not in seen:
                seen.add(g)
                group_order.append(g)
            if b and b not in board_to_group:
                board_to_group[b] = g

    return board_to_group, group_order


def _infer_variant_name_from_path(path: str) -> str:
    p = path.replace("\\", "/")
    if "/new_abc_variant/" in p:
        return "new_abc_variant"
    if "/variant/" in p:
        return "variant"
    return "default"


def _collect_json_files(results_dir: str, prefix: Optional[str]):
    prefix_filter = f"*{prefix}*.json" if prefix else "*.json"
    files = []
    files.extend(glob.glob(os.path.join(results_dir, prefix_filter)))
    for sub in ["new_abc_variant", "variant"]:
        subdir = os.path.join(results_dir, sub)
        if os.path.isdir(subdir):
            files.extend(glob.glob(os.path.join(subdir, "**", prefix_filter), recursive=True))
    return sorted(set(files))


@dataclass
class Result:
    task_name: str
    task_subtype: str
    variant_id: int
    is_success: bool
    termination_reason: str
    step_num: int
    progress: float


def get_instance_summary(result, board_to_group, variant_id):
    task_info = result.get("task") or {}
    board = _normalize_board(task_info.get("initial_board"))
    task_subtype = board_to_group.get(board, "default")

    termination_reason = (result.get("last_observation") or {}).get("termination_reason", "")
    is_success = "COMPLETE" in termination_reason

    steps = result.get("steps") or []
    step_num = len(steps)

    progress = (result.get("last_observation") or {}).get("progress", 0.0) or 0.0
    try:
        progress = float(progress)
    except Exception:
        progress = 0.0

    return Result(
        task_name=task_info.get("task_name", "default"),
        task_subtype=task_subtype,
        variant_id=variant_id,
        is_success=is_success,
        termination_reason=termination_reason,
        step_num=step_num,
        progress=progress,
    )


def get_experiment_summary(experiment_results, subtype_order):
    num_trials = len(experiment_results)

    subtype_results = {st: {} for st in subtype_order}
    for _, results in experiment_results.items():
        for r in results:
            subtype_results.setdefault(r.task_subtype, {})
            subtype_results[r.task_subtype].setdefault(r.task_name, []).append(r)

    subtype_stats = {st: {} for st in subtype_order}

    for subtype in subtype_order:
        task_results = subtype_results.get(subtype, {})

        pass_rate = []
        pass_at_k = []
        progress_mean = []
        eff_steps = []

        for _, rs in task_results.items():
            pass_rate.append(sum(r.is_success for r in rs) / num_trials if num_trials else 0.0)
            pass_at_k.append(any(r.is_success for r in rs))
            progress_mean.append(sum(r.progress for r in rs) / num_trials if num_trials else 0.0)
            for r in rs:
                if "COMPLETE" in r.termination_reason or "FILLED" in r.termination_reason:
                    eff_steps.append(r.step_num)

        subtype_stats[subtype] = {
            "pass_rate": float(np.mean(pass_rate)) if pass_rate else 0.0,
            "pass_at_k": float(np.mean(pass_at_k)) if pass_at_k else 0.0,
            "progress_mean": float(np.mean(progress_mean)) if progress_mean else 0.0,
            "eff_step_num": float(np.mean(eff_steps)) if eff_steps else 0.0,
        }

    return subtype_stats


def print_results_table_by_variant(subtype_order, subtype_stats_by_variant):
    print("metric\t" + "\t".join(subtype_order) + "\tvariant")

    for vid in sorted(subtype_stats_by_variant.keys()):
        st = subtype_stats_by_variant[vid]
        print(
            # "pass@k(%)\t"
            ""
            + "\t".join(f"{st[s]['pass_at_k']*100:.2f}" for s in subtype_order)
            + f"\t{vid}"
        )

    for vid in sorted(subtype_stats_by_variant.keys()):
        st = subtype_stats_by_variant[vid]
        print(
            # "avg_success_rate(%)\t"
            ""
            + "\t".join(f"{st[s]['pass_rate']*100:.2f}" for s in subtype_order)
            + f"\t{vid}"
        )

    for vid in sorted(subtype_stats_by_variant.keys()):
        st = subtype_stats_by_variant[vid]
        print(
            # "avg_progress(%)\t"
            ""
            + "\t".join(f"{st[s]['progress_mean']*100:.2f}" for s in subtype_order)
            + f"\t{vid}"
        )

    for vid in sorted(subtype_stats_by_variant.keys()):
        st = subtype_stats_by_variant[vid]
        print(
            # "avg_eff_steps\t"
            ""
            + "\t".join(f"{st[s]['eff_step_num']:.2f}" for s in subtype_order)
            + f"\t{vid}"
        )


def get_results(results_dir, prefix, hotfix):
    board_to_group, subtype_order = build_initial_board_to_group_map()
    json_files = _collect_json_files(results_dir, prefix)

    exp_by_variant = defaultdict(dict)

    for jf in json_files:
        vname = _infer_variant_name_from_path(jf)
        vid = VARIANT_ID.get(vname, 999)

        results_raw = [
            get_instance_summary(r, board_to_group, vid) for r in load_json(jf)
        ]
        exp_name = parsing_experiment_info(jf)
        exp_by_variant[vid][exp_name] = results_raw

    subtype_stats_by_variant = {}
    for vid, exp_results in exp_by_variant.items():
        subtype_stats_by_variant[vid] = get_experiment_summary(
            exp_results, list(subtype_order)
        )

    print_results_table_by_variant(subtype_order, subtype_stats_by_variant)


if __name__ == "__main__":
    args = parse_args()
    get_results(args.results_dir, args.prefix, args.hotfix)
