import re
import math
import json
from pathlib import Path
from collections import defaultdict

# === CONFIG ===
log_dir = Path("../logs")
filename_pattern = re.compile(r"^(ac|pb)_(\d+)n_(\d+)r(?:_(mcts|dfs))?_v(\d+)\.log$")
method_order = {"mcts": 0, "dfs": 1, "guided": 2}
method_label_map = {"mcts": "MCTS", "dfs": "DFS", "guided": "GGMS"}


# === HELPERS ===
def extract_log_info(log_path):
    try:
        with open(log_path, "r") as f:
            lines = f.readlines()

        failed_counts = []
        for line in lines:
            line = line.strip()
            match = re.match(r"Iteration:\s*\d+,\s*(failed_cases|remaining):\s*(\d+)", line)
            if match:
                failed_counts.append(int(match.group(2)))

        if len(lines) < 2:
            return None

        summary_line = lines[-2].strip()
        total_line = lines[-1].strip()
        if not summary_line.startswith("{") or not total_line.startswith("Total time:"):
            return None

        summary = eval(summary_line)
        total = sum(summary.get(k, 0.0) for k in ["mcts", "train", "validate"])

        return {"total": total, "failed_cases": failed_counts}
    except Exception:
        return None


def log_stats(values):
    logs = [math.log10(v) for v in values if v > 0]
    if not logs:
        return None
    return {"avg": round(sum(logs) / len(logs), 4), "min": round(min(logs), 4), "max": round(max(logs), 4)}


def generate_json_summary(grouped_by_combo):
    stats_dict = defaultdict(dict)

    for key in sorted(grouped_by_combo.keys(), key=lambda x: (x[0], x[1], x[2], method_order.get(x[3], 99))):
        proto, nodes, rounds, method = key
        label = method_label_map[method]
        setting = f"{proto}-{nodes}-{rounds}"
        runs = grouped_by_combo[key]
        success_list = [r for r in runs if r]

        stats_dict[setting][label] = {}
        stats_dict[setting][label]["success"] = round(len(success_list) / len(runs), 4)

        total_times = [r["total"] for r in success_list if r["total"] is not None]
        stat = log_stats(total_times)
        stats_dict[setting][label]["total"] = stat if stat else {"avg": None, "min": None, "max": None}

    with open("summary_log_times.json", "w") as f:
        json.dump(stats_dict, f, indent=4)
    print("✅ Saved summary to summary_log_times.json")


def extract_failed_case_summary(individual_results):
    counterexamples = {}

    for fname, result in individual_results:
        if result and result.get("failed_cases"):
            match = filename_pattern.match(fname)
            if match:
                protocol, nodes, rounds, method, version = match.groups()
                method = method or "guided"
                method_key = method.lower().replace("guided", "ggms")
                key = f"{protocol}-{nodes}-{rounds}_{method_key}_v{version}"
                counterexamples[key] = result["failed_cases"]

    with open("counterexamples_by_version.json", "w") as f:
        json.dump(counterexamples, f, indent=4, sort_keys=True)
    print("✅ Saved sorted counterexamples (with method) to counterexamples_by_version.json")


# === MAIN ===
def main():
    grouped_by_combo = defaultdict(list)
    individual_results = []

    for file in log_dir.iterdir():
        if file.is_file():
            match = filename_pattern.match(file.name)
            if match:
                protocol, nodes, rounds, method, version = match.groups()
                version = int(version)
                if version > 9:
                    continue
                method = method or "guided"
                key = (protocol, int(nodes), int(rounds), method)
                result = extract_log_info(file)
                individual_results.append((file.name, result))
                grouped_by_combo[key].append(result)

    print("=== Individual Results (v0–v9) ===")
    for fname, result in sorted(individual_results):
        print(f"{fname}: {result if result else 'Not done'}")

    print("\n=== Summary by (Protocol, Nodes, Rounds, Method) ===")
    for key in sorted(grouped_by_combo.keys(), key=lambda x: (x[0], x[1], x[2], method_order.get(x[3], 99))):
        proto, nodes, rounds, method = key
        runs = grouped_by_combo[key]
        successful = [r for r in runs if r]
        success_str = f"{len(successful)}/{len(runs)}"
        print(f"{proto}-{nodes}n/{rounds}r {method_label_map[method]}: success rate = {success_str}")

    generate_json_summary(grouped_by_combo)
    extract_failed_case_summary(individual_results)


if __name__ == "__main__":
    main()
