import copy
import os
import sys
import platform
import threading
# import multiprocessing
import multiprocess as mp

from matplotlib import pyplot as plt
import matplotlib
import argparse
import torch
import numpy as np
from tqdm import tqdm
import datetime
import pickle as pkl

root_path2 = '..'
root_path3 = '../train_agent'
sys.path.append(root_path2)
sys.path.append(root_path3)

TASK_PATH = './data/tasks'
FIG_PATH = './data/figs'
VID_PATH = './data/vids'
RESULT_PATH = './data/results'

from train_agent.benchmarking1 import mean_and_std_err, get_base_ae, ApagAgentNewVersion
from overcooked_ai_py.agents.benchmarking import AgentEvaluator
from overcooked_ai_py.agents.agent import AgentPair
from result_logger import AttackResultLogger, RESULT_PATH


def load_saved_apag_agent(ppo_model_dir, horizon=None, deterministic=True, idx=0, device='cpu', sample_prob=1.0):
    policy_fp = os.path.join(ppo_model_dir, 'policy.pt')
    config_fp = os.path.join(ppo_model_dir, 'config.pt')
    policy = torch.load(policy_fp, map_location=torch.device(device))
    config = torch.load(config_fp)

    mdp_params = config['env_config']['mdp_params']
    if 'rew_shaping_params' in mdp_params:
        config['env_config']['mdp_params'].__delitem__('rew_shaping_params')
    eval_params = config['env_config']
    if horizon is not None:
        eval_params['env_params'].update({'horizon': horizon})
    # important, since horizon does affect state encoder

    ae = get_base_ae(eval_params["mdp_params"], eval_params["env_params"], None, None)
    featurize_fn = ae.env.lossless_state_encoding_mdp
    # print(deterministic, idx)
    agent = ApagAgentNewVersion(policy, featurize_fn, deterministic=deterministic, agent_index=idx, sample_prob=sample_prob, device=device)
    return agent, config, ae


