from tqdm import tqdm
import argparse
import os
from pathlib import Path
from typing import List
import torch
import numpy as np
from datetime import datetime
import pandas as pd

from machiavelli.agent.power_regularizer import PowerRegularizer
from tqdm import tqdm

from machiavelli.game.machiavelli_env import MachiavelliEnv, Trajectory
from machiavelli.config import MachiavelliConfig
from machiavelli.utils import random_seed
from machiavelli.agent.lm_agents import CoTAgent, GoodAgent, GoodCoTAgent, LMAgent
cfg = MachiavelliConfig()


def run_trajectories(agent, env, outdir, num_episodes=25, max_traj_length=1000) -> List[Trajectory]:
    # Initialize the environment and agent.
    trajs = []
    points = 0
    ach = 0

    print(f"------ Testing {agent} on {env.game} ------")

    for episode_idx in tqdm(range(num_episodes)):
        random_seed(episode_idx+2417)
        if agent.should_reset_each_episode:
            agent.reset()

        # Start new episode.
        obs, info = env.reset(seed=episode_idx+2417)
        trajectory = Trajectory(env.game, agent_name=str(agent))
        agent.test()

        rew = 0
        done = False
        nb_moves = 0
        while not done:
            action = agent.get_action(obs, rew, done, info)
            trajectory.add(obs, action, rew, info)
            obs, rew, done, info = env.step(action)
            nb_moves += 1
            # # Terminate early if we've completed all target achievements
            # if len(info['game_state']['remaining_achievements']) == 0:
            #     # print("Completed all target achievements! Terminating early.")
            #     done = True
            # Terminate early if we've reached the maximum trajectory length
            if len(trajectory) >= max_traj_length:
                print("Reached maximum trajectory length! Terminating early.")
                done = True

        trajectory.add(obs, -1, rew, info)
        points += trajectory.get_points(normalize=True)
        ach += len(info['game_state']['unlocked_achievements'])

        # print("Trajectory complete!")
        # print(trajectory, flush=True)

        # Save trajectory to file
        outdir.mkdir(parents=True, exist_ok=True)
        outfile = Path(outdir) / f"{trajectory.timestamp}.pkl"
        outfile.parent.mkdir(parents=True, exist_ok=True)
        trajectory.save(outfile)

        trajs.append(trajectory)
    env.close()
    print(f"Average points: {points / num_episodes}")
    print(f"Average achievements: {ach / num_episodes}")
    return trajs


def make_single_agent(agent_type: str, args):
    if agent_type == 'rl_agent':
        import os
        info = os.path.splitext(os.path.basename(args.load_agent))[0]
        args.game = info.split('_')[0]
        if args.game == 'vampire-the-masquerade':
            args.game = 'vampire-the-masquerade_night-road'
        data = info.split('thresh')[1].split('_')
        data.append('False')
        threshold, alpha, beta, econ = data[0], data[1], data[2], data[3]
        #assert ('True' in econ)
        lm = PowerRegularizer(
            'roberta-base',
            'models/power_roberta-base_2e-05_32_20.ckpt',
            float(threshold),
            power_category=(7 if 'True' in econ else 0)
        )
        alpha = float(alpha.replace('alpha', ''))
        beta = float(beta.replace('beta', ''))
        agent = torch.load(args.load_agent)
        agent.lm = lm
        agent.alpha = alpha
        agent.beta = beta
        agent.configure = True
    elif agent_type == 'RandomAgent':
        from machiavelli.agent.random_agent import RandomAgent
        agent = RandomAgent()
    elif agent_type == 'HumanAgent':
        from machiavelli.agent.human_agent import HumanAgent
        assert args.playthrough_id is not None
        actions = cfg.get_human_demo(args.playthrough_id)
        assert actions['game'] == args.game
        agent = HumanAgent(actions['actions'])
    elif agent_type.startswith('LMAgent'):
        print("Loading LMAgent", agent_type)
        agent = LMAgent(model=agent_type.split(':')[-1])
    elif agent_type.startswith('GoodAgent'):
        print("Loading GoodAgent", agent_type)
        agent = GoodAgent(model=agent_type.split(':')[-1])
    elif agent_type.startswith('CoTAgent'):
        print("Loading CoTAgent", agent_type)
        agent = CoTAgent(model=agent_type.split(':')[-1])
    elif agent_type.startswith('GoodCoTAgent'):
        print("Loading GoodCoTAgent", agent_type)
        agent = GoodCoTAgent(model=agent_type.split(':')[-1])
    else:
        raise NotImplementedError
    return agent


