# Better import this file when exec
import os
import subprocess
import sys
root_path = '..'
root_path2 = '../third_party_libs'
sys.path.append(root_path)
sys.path.append(root_path2)

import platform
import torch
import numpy as np
import pickle
import argparse
import json
from overcooked_ai_py.agents.agent import Action, AgentPair, StayAgent, RandomAgent
from overcooked_ai_py.agents.benchmarking import AgentEvaluator


DATA_PATH = 'assets'
ALLOWED_AGENT_TYPES = ['stay', 'random', 'apag']
DUMMY_AGENT_NAME = ['stay', 'random']
RAW_TRAJ_DIR = 'raw_trajs'
PROC_TRAJ_DIR = 'proc_trajs'
PROC_TRAJ_DIR_V2 = 'proc_trajs_v2'

default_conditions = [
    {'deterministic': True},
    # {'deterministic': False, 'sample_prob': 0.5}
    {'deterministic': False}
]
from train_agent.benchmarking1 import ApagAgentNewVersion, get_base_ae, mean_and_std_err


def load_saved_apag_agent(ppo_model_dir, horizon=None, deterministic=True, idx=0, device='cpu', sample_prob=1.0):
    policy_fp = os.path.join(ppo_model_dir, 'policy.pt')
    config_fp = os.path.join(ppo_model_dir, 'config.pt')
    policy = torch.load(policy_fp, map_location=torch.device(device))
    config = torch.load(config_fp)

    mdp_params = config['env_config']['mdp_params']
    if 'rew_shaping_params' in mdp_params:
        config['env_config']['mdp_params'].__delitem__('rew_shaping_params')
    eval_params = config['env_config']
    if horizon is not None:
        eval_params['env_params'].update({'horizon': horizon})
    # important, since horizon does affect state encoder

    ae = get_base_ae(eval_params["mdp_params"], eval_params["env_params"], None, None)
    featurize_fn = ae.env.lossless_state_encoding_mdp
    # print(deterministic, idx)
    agent = ApagAgentNewVersion(policy, featurize_fn, deterministic=deterministic, agent_index=idx, sample_prob=sample_prob, device=device)
    return agent, config, ae


def auto_load_agent(layout, agent_name, is_main=False, verbose=False, **agent_params):
    if agent_name == 'stay':
        # return StayAgent()
        return StayAgent(), None, None

    if agent_name == 'random':
        # return RandomAgent()
        return RandomAgent(), None, None

    agent_path = os.path.join(DATA_PATH, layout, 'agents', agent_name)
    info_path = os.path.join(agent_path, 'agent_info.json')
    if verbose:
        print(agent_path)
    assert os.path.exists(agent_path)
    assert os.path.exists(info_path)

    agent_info = json.load(open(info_path, 'r', encoding='utf-8'))
    if agent_info['agent_type'] == 'apag':
        agent, config, ae = load_saved_apag_agent(agent_path, **agent_params)
        agent.name = agent_name
        return agent, config, ae

    else:
        raise NotImplementedError


def eval_saved_agent_pair(agent_0, agent_1, layout, horizon=400, num_games=1, mdp_params=None, keep_info=False):
    if mdp_params is None:
        mdp_params = {
            "layout_name": layout,
        }
    env_params = {
        "horizon": horizon
    }
    ae = get_base_ae(mdp_params, env_params, None, None)

    rollouts = ae.evaluate_agent_pair(AgentPair(agent_0, agent_1, allow_duplicate_agents=True), num_games=num_games)

    if not keep_info:
        rollouts.__delitem__('ep_infos')
    mean, se = mean_and_std_err(rollouts["ep_returns"])

    return mean, se, rollouts


