from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Optional, Tuple
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import argparse
import json
import glob
import os
from collections import defaultdict, Counter
from rllm.report.task_type import get_task_subtype, get_subtype_order
from rllm.report.monitor import (
    get_instance_summary,
    get_experiment_summary,
    print_results,
    get_metrics,
)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--results_dir", type=str, required=True)
    parser.add_argument("--prefix", type=str, default=None)
    parser.add_argument("--hotfix", action="store_true")

    return parser.parse_args()

def load_json(file_path: str) -> Dict[str, Any]:
    with open(file_path, "r") as f:
        return json.load(f)

def save_json(data: Dict[str, Any], file_path: str):
    with open(file_path, "w") as f:
        json.dump(data, f, indent=4)

def parsing_experiment_info(result_path: str) -> str:
    # get the experiment name from the result_path
    folder_list = result_path.split("/")
    model_name = folder_list[-3]
    task_name = folder_list[-2]
    suffix = folder_list[-1].split(".")[0]
    experiment_name = f"{model_name}_{task_name}_{suffix}"
    return experiment_name

def get_hotfixed_task_subtype(task_info: dict, hotfix_data: dict) -> str:
    task_name = task_info.get("task_name", "default")
    if "sudoku_normal_hard" in task_name or "sudoku_normal_extremely_hard" in task_name:
        if task_name in hotfix_data["hard"]:
            return "standard_hard"
        else:
            return "standard_extremely_hard"
    else:
        return None

def get_results(results_dir: str, prefix: Optional[str] = None, hotfix: bool = False):
    # get all json files in the results_dir
    prefix_filter = f"*{prefix}*.json" if prefix else "*.json"
    json_files = glob.glob(os.path.join(results_dir, prefix_filter))
    experiment_results = defaultdict(list)
    print(json_files)
    if hotfix:
        try:
            hotfix_data = load_json("rllm/report/hotfixed_task_name_dict.json")
            print("Hotfix data loaded successfully.")
        except FileNotFoundError:
            print("Hotfix data file not found. Please run the script to generate the hotfix data.")
            hotfix_data = None
    else:
        print("Hotfix is disabled.")
        hotfix_data = None
    for json_file in json_files:
        results = load_json(json_file)
        results_raw = []
        for result in results:
            if hotfix:
                hotfix_subtype = get_hotfixed_task_subtype(result["task"], hotfix_data)
                results_raw.append(get_instance_summary(result, use_difficulty=False, hotfix_subtype=hotfix_subtype))
            else:
                results_raw.append(get_instance_summary(result))
        experiment_name = parsing_experiment_info(json_file)
        experiment_results[experiment_name] = results_raw

    subtype_order = get_subtype_order(results_raw[0].task_type)
    print(subtype_order)
    total_stats, subtype_stats = get_experiment_summary(experiment_results, subtype_order)
    print_results(total_stats, subtype_stats)

    metrics = {
        "total_stats": total_stats,
        "subtype_stats": subtype_stats,
    }
    save_json(metrics, os.path.join(results_dir, "metrics.json"))

if __name__ == "__main__":
    args = parse_args()
    get_results(args.results_dir, args.prefix, args.hotfix)