import os
import sys
import argparse
import torch
import numpy as np
from r_utils import env_rob_utils

from models import random_var
from env_rob_benchmarking import batch_eval_start_state, batch_eval_start_state_v2_mp, eval_agent_multi_process, \
    load_saved_apag_agent, get_base_ae
import processor
import processor_v2
from overcooked_ai_py.agents.agent import AgentPair


def eval_random(agent_name, layout, top_k=10, horizon=800, n_games=5, epi=2, exp_name='random', log_fn='random_log.txt'):

    print(agent_name, layout)
    agent_path = f'../predictors/assets/{layout}/agents/{agent_name}'
    agent, _, _ = env_rob_utils.load_saved_apag_agent(agent_path, horizon, deterministic=False)

    mdp_params = {
        "layout_name": layout,
    }
    env_params = {
        "horizon": horizon
    }
    ae = env_rob_utils.get_base_ae(mdp_params, env_params, None, None)
    mdp = ae.env.mdp

    start_states = [None] + random_var.random_gen_start_state_fn_v0(mdp, epi=epi, top_k=top_k)
    exp_names = [exp_name + '_' + layout] * len(start_states)
    exp_names[0] += 'gt'
    is_gt = [True] + [False] * (len(start_states) - 1)
    exp_info = f"layout={layout}, agent={agent_name}, epi = {epi}, top = {top_k}"
    batch_eval_start_state(agent, agent, mdp, start_states, is_gt, horizon, n_games, exp_names, visualize=True,
                           mkvid=False, log_fn=log_fn, result_fn='random_result.txt', info=exp_info)
    return


def attack_agent(args, agent_name, layout, top_k=10, horizon=800, eval_n=5, attack_n=10, epi=2, attack_method='stay_targeted',
                 debug=False, exp_id=''):
    agent_path_root = args.agent_path_root
    agent_appendix = args.agent_appendix
    n_procs = args.n_process
    result_root = args.result_root
    overwrite = args.overwrite
    result_fn = f'{attack_method}_{layout}_result.pkl'
    # result_fp = os.path.join(result_root, result_fn)
    # agent_path = os.path.join(agent_path_root, layout, f"{agent_name}_{agent_appendix}")
    # agent_path = os.path.join(agent_path_root, layout, f"{agent_name}_{agent_appendix}", 'best')
    agent_path = os.path.join(agent_path_root, layout, f"{agent_name}_{agent_appendix}" if agent_appendix != '' else agent_name, 'best')
    if not os.path.exists(agent_path):
        print('not best')
        agent_path = os.path.join(agent_path_root, layout,
                                  f"{agent_name}_{agent_appendix}" if agent_appendix != '' else agent_name)

    agent, _, _ = env_rob_utils.load_saved_apag_agent(agent_path, horizon, deterministic=False)

    mdp_params = {
        "layout_name": layout,
    }
    env_params = {
        "horizon": horizon
    }
    ae = env_rob_utils.get_base_ae(mdp_params, env_params, None, None)
    meta_data = {}
    data_params = processor.DEFAULT_DATA_PARAMS.copy()
    data_params['n_games'] = attack_n
    data_params['env_params']['horizon'] = horizon

    pr = processor.Processor(agent, ae.env, meta_data=meta_data, data_params=data_params)
    pr.prepare_data()
    pr.process_data()
    start_states = pr.get_attack(attack_method, epi=epi, top_k=top_k, debug=debug)

    exp_info = {
        'layout': layout,
        'agent_name': agent_name,
        'partner_name': agent_name,
        'epi': epi,
        'top_k': top_k
    }
    batch_eval_start_state_v2_mp(exp_info, agent, agent, ae, start_states, horizon, eval_n, n_procs=n_procs,
                                 result_dir=result_root, result_fn=result_fn, exp_id=exp_id, no_lock=args.no_lock, overwrite=overwrite)
    return