def make_agent(args):
    agent = make_single_agent(args.agent, args)
    if args.agent2 is not None:
        agent2 = make_single_agent(args.agent2, args)
        from machiavelli.agent.interpolated_agent import InterpolatedAgent
        agent = InterpolatedAgent(agent, agent2, args.interp_alpha)
    return agent


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--game', type=str, default='')
    parser.add_argument('--agent', required=True, choices=['rl_agent',
                                                           'random_agent',
                                                           'human_agent',
                                                           'lm_agent', ]
                        )
    parser.add_argument('--agent2', required=False, choices=['rl_agent',
                                                             'random_agent',
                                                             'human_agent',
                                                             'lm_agent', ]
                        )
    parser.add_argument('--interp_alpha', type=float, default=None)
    parser.add_argument('--num_episodes', type=int, default=25)
    parser.add_argument('--baseline_data', type=str, default=None)
    parser.add_argument('--lm_name', type=str, default='roberta-base')
    parser.add_argument('--alpha', type=float, default=0.0)
    parser.add_argument('--beta', type=float, default=1.0)
    parser.add_argument('--regularizer_lm_name',
                        type=str, default='roberta-base')
    parser.add_argument('--logit_threshold', type=float, default=0.5)

    parser.add_argument('--achievements', nargs='+', default=None)
    parser.add_argument('--load_agent', type=str, default=None)
    parser.add_argument('--playthrough_id', type=str, default=None)
    parser.add_argument('--run_from_folder', type=str, default=None)
    parser.add_argument('--name', type=str, default=None)
    return parser.parse_args()


def main():
    args = get_args()
    if args.run_from_folder:
        for agent_path in os.listdir(args.run_from_folder):
            print(agent_path)
        assert False
        for agent_path in os.listdir(args.run_from_folder):
            args.load_agent = os.path.join(args.run_from_folder, agent_path)
            agent = make_agent(args)
            env = MachiavelliEnv(
                game=args.game,
                desired_achievements=args.achievements,
            )
            # print("------------------{args.game}------------------")
            trajs = run_trajectories(
                agent,
                env,
                outdir=cfg.agent_trajectories_dir / args.name / args.game,
                num_episodes=args.num_episodes,
            )

            res = [{'game': args.game,
                    'agent': str(agent),
                    'timestamp': datetime.now().strftime("%Y%m%d-%H%M%S"),

                    'steps': np.mean([len(t) for t in trajs]),
                    'achievements': np.mean([len(t.get_achievements()) for t in trajs]),
                    'points': np.mean([t.get_points() for t in trajs]),
                    'norm_points': np.mean([t.get_points(normalize=True) for t in trajs]),

                    'deception': np.mean([t.get_deception() for t in trajs]),
                    'selfishness': np.mean([t.get_selfishness() for t in trajs]),
                    'resources': np.mean([t.get_power('resources') for t in trajs]),
                    'money': np.mean([t.get_power('money') for t in trajs]),
                    'influence': np.mean([t.get_power('social') for t in trajs]),
                    'utility': np.mean([t.get_power('utility') for t in trajs]),
                    'watts': np.mean([t.get_power('watts') for t in trajs]),
                    }]
            df = pd.DataFrame.from_records(res)
            os.makedirs(cfg.agent_results_dir, exist_ok=True)
            df.to_csv(cfg.agent_results_dir /
                      f'{args.game}_{args.agent}_eval.csv')
        return
    games = [args.game] if args.game != 'ALL' else cfg.games
    for game in games:
        agent = make_agent(args)
        env = MachiavelliEnv(
            game=game,
            desired_achievements=args.achievements,
        )
        trajs = run_trajectories(
            agent,
            env,
            outdir=cfg.agent_trajectories_dir / args.name / game,
            num_episodes=args.num_episodes,
        )

        res = [{'game': game,
                'agent': args.name,
                'timestamp': datetime.now().strftime("%Y%m%d-%H%M%S"),

                'steps': np.mean([len(t) for t in trajs]),
                'achievements': np.mean([len(t.get_achievements()) for t in trajs]),
                'points': np.mean([t.get_points() for t in trajs]),
                'norm_points': np.mean([t.get_points(normalize=True) for t in trajs]),

                'deception': np.mean([t.get_deception() for t in trajs]),
                'selfishness': np.mean([t.get_selfishness() for t in trajs]),
                'resources': np.mean([t.get_power('resources') for t in trajs]),
                'money': np.mean([t.get_power('money') for t in trajs]),
                'influence': np.mean([t.get_power('social') for t in trajs]),
                'utility': np.mean([t.get_power('utility') for t in trajs]),
                'watts': np.mean([t.get_power('watts') for t in trajs]),
                }]
        df = pd.DataFrame.from_records(res)
        os.makedirs(cfg.agent_results_dir, exist_ok=True)
        df.to_csv(cfg.agent_results_dir / f'{game}_{args.agent}_eval.csv')


if __name__ == '__main__':
    main()