def save_raw_traj(traj, main_agent_name, partner_name, main_agent_pos, layout, horizon, additional_info):
    raw_trajs_dir_path = os.path.join(DATA_PATH, layout, 'raw_trajs')
    # traj_fn = '_'.join([main_agent_name, agent_name_0, agent_name_1, str(horizon), str(additional_info)]) + '.pkl'
    traj_fn = '_'.join([main_agent_name, partner_name, str(main_agent_pos), str(horizon), str(additional_info)]) + '.pkl'
    traj_fp = os.path.join(raw_trajs_dir_path, traj_fn)
    assert os.path.exists(raw_trajs_dir_path)
    f = open(traj_fp, 'wb')
    pickle.dump(traj, f)
    return traj_fp


def load_raw_traj(main_agent_name, partner_name, main_agent_pos, layout, horizon, additional_info):
    raw_trajs_dir_path = os.path.join(DATA_PATH, layout, RAW_TRAJ_DIR)
    # traj_fn = '_'.join([main_agent_name, agent_name_0, agent_name_1, str(horizon), str(additional_info)]) + '.pkl'
    traj_fn = '_'.join([main_agent_name, partner_name, str(main_agent_pos), str(horizon), str(additional_info)]) + '.pkl'
    traj_fp = os.path.join(raw_trajs_dir_path, traj_fn)
    assert os.path.exists(raw_trajs_dir_path)
    f = open(traj_fp, 'rb')
    traj = pickle.load(f)
    return traj


def is_exist_raw_traj(main_agent_name, partner_name, main_agent_pos, layout, horizon, additional_info):
    raw_trajs_dir_path = os.path.join(DATA_PATH, layout, RAW_TRAJ_DIR)
    # traj_fn = '_'.join([main_agent_name, agent_name_0, agent_name_1, str(horizon), str(additional_info)]) + '.pkl'
    traj_fn = '_'.join([main_agent_name, partner_name, str(main_agent_pos), str(horizon), str(additional_info)]) + '.pkl'
    traj_fp = os.path.join(raw_trajs_dir_path, traj_fn)
    return os.path.exists(traj_fp)


def is_exist_raw_traj_2(main_agent_name, partner_name, layout, horizon, additional_info):
    raw_trajs_dir_path = os.path.join(DATA_PATH, layout, RAW_TRAJ_DIR)
    traj_fn_1 = '_'.join([main_agent_name, partner_name, '0', str(horizon), str(additional_info)]) + '.pkl'
    traj_fp_1 = os.path.join(raw_trajs_dir_path, traj_fn_1)
    traj_fn_2 = '_'.join([main_agent_name, partner_name, '1', str(horizon), str(additional_info)]) + '.pkl'
    traj_fp_2 = os.path.join(raw_trajs_dir_path, traj_fn_2)
    return os.path.exists(traj_fp_1) and os.path.exists(traj_fp_2)


def is_exist_proced_traj_v2(main_agent_name, partner_name, layout, horizon, additional_info, version='2'):
    proc_traj_dir = PROC_TRAJ_DIR if version == '1' else PROC_TRAJ_DIR_V2
    proc_trajs_dir_path = os.path.join(DATA_PATH, layout, proc_traj_dir)
    traj_fn_1 = '_'.join([main_agent_name, partner_name, '0', str(horizon), str(additional_info)]) + '.pkl'
    traj_fp_1 = os.path.join(proc_trajs_dir_path, traj_fn_1)
    traj_fn_2 = '_'.join([main_agent_name, partner_name, '1', str(horizon), str(additional_info)]) + '.pkl'
    traj_fp_2 = os.path.join(proc_trajs_dir_path, traj_fn_2)
    return os.path.exists(traj_fp_1) and os.path.exists(traj_fp_2)


