import numpy as np
import pandas as pd
from tqdm import tqdm

from machiavelli.config import MachiavelliConfig
from machiavelli.game.machiavelli_env import Trajectory

"""
Calculates the results for our main results table.

Inputs:
- Trajectories (.pkl files) in `MachiavelliConfig().experiment_trajectories_dir`, organized by subdirectories for each method

Outputs:
- main_results.csv: Aggregated results for each method
    - Each row is a different baseline method (add baseline methods to the list in the main function)
    - Each column is a different harm type
    - The cell corresponds to the harm count for that method and harm type, normalized with the mean and stddev from "normalization_coeffs.json"
- main_results_expanded.csv: Aggregated results for each method and game
    - As above, but each row is a (game, method) pair
"""


def load_results(method: str, test_set_only=True):
    print(f"Loading results for {method}...")
    traj_dir = MachiavelliConfig().experiment_trajectories_dir / method
    
    df_rows = []
    games_to_eval = MachiavelliConfig().games_test if test_set_only else MachiavelliConfig().working_games
    for traj_path in tqdm(sorted(list(traj_dir.glob("**/*.pkl")))):
        traj = Trajectory.load(traj_path)
        if traj.game_name not in games_to_eval:
            continue
        mach_scores = traj.get_mach_scores()
        df_rows.append({
            "game": traj.game_name,
            "method": method,
            "points": traj.get_points(normalize=True),
            "achievements": len(traj.get_achievements()),
            "steps": len(traj),
            **mach_scores,
        })
    df = pd.DataFrame(df_rows)
    return df

if __name__ == "__main__":
    test_set_only = True
    df = pd.concat([
        load_results("random_baselines_1k/RandomAgent", test_set_only=test_set_only),
        
        load_results("drrn/drrn_no_reg", test_set_only=test_set_only),
        load_results("drrn/drrn_reg_less", test_set_only=test_set_only),
        load_results("drrn/drrn_reg_more", test_set_only=test_set_only),

        load_results("lm_dev/LMAgent:davinci", test_set_only=test_set_only),
        load_results("lm_dev/LMAgent:gpt-3.5-turbo", test_set_only=test_set_only),
        load_results("lm_dev/CoTAgent:gpt-3.5-turbo", test_set_only=test_set_only),
        load_results("lm_dev/GoodAgent:gpt-3.5-turbo", test_set_only=test_set_only),
        load_results("lm_dev/GoodCoTAgent:gpt-3.5-turbo", test_set_only=test_set_only),

        load_results("lm_dev/CoTAgent:gpt-4", test_set_only=test_set_only),
        load_results("lm_dev/GoodAgent:gpt-4", test_set_only=test_set_only),
        load_results("lm_dev/GoodCoTAgent:gpt-4", test_set_only=test_set_only),

        load_results("lm_dev/GoodAgent:llama-13b", test_set_only=test_set_only),
        load_results("lm_dev/GoodAgent:llama-30b", test_set_only=test_set_only),
        load_results("lm_dev/GoodAgent:llama-7b", test_set_only=test_set_only),
        load_results("lm_dev/LMAgent:llama-13b", test_set_only=test_set_only),
        load_results("lm_dev/LMAgent:llama-30b", test_set_only=test_set_only),
        load_results("lm_dev/LMAgent:llama-7b", test_set_only=test_set_only),
    ])

    # Aggregate results by game and method, taking the mean of each column
    df_agg = df.groupby(["game", "method"]).mean().reset_index()
    # Add count to indicate how many rows were aggregated
    df_agg["samples"] = df.groupby(["game", "method"]).count().reset_index()["points"]
    df_agg.to_csv(MachiavelliConfig().experiment_results_dir / "main_results_expanded.csv", index=False)

    # Aggregate results by method, and calculate the mean of each column
    df_agg = df.groupby(["method"]).mean().reset_index()
    # Add count to indicate how many rows were aggregated
    df_agg["samples"] = df.groupby(["method"]).count().reset_index()["points"]
    # How many unique games were evaluated for each method?
    df_agg["games"] = df.groupby(["method"]).nunique().reset_index()["game"]
    df_agg.to_csv(MachiavelliConfig().experiment_results_dir / "main_results.csv", index=False)
