from collections import Counter
import numpy as np
import pandas as pd
from tqdm import tqdm

from machiavelli.config import MachiavelliConfig
from machiavelli.game.machiavelli_env import Trajectory

"""
How does the distributions of completed achievements (good vs bad) change between the default agent and the regularized agent?

- For each trajectory, calculate the number of good and bad achievements completed
- For each game, calculate the mean number of good and bad achievements completed
"""


def load_results(method: str):
    print(f"Loading results for {method}...")
    traj_dir = MachiavelliConfig().experiment_trajectories_dir / method
    
    df_rows = []
    games_to_eval = MachiavelliConfig().games_test
    for traj_path in tqdm(sorted(list(traj_dir.glob("**/*.pkl")))):
        traj = Trajectory.load(traj_path)
        if traj.game_name not in games_to_eval:
            continue
        ach_metadata = MachiavelliConfig().get_game_metadata(traj.game_name)['achievements']
        counts = Counter([ach_metadata[ach]['classification'] for ach in traj.get_achievements()])

        game_points = sum([ach_metadata[ach]['points'] for ach in ach_metadata])
        game_good_points = sum([ach_metadata[ach]['points'] for ach in ach_metadata if ach_metadata[ach]['classification'] == 1 or ach_metadata[ach]['classification'] == 0])

        traj_points = traj.get_points()
        traj_good_points = sum([ach_metadata[ach]['points'] for ach in traj.get_achievements() if ach_metadata[ach]['classification'] == 1 or ach_metadata[ach]['classification'] == 0])
        traj_bad_points = sum([ach_metadata[ach]['points'] for ach in traj.get_achievements() if ach_metadata[ach]['classification'] == -1])

        # print("game_good_points", game_good_points)

        df_rows.append({
            "game": traj.game_name,
            "method": method,
            # "game_achievements": len(ach_metadata),
            # "game_good": len([ach for ach in ach_metadata if ach_metadata[ach]['classification'] == 1 or ach_metadata[ach]['classification'] == 0]),
            # "game_good_proportion": len([ach for ach in ach_metadata if ach_metadata[ach]['classification'] == 1 or ach_metadata[ach]['classification'] == 0]) / len(ach_metadata),
            # "achievements": len(traj.get_achievements()),
            # "n_good": counts[1] + counts[0],
            # "n_bad": counts[-1],
            # "good_proportion": (counts[1] + counts[0]) / len(traj.get_achievements()),

            "game_points": game_points,
            "game_good_points": game_good_points,
            "game_good_points_proportion": game_good_points / game_points,
            "traj_points": traj_points,
            "traj_good_points": traj_good_points,
            "traj_bad_points": traj_bad_points,
            "traj_good_points_proportion": traj_good_points / traj_points
            
        })
    df = pd.DataFrame(df_rows)
    return df

if __name__ == "__main__":
    test_set_only = True
    df = pd.concat([
        load_results("drrn/drrn_no_reg"),
        load_results("drrn/drrn_reg_more"),
        load_results("lm_dev/CoTAgent:gpt-4"),
        load_results("lm_dev/GoodCoTAgent:gpt-4"),
    ])

    # Aggregate results by method, and calculate the mean of each column
    df_agg = df.groupby(["method"]).mean().reset_index()
    df_agg.to_csv(MachiavelliConfig().experiment_results_dir / "count_achievement_change.csv", index=False)
