import concurrent.futures
import os
import re
import sys
import warnings
from pathlib import Path

import numpy as np
import pandas as pd

# Suppress RuntimeWarning for division by zero and PerformanceWarning for DataFrame fragmentation
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)
from utils import timing_utils as timer

DOMAINS = [
    "shopping",
    "classifieds",
    "reddit",
    "chrome",
    "gimp",
    "libreoffice_calc",
    "libreoffice_impress",
    "libreoffice_writer",
    "multi_apps",
    "os",
    "thunderbird",
    "vlc",
    "vs_code",
]
OVEWRITE = True
DELETE_INVALID_FILES = False

file_to_parse_pattern = "**/conversation/**/*.txt"
file_to_parse_pattern_html = "**/conversation/**/*.html"

experiments_path = "./offline_experiments/"

envs = ["osw", "vwa"]

original_scores_paths = {
    "osw": "./trace_osworld/ui-tars-1.5_50steps_2025-04-05/consolidated_results.csv",
    "vwa": "./experiments/gpt-4o-2024-08-06/base_stateaware/scores_summary.csv",
}
eval_criteria = ["SUCCESS", "PARTIAL SUCCESS", "FAILURE", "PARTIAL FAILURE"]
exclude_strs = ["zzOld", "logs", "llama", "claude", "k-cond", "k-uncond", "verify_example"]
results_csv_name = "evaluations.csv"

EVAL_SECTION_REGEX = re.compile(
    r"(?:EVALUATION:|Status:|# EVALUATION)\s*(.*?)\s*(?:FEEDBACK:|</pre>)",
    re.S | re.IGNORECASE,
)

EVAL_CRITERIA_REGEX = re.compile(
    rf"^\s*:?\s*({'|'.join(map(re.escape, eval_criteria))})(?::)?",
    re.IGNORECASE,
)


def parse_eval(content: str) -> str:
    """
    Extracts the evaluation score (e.g., SUCCESS, FAILURE) from the EVALUATION section.
    """
    match = EVAL_CRITERIA_REGEX.search(content)
    if match:
        return match.group(1).upper()  # Normalize to uppercase
    else:
        return ""


def delete_invalid_files(file: Path | str):
    # Split the extension from the file.
    base_path = str(file.with_suffix(""))

    all_files = []
    all_files.append(f"{base_path}.html")
    all_files.append(f"{base_path}.txt")
    parent_dir = Path(file).parent.parent
    all_files.append(f"{parent_dir}/usage/{base_path}.csv")
    for file in all_files:
        if not os.path.exists(file):
            continue
        if os.path.isfile(file):
            os.remove(file)


def get_eval_from_trace_file(file: Path | str) -> str:
    if not isinstance(file, Path):
        file = Path(file)

    task_id = get_task_id_from_trace_file(file)
    variation_name = get_variation_name_from_trace_file(file)
    domain = get_domain_from_path(file)
    try:
        with open(file, "r", encoding="utf-8") as f:
            content = f.read()

            # Extract the evaluation section using a regex.
            matches = EVAL_SECTION_REGEX.findall(content)
            if not matches:
                print(f"No evaluation found in {file}", flush=True)
                if DELETE_INVALID_FILES:
                    delete_invalid_files(file)
                return task_id, variation_name, "", domain, file

            # Get the last match
            eval_section_str = matches[-1].strip()
            eval_result = parse_eval(eval_section_str)  # SUCESS, FAILURE, etc

            if eval_result == "":
                print(f"Unable to parse textual score from eval section: {eval_section_str} for {file}", flush=True)
                if DELETE_INVALID_FILES:
                    delete_invalid_files(file)
                return task_id, variation_name, "", domain, file

            return task_id, variation_name, eval_result, domain, file
    except Exception as e:
        print(f"Error reading file {file}: {e}", flush=True)
        return task_id, variation_name, "", domain, file


def get_task_id_from_trace_file(file: Path | str) -> str:
    if not isinstance(file, Path):
        file = Path(file)

    return file.stem


def get_variation_name_from_trace_file(file: Path | str) -> str:
    if not isinstance(file, Path):
        file = Path(file)

    return file.parent.parent.name