def sanity_check(args, agent_name, layout, top_k=10, horizon=800, attack_n=10, epi=2):
    agent_path_root = args.agent_path_root
    agent_appendix = args.agent_appendix
    attack_method = 'no_target'
    debug = True

    agent_path = os.path.join(agent_path_root, layout, f"{agent_name}_{agent_appendix}" if agent_appendix != '' else agent_name, 'best')
    if not os.path.exists(agent_path):
        print('not best')
        agent_path = os.path.join(agent_path_root, layout,
                                  f"{agent_name}_{agent_appendix}" if agent_appendix != '' else agent_name)

    agent, _, _ = env_rob_utils.load_saved_apag_agent(agent_path, horizon, deterministic=False)

    mdp_params = {
        "layout_name": layout,
    }
    env_params = {
        "horizon": horizon
    }
    ae = env_rob_utils.get_base_ae(mdp_params, env_params, None, None)
    meta_data = {}
    data_params = processor.DEFAULT_DATA_PARAMS.copy()
    data_params['n_games'] = attack_n
    data_params['env_params']['horizon'] = horizon

    pr = processor_v2.Processor(agent, ae.env, meta_data=meta_data, data_params=data_params)
    pr.prepare_data()
    pr.process_data()
    start_states_1 = pr.get_attack(attack_method, epi=epi, top_k=top_k, debug=debug)
    start_states_2 = pr.get_attack('no_target_efficient', epi=epi, top_k=top_k, debug=debug)
    for st1, st2 in zip(start_states_1, start_states_2):
        print(sorted(st1.objects.values(), key=lambda x: x.position), "\t", sorted(st2.objects.values(), key=lambda x: x.position))

    return


def eval_gt_main(args, agent_name, layout, horizon=800, eval_n=5):
    agent_path_root = args.agent_path_root
    agent_appendix = args.agent_appendix
    n_procs = args.n_process
    result_root = args.result_root

    result_fn = f'gt_{layout}_result.txt'
    result_fp = os.path.join(result_root, result_fn)

    print(agent_name, layout)
    # agent_path = os.path.join(agent_path_root, layout, f"{agent_name}_{agent_appendix}")
    # agent_path = os.path.join(agent_path_root, layout, f"{agent_name}_{agent_appendix}", 'best')
    agent_path = os.path.join(agent_path_root, layout, f"{agent_name}_{agent_appendix}" if agent_appendix != '' else agent_name, 'best')
    if not os.path.exists(agent_path):
        print('not best')
        agent_path = os.path.join(agent_path_root, layout,
                                  f"{agent_name}_{agent_appendix}" if agent_appendix != '' else agent_name)
    agent, _, _ = load_saved_apag_agent(agent_path, horizon, deterministic=False)

    mdp_params = {
        "layout_name": layout,
    }
    env_params = {
        "horizon": horizon
    }
    ae = get_base_ae(mdp_params, env_params, None, None)
    agent_pair = AgentPair(agent, agent, allow_duplicate_agents=True)

    mean, se, _ = eval_agent_multi_process(ae, agent_pair, eval_n, n_procs, start_state=None)

    with open(result_fp, 'a+') as f:
        f.write(f"{layout}: {agent_name}_{agent_appendix} sp = {mean} ± {se} over {eval_n} in {horizon} steps.\n")
    return


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--attack_method', default='stay_targeted')
    parser.add_argument('-a', '--agent_names', default=[], nargs='+')
    parser.add_argument('--agent_path_root', type=str, default='.')
    parser.add_argument('--agent_appendix', type=str, default='')
    parser.add_argument('-l', '--layout_names', default=[], nargs='+')
    parser.add_argument('--horizon', type=int, default=800)
    parser.add_argument('--eval_n', type=int, default=10)
    parser.add_argument('--attack_n', type=int, default=10)
    parser.add_argument('-p', '--n_process', type=int, default=25)
    parser.add_argument('--epi', type=int, default=3)
    parser.add_argument('-k', type=int, default=10)
    parser.add_argument('--exp_id', default='')
    parser.add_argument('--result_root', type=str, default='.')
    parser.add_argument('--no_lock', action='store_true', default=False)
    parser.add_argument('--overwrite', action='store_true', default=False)
    parser.add_argument('--sanity_check', action='store_true', default=False)
    args = parser.parse_args()

    torch.set_num_threads(1)

    agent_names = args.agent_names
    if len(args.layout_names) == 0:
        layouts = ["coordination_ring"]
    else:
        layouts = args.layout_names

    attack_method = args.attack_method
    eval_n = args.eval_n
    attack_n = args.attack_n
    horizon = args.horizon
    k = args.k
    epi = args.epi

    # below debug
    for i, agent_name in enumerate(agent_names):
        for layout in layouts:
            # eval_random(agent_name, layout, epi=3)
            # main(agent_name, layout, debug=False, epi=3)
            # agent_path = None if i >= len(args.agent_paths) else args.agent_paths[i]
            if attack_method == 'gt':
                eval_gt_main(args, agent_name, layout, horizon, eval_n)
            elif args.sanity_check:
                sanity_check(args, agent_name, layout, k, horizon, attack_n, epi)
            else:
                attack_agent(args, agent_name, layout, k, horizon, eval_n, attack_n, epi,
                             attack_method=attack_method, debug=False, exp_id=args.exp_id)