def process_traj(main_agent: ApagAgentNewVersion, main_agent_pos, traj):
    policy = main_agent.actor_critic
    net = policy.base
    featurize_fn = main_agent.featurize
    net.eval()
    states = traj['ep_states']
    actions = traj['ep_actions']

    targetss = []

    # now assert net is not recurrent
    obss, targets = [], []
    for i, ts in enumerate(states):
        obs = []
        rnn_hxs = torch.zeros(1, policy.recurrent_hidden_state_size, dtype=torch.int32)
        masks = None
        for t in ts:
            raw_o = torch.tensor(featurize_fn(t)[main_agent_pos], dtype=torch.float).unsqueeze(0)
            v, o, rnn_hxs = net(raw_o, rnn_hxs, masks)
            obs.append(o.detach().cpu().numpy())
        # obss.append(obs)
        obss.append(np.vstack(obs))

    for i, acts in enumerate(actions):
        targets = []
        for act in acts:
            this_act = act[main_agent_pos]
            targets.append(Action.ACTION_TO_INDEX[this_act])
        targetss.append(targets)

    data = [obss, targetss]

    return data


def to_one_hot(v, max_len):
    feat = np.zeros(max_len)
    feat[v] = 1
    return feat


def process_traj_v2(main_agent: ApagAgentNewVersion, main_agent_pos, traj, ae: AgentEvaluator):
    policy = main_agent.actor_critic
    net = policy.base
    featurize_fn = main_agent.featurize
    state_featurize_fn = ae.env.featurize_state_mdp
    net.eval()
    torch.set_num_threads(10)
    net = net.to('cpu')

    states = traj['ep_states']
    actions = traj['ep_actions']
    score = traj['ep_returns']

    targetss = []

    # now assert net is not recurrent
    obss, targets = [], []
    for i, ts in enumerate(states):
        obs = []
        rnn_hxs = torch.zeros(1, policy.recurrent_hidden_state_size, dtype=torch.int32)
        masks = None
        for j, t in enumerate(ts):
            raw_o = torch.tensor(featurize_fn(t)[main_agent_pos], dtype=torch.float).unsqueeze(0)
            v, o, rnn_hxs = net(raw_o, rnn_hxs, masks)
            feat = o.detach().cpu().numpy()
            st_feat = state_featurize_fn(t)[main_agent_pos]  # keep dims
            act_feat = np.concatenate([to_one_hot(Action.ACTION_TO_INDEX[act], Action.NUM_ACTIONS) for act in actions[i][j]])
            st_feat, act_feat = np.expand_dims(st_feat, axis=0), np.expand_dims(act_feat, axis=0)
            feat = np.concatenate([feat, st_feat, act_feat], axis=-1)
            # print(act_feat.shape, feat.shape)
            obs.append(feat)
        # obss.append(obs)
        obss.append(np.vstack(obs))

    for i, acts in enumerate(actions):
        targets = []
        for act in acts:
            this_act = act[main_agent_pos]
            targets.append(Action.ACTION_TO_INDEX[this_act])
        targetss.append(targets)

    data = [[obss, targetss], score]

    return data


def save_processed_traj(traj, main_agent_name, partner_name, main_agent_pos, layout, horizon, additional_info, version='1'):
    proc_traj_dir = PROC_TRAJ_DIR if version == '1' else PROC_TRAJ_DIR_V2
    proc_trajs_dir_path = os.path.join(DATA_PATH, layout, proc_traj_dir)
    # traj_fn = '_'.join([main_agent_name, agent_name_0, agent_name_1, str(horizon), str(additional_info)]) + '.pkl'
    traj_fn = '_'.join([main_agent_name, partner_name, str(main_agent_pos), str(horizon), str(additional_info)]) + '.pkl'
    traj_fp = os.path.join(proc_trajs_dir_path, traj_fn)
    if not os.path.exists(proc_trajs_dir_path):
        os.mkdir(proc_trajs_dir_path)
    f = open(traj_fp, 'wb')
    pickle.dump(traj, f)
    return traj_fp


def load_processed_traj(main_agent_name, partner_name, main_agent_pos, layout, horizon, additional_info, version='1'):
    proc_traj_dir = PROC_TRAJ_DIR if version == '1' else PROC_TRAJ_DIR_V2
    proc_trajs_dir_path = os.path.join(DATA_PATH, layout, proc_traj_dir)
    # traj_fn = '_'.join([main_agent_name, agent_name_0, agent_name_1, str(horizon), str(additional_info)]) + '.pkl'
    traj_fn = '_'.join([main_agent_name, partner_name, str(main_agent_pos), str(horizon), str(additional_info)]) + '.pkl'
    traj_fp = os.path.join(proc_trajs_dir_path, traj_fn)
    assert os.path.exists(proc_trajs_dir_path)
    f = open(traj_fp, 'rb')
    traj = pickle.load(f)
    return traj