def save_csv(
    df: pd.DataFrame, path: Path | str, gold_scores: pd.DataFrame | None = None, key_gold_score: str = "gold_score"
):
    df_to_save = df.copy()
    if gold_scores is not None:
        # Ensure unique_id is of type string for both DataFrames
        gold_df = gold_scores.copy()
        # Create unique_id concatenating domain and task_id
        gold_df["domain_task_id"] = gold_df["domain"].astype(str) + "_" + gold_df["task_id"].astype(str)

        gold_df["domain_task_id"] = gold_df["domain_task_id"].astype(str)
        df_to_save["domain_task_id"] = df_to_save["domain_task_id"].astype(str)

        # Subset gold_scores to only include unique_id and score, and rename score to gold_score
        if "trace_path" in gold_df.columns:
            gold_scores_subset = gold_df[["domain_task_id", key_gold_score, "trace_path"]]
        else:
            gold_scores_subset = gold_df[["domain_task_id", key_gold_score]]

        # Drop the existing gold_score column if it already exists
        if key_gold_score in df_to_save.columns:
            df_to_save = df_to_save.drop(columns=[key_gold_score])

        df_to_save = pd.merge(df_to_save, gold_scores_subset, on="domain_task_id", how="left")

    # Sort by unique_id
    df_to_save = df_to_save.sort_values(by="domain_task_id")
    # Reorder columns to have domain_task_id, env, domain, task_id first
    cols_first = ["domain_task_id", "env", "domain", "task_id", "gold_score"]
    cols_first = [col for col in cols_first if col in df_to_save.columns]
    other_cols = [col for col in df_to_save.columns if col not in cols_first]
    df_to_save = df_to_save[cols_first + other_cols]

    df_to_save.to_csv(path, index=False)

    return df_to_save


def remove_files_to_parse_if_exists(files_to_parse: list[Path], existing_results_df: pd.DataFrame) -> list[Path]:
    variation_names = [get_variation_name_from_trace_file(file) for file in files_to_parse]
    task_ids = [get_task_id_from_trace_file(file) for file in files_to_parse]
    domains = [get_domain_from_path(file) for file in files_to_parse]
    unique_ids = [f"{domain}_{task_id}" for domain, task_id in zip(domains, task_ids)]

    final_files_to_parse = []

    # Create lookup dictionaries for faster access
    existing_values = {
        (str(row["domain_task_id"]), col): row[col]
        for _, row in existing_results_df.iterrows()
        for col in existing_results_df.columns
        if col not in ["domain_task_id", "source", "gold_score", "env", "domain", "task_id"]
    }

    # Create a list of files that need to be parsed
    for file, variation_name, uid in zip(files_to_parse, variation_names, unique_ids):
        if (uid, variation_name) in existing_values:
            val = existing_values[(uid, variation_name)]
            if val and pd.notna(val):
                continue
        final_files_to_parse.append(file)

    return final_files_to_parse


def get_domain_from_path(path: Path | str) -> str:
    path = str(path)
    for domain in DOMAINS:
        match = re.match(rf".*/({domain})/.*", path)
        if match:
            return match.group(1)
    raise ValueError(f"Unable to determine domain from {path}")


def _results_to_df(results: dict) -> pd.DataFrame:
    # Convert the nested dictionary to a DataFrame
    rows = []
    for domain_task_id, variations in results.items():
        for variation_name, data in variations.items():
            row = {
                "domain_task_id": domain_task_id,
                "domain": data["domain"],
                "task_id": data["task_id"],
                variation_name: data["eval_result"],
            }
            rows.append(row)

    result_df = pd.DataFrame(rows)
    return result_df