def eval_agent_on_mdp(agent_0, agent_1, mdp, env_params, n_games, start_state=None,
                      save_fig=False, fig_fn='', mkvid=False, vid_fn='', use_batch=False, batch_max=10):

    ae = AgentEvaluator.from_mdp(mdp, env_params)
    if not use_batch:
        rollouts = ae.evaluate_agent_pair(AgentPair(agent_0, agent_1, allow_duplicate_agents=True), num_games=n_games,
                                      start_state_fn=(lambda: start_state) if start_state is not None else None)
        ep_returns = rollouts["ep_returns"]
    else:
        ep_returns = []
        for n in [batch_max] * (n_games // batch_max) + ([n_games % batch_max] if n_games % batch_max !=0 else []):
            rollouts = ae.evaluate_agent_pair(AgentPair(agent_0, agent_1, allow_duplicate_agents=True), num_games=n,
                                          start_state_fn=(lambda: start_state) if start_state is not None else None)
            ep_returns.append(rollouts["ep_returns"].copy())
        ep_returns = np.vstack(ep_returns)

    mean, se = mean_and_std_err(ep_returns)
    # mean, se = mean_and_std_err(rollouts["ep_returns"])
    return mean, se, rollouts


def eval_agent_single_thread(a_ev: AgentEvaluator, a_pair, n_ga, ss, result_l, rand_seed):
    # print('start', n_ga)
    np.random.seed(rand_seed)
    torch.manual_seed(rand_seed)
    ep_returns = []
    ap = copy.deepcopy(a_pair)
    a_ev.env.mdp.start_state = ss
    for _ in range(n_ga):
        rollouts = a_ev.evaluate_agent_pair(ap, num_games=1, info=False)
        ep_returns.append(rollouts["ep_returns"])
    ep_returns = np.concatenate(ep_returns)
    # print(ep_returns)
    result_l.append(ep_returns)

    return ep_returns


def eval_agent_multi_process(ae: AgentEvaluator, agent_pair, n_games, n_procs, start_state=None):

    n_games_per_subprocess = [n_games // n_procs] * n_procs
    for i in range(n_games % n_procs):
        n_games_per_subprocess[i] += 1

    env_params = {"horizon": ae.env.horizon}
    # print(n_games_per_subprocess)
    manager = mp.Manager()
    result_l = manager.list()
    pool = mp.Pool(processes=n_procs)
    for i in range(n_procs):
        new_ae = AgentEvaluator.from_mdp(ae.env.mdp, env_params)
        t_args = (new_ae, agent_pair, n_games_per_subprocess[i], start_state, result_l, np.random.randint(0, 1e7))
        result = pool.apply_async(func=eval_agent_single_thread, args=t_args)
        # all_results.append(result)
        # print(f"proc {i} start")
    pool.close()
    pool.join()

    assert len(result_l) == n_procs

    # print(result_l)
    all_results = np.concatenate(result_l)
    mean, se = mean_and_std_err(all_results)
    # exit(0)
    return mean, se, None


def batch_eval_start_state(agent1, agent2, mdp, start_states, is_gt: list, horizon=800, n_games=5, exp_name=None, visualize=False,
                           mkvid=False, log_fn=None, result_fn=None, info=''):
    env_params = {
        "horizon": horizon
    }
    # gt = eval_agent_on_mdp(agent1, agent2, mdp, env_params, n_games)
    # print(f'gt = {gt[0]} +- {gt[1]}')

    if not isinstance(exp_name, list):
        exp_name = [exp_name] * len(start_states)

    assert len(is_gt) == len(start_states)
    exp_info = []
    scores = []
    gt_info = None
    for i, st_st in enumerate(start_states):
        fig_fn = '_'.join([exp_name[i], str(i)])
        vid_fn = '_'.join([exp_name[i], str(i)])
        mean, se, _ = eval_agent_on_mdp(agent1, agent2, mdp, env_params, n_games, st_st, visualize, fig_fn, mkvid, vid_fn)
        print(f"{exp_name[i]} {i}: {mean} +- {se}")
        exp_info.append(f"{exp_name[i]} {i}: {mean} +- {se}")
        if not is_gt[i]:
            scores.append(mean)
        else:
            gt_info = f"gt: {mean} +- {se} \n"

    if log_fn is not None:
        with open(log_fn, 'a+') as f:
            f.write('\n'.join(exp_info) + '\n')

    meta_info = datetime.datetime.now().strftime("%y-%m-%d-%H:%M:%S") + '\t' + info + '\n'
    if result_fn is not None:
        with open(result_fn, 'a+') as f:
            f.write(meta_info)
            if gt_info is not None:
                f.write(gt_info)
            f.write(f"scores={np.mean(scores)} +- {np.std(scores)}, min = {np.min(scores)}, max = {np.max(scores)}" + '\n')


def batch_eval_start_state_v2(exp_info, agent1, agent2, mdp, start_states, horizon=800, n_games=5, exp_name=None, visualize=False,
                           mkvid=False, result_fn=None, exp_id=''):
    env_params = {
        "horizon": horizon
    }
    if not isinstance(exp_name, list):
        exp_name = [exp_name] * len(start_states)

    assert result_fn is not None, "result_fn can not be None"

    logger = AttackResultLogger(result_fn, file_format='pkl')
    layout, agent_name, partner_name, epi, top_k = \
        exp_info['layout'], exp_info['agent_name'], exp_info['partner_name'], exp_info['epi'], exp_info['top_k']

    metadata = (horizon, epi, partner_name, top_k, n_games, exp_id)
    for i, st_st in enumerate(start_states):
        fig_fn = '_'.join([exp_name[i], str(i)])
        vid_fn = '_'.join([exp_name[i], str(i)])
        mean, se, _ = eval_agent_on_mdp(agent1, agent2, mdp, env_params, n_games, st_st, visualize, fig_fn, mkvid, vid_fn, use_batch=True)
        print(f"{exp_name[i]} {i}: {mean} +- {se}")
        record = {
            'mean': mean,
            'se': se,
            'start_state': st_st
        }
        logger.add_record(layout, agent_name, metadata, record)

    return logger


def batch_eval_start_state_v2_mp(exp_info, agent1, agent2, ae, start_states, horizon=800, n_games=5, n_procs=20,
                                 result_dir=RESULT_PATH, result_fn=None, exp_id='', no_lock=False, overwrite=False):

    assert result_fn is not None, "result_fn can not be None"
    agent_pair = AgentPair(agent1, agent2, allow_duplicate_agents=True)

    logger = AttackResultLogger(result_fn, log_dir=result_dir, file_format='pkl', no_lock=no_lock)
    layout, agent_name, partner_name, epi, top_k = \
        exp_info['layout'], exp_info['agent_name'], exp_info['partner_name'], exp_info['epi'], exp_info['top_k']

    metadata = (horizon, epi, partner_name, top_k, n_games, exp_id)
    if overwrite:
        logger.del_records(layout, agent_name, metadata)
    for i, st_st in enumerate(start_states):
        mean, se, _ = eval_agent_multi_process(ae, agent_pair, n_games, n_procs, start_state=st_st)
        record = {
            'mean': mean,
            'se': se,
            'start_state': st_st
        }
        logger.add_record(layout, agent_name, metadata, record)

    return logger


def eval_gt(agent_name, layout, horizon=800, eval_n=5):
    # log_fn = f'gt_log.txt'
    # result_fn = f'gt_result.pkl'

    result_fn = f'gt_result.txt'

    print(agent_name, layout)
    agent_path = f'../predictors/assets/{layout}/agents/{agent_name}'
    agent, _, _ = load_saved_apag_agent(agent_path, horizon, deterministic=False)

    mdp_params = {
        "layout_name": layout,
    }
    env_params = {
        "horizon": horizon
    }
    ae = get_base_ae(mdp_params, env_params, None, None)
    ap = AgentPair(agent, agent, allow_duplicate_agents=True)
    rollouts = ae.evaluate_agent_pair(ap, num_games=eval_n)
    mean, se = mean_and_std_err(rollouts["ep_returns"])

    with open(result_fn, 'a+') as f:
        f.write(f"{layout}: {agent_name} + {agent_name} = {mean} ± {se} over {eval_n} in {horizon} steps.\n")
    return


def eval_gt_v2(agent_name, agent_path, layout, horizon=800, eval_n=5):
    # log_fn = f'gt_log.txt'
    # result_fn = f'gt_result.pkl'

    result_fn = f'gt_result.txt'
    meta_data = (agent_name, horizon, eval_n)

    print(agent_name, layout)

    agent, _, _ = load_saved_apag_agent(agent_path, horizon, deterministic=False)

    mdp_params = {
        "layout_name": layout,
    }
    env_params = {
        "horizon": horizon
    }
    ae = get_base_ae(mdp_params, env_params, None, None)
    ap = AgentPair(agent, agent, allow_duplicate_agents=True)
    rollouts = ae.evaluate_agent_pair(ap, num_games=eval_n)
    mean, se = mean_and_std_err(rollouts["ep_returns"])

    with open(result_fn, 'a+') as f:
        f.write(f"{layout}: {agent_name} + {agent_name} = {mean} ± {se} over {eval_n} in {horizon} steps.\n")
    return


def load_agent_api(agent_name, layout, horizon=800):
    agent_path = f'../predictors/assets/{layout}/agents/{agent_name}'
    agent, config, ae = load_saved_apag_agent(agent_path, horizon, deterministic=False)
    return agent, config, ae


def batch_cross_play(main_agent_pools, partner_pools, layout, horizon=800):
    pass


def record_result():
    pass