def dict_to_str(d:dict):
    str_info = '-'.join([f'{str(k)}={str(v)}' for k, v in d.items()])
    return str_info


def get_predictor_path(layout, main_agent_name, model_name, predictor_name, auto_mkdir=True):
    predictor_dir = os.path.join(DATA_PATH, layout, 'predictors', '_'.join([main_agent_name, model_name, predictor_name]))
    if auto_mkdir and not os.path.exists(predictor_dir):
        os.mkdir(predictor_dir)
    return predictor_dir


def cut_data(data, proportion):
    # data = [inputs, targets] inputs = List[] 1 * horizon * obs
    data = [[data[i][0][:int(len(data[i][0]) * proportion)]] for i in range(len(data))]
    return data


def proc_data(data, seq_len, pred_len, step, should_action_mask):
    seq_inputs = []
    seq_targets = []
    inputs, targets = data
    inputs, targets = np.array(inputs), np.array(targets)
    # print(inputs.shape, targets.shape)
    eps_len = inputs.shape[-2]

    if len(inputs.shape) == 2:
        inputs, targets = np.expand_dims(inputs, 0), np.expand_dims(targets, 0)
    for j, inps in enumerate(inputs):
        for i in range(0, eps_len - seq_len - pred_len, step):
            inp = inps[i:i + seq_len].copy()
            if should_action_mask:
                inp[-1][-12:] = 0
            seq_inputs.append(inp)
            target = targets[j][i + seq_len: i + seq_len + pred_len]
            seq_targets.append(target)
            # this version, immediate action is not in targets

    seq_inputs = np.array(seq_inputs).astype(np.float)
    seq_targets = np.array(seq_targets).astype(np.int32)
    seq_lens = np.array([seq_len] * seq_inputs.shape[0])
    return seq_inputs, seq_lens, seq_targets