def parse_experiment_results(
    base_path,
    existing_results_df: pd.DataFrame | None = None,
):
    print(f"Parsing evaluations in {base_path}", flush=True)

    files_to_parse = list(Path(base_path).glob(file_to_parse_pattern))

    files_to_parse_html = list(Path(base_path).glob(file_to_parse_pattern_html))

    files_to_parse = [file for file in files_to_parse if not any(exclude in str(file) for exclude in exclude_strs)]
    files_to_parse_html = [
        file for file in files_to_parse_html if not any(exclude in str(file) for exclude in exclude_strs)
    ]

    files_to_parse_search = set(re.sub(r"\.txt$", "", str(file)) for file in files_to_parse)
    # If file is in the files_to_parse already, ignore; else, add it.
    for file in files_to_parse_html:
        file_base_path = re.sub(r"\.html$", "", str(file))
        if not file_base_path:
            continue
        if file_base_path in files_to_parse_search:
            continue
        else:
            files_to_parse.append(file)

    if existing_results_df is not None:
        files_to_parse = remove_files_to_parse_if_exists(files_to_parse, existing_results_df)

    if len(files_to_parse) == 0:
        print(f"No files to parse in {base_path}", flush=True)
        return None

    else:
        print(f"Parsing {len(files_to_parse)} files in {base_path}", flush=True)

    results = {}

    # if sys.gettrace():
    #     files_to_parse = files_to_parse[:10] + files_to_parse[-10:]
    #     all_results = [get_eval_from_trace_file(file) for file in files_to_parse]  # @debugging
    #     for task_id, variation_name, eval_result, domain in all_results:
    #         if not eval_result:
    #             print(f"Not able to parse eval result for {base_path}/{variation_name}/conversation/{task_id}.html")
    #             continue

    #         if f"{domain}_{task_id}" not in results:
    #             results[f"{domain}_{task_id}"] = {}
    #         results[f"{domain}_{task_id}"][variation_name] = eval_result
    # else:
    with concurrent.futures.ProcessPoolExecutor(max_workers=16) as executor:
        # Submit all file processing tasks concurrently.
        futures = {executor.submit(get_eval_from_trace_file, file): file for file in files_to_parse}
        for future in concurrent.futures.as_completed(futures):
            task_id, variation_name, eval_result, domain, file = future.result()
            if not eval_result:
                print(f"Not able to parse eval result for {file}")
                continue

            domain_task_id = f"{domain}_{task_id}"
            if domain_task_id not in results:
                results[domain_task_id] = {}
            results[domain_task_id][variation_name] = {"eval_result": eval_result, "domain": domain, "task_id": task_id}
            # dump_timings(os.path.dirname(__file__))

    if not results:
        return existing_results_df

    # Convert dictionary to DataFrame after processing all subdirs
    result_df = _results_to_df(results)

    if existing_results_df is not None:
        result_df = (
            result_df.set_index("domain_task_id")
            .combine_first(existing_results_df.set_index("domain_task_id"))
            .reset_index()
        )

    return result_df


def map_eval_to_score(eval: str) -> int | float:
    # If eval is None, empty, or not provided, return np.nan so that it doesn't count.
    if not eval:
        return np.nan
    if eval == "SUCCESS":
        return 1
    elif eval == "PARTIAL SUCCESS":
        return 0
    elif eval == "PARTIAL FAILURE":
        return 0
    elif eval == "FAILURE":
        return 0
    else:
        return np.nan


def compute_confusion_stats(gold_scores: pd.DataFrame, evals: pd.DataFrame):
    # Convert task_id to string
    gold_scores.loc[:, "domain_task_id"] = gold_scores["domain_task_id"].astype(str)
    evals.loc[:, "domain_task_id"] = evals["domain_task_id"].astype(str)

    evals = evals.copy()

    evals["predicted_score"] = evals["eval"].apply(map_eval_to_score)

    # Merge evals with the ground-truth scores on task_id
    merged = pd.merge(evals, gold_scores, on="domain_task_id", how="right", suffixes=("_eval", "_true"))

    # Compute confusion matrix values based on the comparison of predicted vs. true
    false_positive = ((merged["predicted_score"] == 1) & (merged["gold_score"] == 0)).sum()
    false_negative = ((merged["predicted_score"] == 0) & (merged["gold_score"] == 1)).sum()
    true_positive = ((merged["predicted_score"] == 1) & (merged["gold_score"] == 1)).sum()
    true_negative = ((merged["predicted_score"] == 0) & (merged["gold_score"] == 0)).sum()

    tp_ratio = true_positive / (true_positive + false_negative)
    tn_ratio = true_negative / (false_positive + true_negative)
    fp_ratio = false_positive / (false_positive + true_negative)
    fn_ratio = false_negative / (true_positive + false_negative)
    accuracy = (true_positive + true_negative) / (true_positive + true_negative + false_positive + false_negative)
    recall = true_positive / (true_positive + false_negative)
    precision = true_positive / (true_positive + false_positive)
    f1_score = 2 * (precision * recall) / (precision + recall)

    # Added counts for SUCCESS, PARTIAL SUCCESS, and FAILURE based on the "eval" column.
    criteria_counts = {}
    for criteria in eval_criteria:
        criteria_count = evals["eval"].str.upper().eq(criteria).sum()
        criteria_counts[criteria] = criteria_count

    # Filter the false negatives (i.e. cases where gold_score==1 but predicted 0)
    merged["eval_upper"] = merged["eval"].str.upper().str.strip()
    false_negatives = merged[(merged["predicted_score"] == 0) & (merged["gold_score"] == 1)]
    false_neg_partial_success = false_negatives[false_negatives["eval_upper"] == "PARTIAL SUCCESS"].shape[0]
    false_neg_failure = false_negatives[false_negatives["eval_upper"] == "FAILURE"].shape[0]

    total = len(gold_scores)

    all_data = {
        "false_positive": false_positive,
        "false_negative": false_negative,
        "true_positive": true_positive,
        "true_negative": true_negative,
        "tp_ratio": tp_ratio,
        "tn_ratio": tn_ratio,
        "fp_ratio": fp_ratio,
        "fn_ratio": fn_ratio,
        "accuracy": accuracy,
        "recall": recall,
        "precision": precision,
        "f1_score": f1_score,
        "false_neg_partial_success": false_neg_partial_success,
        "false_neg_failure": false_neg_failure,
        "NA": total - false_positive - false_negative - true_positive - true_negative,
        "total": total,
        "NA %": (total - false_positive - false_negative - true_positive - true_negative) / total,
    }

    all_data.update(criteria_counts)

    return all_data


