import os
import glob
from pathlib import Path
import re
import pandas as pd
from experiments_utils.extract_trajectory_html import process_html_file
from utils.file_utils import get_task_id_from_file_path


# Define the experiment folder and domains of interest
experiment_folder = (
    "experiments/gpt-4o-2024-08-06/____EXEC____mem=text-a_t=100____CRIT____m=2p_eval=tri_no_reset_1thread"
)
critique_folder = "experiments/gpt-4o-2024-08-06/____EXEC____mem=text-a_t=100____CRIT____m=2p_eval=tri_no_reset_1thread"
domains = ["shopping", "reddit", "classifieds"]
usage_data_dirname = "lm_data"


def is_failed_execution(task_id, domain):
    html_path = os.path.join(experiment_folder, domain, "htmls", f"render_{task_id}.html")
    return not os.path.exists(html_path) or os.path.getsize(html_path) == 0


def get_stop_line(task_id, domain):
    html_path = os.path.join(critique_folder, domain, "htmls", f"render_{task_id}.html")

    trajectory_without_critique = process_html_file(html_path, stop_at_critique=True)

    return len(trajectory_without_critique["states"])


def get_agent_name(csv_file):
    return os.path.basename(os.path.dirname(os.path.dirname(csv_file)))


def get_valid_csvs(csv_files, num_agents=2):
    valid_csvs_per_task_id = {}
    csvs_per_task_id = {}
    for csv_file in csv_files:
        agent_name = get_agent_name(csv_file)
        # task_id = get_task_id_from_file_path(csv_file)
        task_id = Path(csv_file).stem

        if task_id not in csvs_per_task_id:
            csvs_per_task_id[task_id] = {}

        csvs_per_task_id[task_id][agent_name] = csv_file

    for task_id, csvs in csvs_per_task_id.items():
        if len(csvs) == num_agents:
            valid_csvs_per_task_id[task_id] = csvs

    return valid_csvs_per_task_id


# Loop over directories in the experiment folder
df = pd.DataFrame()
for domain in os.listdir(experiment_folder):
    if domain in domains:
        usage_data_dir = os.path.join(experiment_folder, domain, usage_data_dirname)
        if not os.path.exists(usage_data_dir):
            continue

        csv_files = glob.glob(os.path.join(usage_data_dir, "**/*.csv"), recursive=True)
        valid_csvs = get_valid_csvs(csv_files)

        for task_id, csvs in valid_csvs.items():
            if is_failed_execution(task_id, domain):
                continue

            for agent_name, csv_file in csvs.items():
                # Read the current CSV file into a DataFrame
                current_df = pd.read_csv(csv_file)

                usage_with_critique = current_df.sum(numeric_only=True).to_frame().T
                usage_with_critique["agent"] = agent_name
                usage_with_critique["domain"] = domain
                usage_with_critique["task_id"] = task_id
                usage_with_critique["n_actions"] = len(current_df)

                if "executor" in agent_name and critique_folder:
                    last_line_no_critique = get_stop_line(task_id, domain)
                    usage_no_critique = current_df.iloc[:last_line_no_critique].sum(numeric_only=True).to_frame().T
                    usage_no_critique["agent"] = f"{agent_name}_no_critique"
                    usage_no_critique["domain"] = domain
                    usage_no_critique["task_id"] = task_id
                    usage_no_critique["n_actions"] = last_line_no_critique

                # Concatenate with the main DataFrame
                df = pd.concat([df, usage_with_critique, usage_no_critique], ignore_index=True)


# Aggregate the usage data per agent
usage_per_agent = df.groupby("agent").sum(numeric_only=True).reset_index()

# Sum over domain and task_id, excluding agent_no_critique lines
usage_per_domain_task = (
    df[~df["agent"].str.contains("no_critique")].groupby(["domain", "task_id"]).sum(numeric_only=True).reset_index()
)

usage_per_domain_task_no_critique = (
    df[df["agent"].str.contains("no_critique")].groupby(["domain", "task_id"]).sum(numeric_only=True).reset_index()
)

mean_per_domain = usage_per_domain_task.groupby("domain").mean(numeric_only=True).reset_index().round(0)
mean_per_domain_no_critique = (
    usage_per_domain_task_no_critique.groupby("domain").mean(numeric_only=True).reset_index().round(0)
)

mean_per_domain.to_csv("mean_per_domain.csv", index=False)
mean_per_domain_no_critique.to_csv("mean_per_domain_no_critique.csv", index=False)


usage_per_domain_task.mean(numeric_only=True).reset_index().round(0)

usage_per_domain_task_no_critique.mean(numeric_only=True).reset_index().round(0)