def load_data(args: argparse.Namespace, data_params, verbose=True, detail=False, score_threshold=60):
    # data = pickle.load(open('bc_torch_data_new.pkl', 'rb'))
    # p_sample_o (12000, 32), sample_a (12000, ) np.array

    layout = args.layout_name
    model_type = args.model
    agent_name = args.agent_name
    partner_names = args.partner.split(',')
    horizon = args.horizon
    if not hasattr(args, 'data_version'):
        data_version = '1'
    else:
        data_version = args.data_version
    if not hasattr(args, 'val_partner'):
        val_partners = []
    else:
        val_partners = args.val_partner.split(',')
    should_action_mask = data_version == '2'

    if args.config_file_path is not None and os.path.exists(args.config_file_path):
        conditions = json.load(open(args.config_file_path, 'r', encoding='utf-8'))
    else:
        conditions = default_conditions

    if args.proportion is not None:
        data_proportion = [float(p) for p in args.proportion.split(',')]
        assert len(data_proportion) == len(partner_names)
    else:
        data_proportion = [1.] * len(partner_names)

    all_data = []
    data_info = []
    val_data = []
    val_data_info = []
    for i, partner_name in enumerate(partner_names):
        if partner_name in DUMMY_AGENT_NAME:
            conds = [{}]
        else:
            conds = conditions

        for cond in conds:
            data, score = load_processed_traj(agent_name, partner_name, 0, layout, horizon, dict_to_str(cond), version=data_version)
            if sum(score) > score_threshold:
                all_data.append(cut_data(data, data_proportion[i]))
                data_info.append([partner_name, cond, 0])
            data, score = load_processed_traj(agent_name, partner_name, 1, layout, horizon, dict_to_str(cond), version=data_version)
            if sum(score) > score_threshold:
                all_data.append(cut_data(data, data_proportion[i]))
                data_info.append([partner_name, cond, 1])

    for i, partner_name in enumerate(val_partners):
        conds = conditions
        for cond in conds:
            # print("val:", cond)
            if cond['deterministic'] is True:
                continue
            data, score = load_processed_traj(agent_name, partner_name, 0, layout, horizon, dict_to_str(cond), version=data_version)
            val_data.append(data)
            val_data_info.append([partner_name, cond, 0])
            data, score = load_processed_traj(agent_name, partner_name, 1, layout, horizon, dict_to_str(cond), version=data_version)
            val_data.append(data)
            val_data_info.append([partner_name, cond, 1])

    if model_type == 'base_lstm_seq':
        seq_len = data_params['seq_len']
        pred_len = data_params['pred_len']
        step = data_params['step']

        test_data_dicts = {}
        val_data_dicts = {}
        for data_i, data in enumerate(all_data):
            proced_data = proc_data(data, seq_len, pred_len, step, should_action_mask)
            test_data_dicts.update({str(data_info[data_i]): proced_data})

        for data_i, data in enumerate(val_data):
            proced_data = proc_data(data, seq_len, pred_len, step, should_action_mask)
            val_data_dicts.update({str(val_data_info[data_i]): proced_data})

        if not detail:
            all_seq_inputs, all_seq_lens, all_seq_targets = [], [], []
            val_seq_inputs, val_seq_lens, val_seq_targets = [], [], []
            for v in test_data_dicts.values():
                all_seq_inputs.append(v[0])
                all_seq_lens.append(v[1])
                all_seq_targets.append(v[2])
            all_seq_inputs, all_seq_lens, all_seq_targets = np.concatenate(all_seq_inputs), np.concatenate(all_seq_lens), np.concatenate(all_seq_targets)
            test_data = [all_seq_inputs, all_seq_lens, all_seq_targets]
            if len(val_data_dicts) > 0:
                for v in val_data_dicts.values():
                    val_seq_inputs.append(v[0])
                    val_seq_lens.append(v[1])
                    val_seq_targets.append(v[2])
                val_seq_inputs, val_seq_lens, val_seq_targets = np.concatenate(val_seq_inputs), np.concatenate(val_seq_lens), np.concatenate(val_seq_targets)
                val_data = [val_seq_inputs, val_seq_lens, val_seq_targets]
            else:
                val_data = None
            return test_data, val_data
        else:
            return test_data_dicts, val_data_dicts
    else:
        raise NotImplementedError


def load_info_file(layout, file_name, default_value=None):
    base_path = os.path.join(DATA_PATH, layout)
    fn = os.path.join(base_path, file_name)
    if not os.path.exists(fn):
        return default_value

    if '.json' in fn:
        return json.load(open(fn, 'r', encoding='utf-8'))

    if '.pkl' in fn or '.pickle' in fn:
        return pickle.load(open(fn, 'rb'))


def save_info_file(layout, file_name, data):
    base_path = os.path.join(DATA_PATH, layout)
    fn = os.path.join(base_path, file_name)
    if '.json' in fn:
        json.dump(data, open(fn, 'w', encoding='utf-8'))

    elif '.pkl' in fn or '.pickle' in fn:
        return pickle.dump(data, open(fn, 'wb'))

    return fn


def copy_agent(layout, src_name, dst_name, override=False, verbose=True):
    base_path = os.path.join(DATA_PATH, layout)
    ori_dir = os.path.join(base_path, 'agents', src_name)
    dst_dir = os.path.join(base_path, 'agents', dst_name)

    if not os.path.exists(ori_dir) or (os.path.exists(dst_dir) and not override):
        return f"error"

    cmd = ('cp -r' if not override else '\\cp -rf') + f' {ori_dir} {dst_dir}'
    if verbose:
        print(cmd)
    subprocess.call(cmd, shell=True)
    return "Success"