# ===============================================================
# Main
# ===============================================================
# Assumes folder structure:
# experiments/
#  .../
#   gpt-4o-2024-08-06/
#     experiments_domain_1_date_1/
#       variation_1/
#       variation_2/
#       ...
#     experiments_domain_2_date_2/
#       variation_1/
#       variation_2/
#       ...


def load_gold_scores(path_to_csv: str) -> pd.DataFrame:
    gold_scores = pd.read_csv(path_to_csv)
    gold_scores.rename(columns={"score": "gold_score"}, inplace=True)
    gold_scores["task_id"] = gold_scores["task_id"].astype(str)

    # Ensure gold_scores has a domain; if not, assign a default.
    if "domain" not in gold_scores.columns:
        raise ValueError("Gold scores must have a domain column.")

    gold_scores = gold_scores.copy()
    gold_scores["domain_task_id"] = gold_scores["domain"].astype(str) + "_" + gold_scores["task_id"].astype(str)

    return gold_scores


all_evals = pd.DataFrame()
confusion_per_env = {}
raw_evals_per_env = {}
for env in envs:
    # === Accumulate confusion matrices across all models ===
    confusion_matrices = pd.DataFrame()
    experiments_path_env = Path(experiments_path) / env

    for model_subdir in [str(s) for s in experiments_path_env.iterdir() if s.is_dir()]:
        if any(exclude in model_subdir for exclude in exclude_strs) or not any(env in model_subdir for env in envs):
            continue

        gold_scores = load_gold_scores(Path(model_subdir) / "gold_scores.csv")
        if not OVEWRITE and (Path(model_subdir) / results_csv_name).exists():
            existing_results = pd.read_csv(Path(model_subdir) / results_csv_name)
            if existing_results.empty:
                existing_results = None
        else:
            existing_results = None

        # Get the model name from the subdirectory.
        model = model_subdir.split("/")[-1]

        model_results_all_domains = parse_experiment_results(
            base_path=model_subdir, existing_results_df=existing_results
        )
        if model_results_all_domains is None:
            continue

        # Create a pivot table so that there is one row per unique entry
        consolidated_evals = model_results_all_domains.pivot_table(index="domain_task_id", aggfunc="first")
        consolidated_evals.reset_index(inplace=True)

        # Saves `evaluations.csv` at model subdirectory.
        save_path = Path(model_subdir) / results_csv_name

        consolidated_evals["env"] = env
        save_csv(consolidated_evals, save_path, gold_scores=gold_scores)
        print(f"Evaluation results consolidated for {model} at {save_path}")

        # Ensure unique_id is a string in both DataFrames
        consolidated_evals["domain_task_id"] = consolidated_evals["domain_task_id"].astype(str)
        gold_scores["domain_task_id"] = gold_scores["domain_task_id"].astype(str)

        # Compute confusion stats for each evaluation column (skipping the non-eval ones)
        for col in consolidated_evals.columns:
            if col in ["domain_task_id", "source", "gold_score", "env", "domain", "task_id"]:
                continue

            # Prepare the evaluation subset for this column.
            evals_subset = consolidated_evals[["domain_task_id", col]].rename(columns={col: "eval"})

            # ------------------------------------------------------------------
            # Overall stats for the entire environment for this evaluation column:
            overall_confusion_stats = compute_confusion_stats(gold_scores, evals_subset)
            # Create a column name in the format: model--config--domain.
            # For overall stats, we use "Overall" for the domain level.
            overall_col_name = f"{model}--{col}--all"
            confusion_matrices[overall_col_name] = pd.Series(overall_confusion_stats)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Stats per domain subset:
            # Loop through each domain in the gold scores.
            for domain in gold_scores["domain"].unique():
                # Filter consolidated_evals to only include rows belonging to the current domain.
                domain_mask = consolidated_evals["domain"] == domain
                domain_evals_subset = consolidated_evals.loc[domain_mask, ["domain_task_id", col]].rename(
                    columns={col: "eval"}
                )
                # Filter gold scores for the current domain.
                domain_gold_scores = gold_scores[gold_scores["domain"] == domain]
                # Compute confusion statistics for the domain subset.
                domain_confusion_stats = compute_confusion_stats(domain_gold_scores, domain_evals_subset)
                # Create a column name in the format: model--config--domain.
                domain_col_name = f"{model}--{col}--{domain}"
                confusion_matrices[domain_col_name] = pd.Series(domain_confusion_stats)
            # ------------------------------------------------------------------

    # === Convert the flat column names into a MultiIndex (domain, model, config) ===
    new_columns = []
    for col in confusion_matrices.columns:
        if "--" in col:
            parts = col.split("--")
            if len(parts) == 3:
                # Note that parts are in order: model, config, domain.
                # We want the MultiIndex in the order: (domain, model, config)
                new_columns.append((parts[2], parts[0], parts[1]))
            else:
                new_columns.append(("", col, ""))
    confusion_matrices.columns = pd.MultiIndex.from_tuples(new_columns, names=["domain", "model", "config"])
    # Optional: sort columns by the multi-index levels.
    confusion_matrices = confusion_matrices.sort_index(axis=1, level=["domain", "model", "config"])

    # Optional: sort the columns so that the first level (model) is in order.
    confusion_matrices.T.to_csv(experiments_path_env / "consolidated_stats.csv", index=True)
    print(f"Offline stats for {env} saved to", experiments_path_env / "consolidated_stats.csv")
    confusion_per_env[env] = confusion_matrices
    raw_evals_per_env[env] = consolidated_evals

# Aggregate confusion matrices across all environments
# This creates a MultiIndex where the top level indicates the environment.
if confusion_per_env:
    # First, concatenate the confusion matrices from the different environments.
    consolidated_confusion = pd.concat(confusion_per_env, axis=1)
    # consolidated_output_path = Path(experiments_path) / "consolidated_stats_multilevel.csv"
    # consolidated_confusion.to_csv(consolidated_output_path, index=True)
    # print(f"Consolidated confusion matrices saved at {consolidated_output_path}")

    # --- Merge only the env and domain levels ---
    # We assume the MultiIndex is structured as follows:
    # Level 0: env
    # Level 1: domain
    # Level 2: model
    # Level 3: config
    #
    # Merge levels 0 and 1 (env and domain) into a single value.
    merged_env_domain = [
        f"{env}_{domain}"
        for env, domain in zip(
            consolidated_confusion.columns.get_level_values(0), consolidated_confusion.columns.get_level_values(1)
        )
    ]

    # Rebuild the MultiIndex with the merged env_domain level and preserve any remaining levels.
    if consolidated_confusion.columns.nlevels > 2:
        new_tuples = list(
            zip(
                merged_env_domain,
                consolidated_confusion.columns.get_level_values(2),
                consolidated_confusion.columns.get_level_values(3),
            )
        )
        new_index = pd.MultiIndex.from_tuples(new_tuples, names=["env_domain", "model", "config"])
    else:
        new_index = pd.Index(merged_env_domain, name="env_domain")

    consolidated_confusion.columns = new_index

    # Transpose the DataFrame so that the merged env_domain (and remaining levels)
    # become the row labels. Note that when saving a DataFrame with a MultiIndex header,
    # Pandas writes multiple header rows.
    transposed_confusion = consolidated_confusion.T

    transposed_output_path = Path(experiments_path) / "consolidated_stats.csv"
    transposed_confusion.to_csv(transposed_output_path, index=True)
    print(f"Consolidated stats saved at {transposed_output_path}")

else:
    print("No confusion matrices to consolidate.")


if raw_evals_per_env:
    consolidated_evals = pd.concat(raw_evals_per_env, axis=0)
    consolidated_output_path = Path(experiments_path) / "consolidated_evals.csv"
    consolidated_evals.to_csv(consolidated_output_path, index=True)
    print(f"Consolidated evals saved at {consolidated_output_path}")
else:
    print("No evals to consolidate.")
