import pickle as pkl
import argparse
import random
import shutil

from r_utils import env_rob_utils
from train_agent.benchmarking1 import get_dummy_ae
import subprocess
import os
import json

from result_logger import AttackResultLogger, RESULT_PATH

attack_methods = ['no_target', 'random', 'random+']
sp_agent_names = [f'sp_1204-{i}' for i in range(5)]
fcp_agent_names = [f'fcp_1204-{i}' for i in range(5)]
layouts = ["coordination_ring"]


parser = argparse.ArgumentParser()
parser.add_argument('-a', '--agent_names', nargs='+', default=[])
parser.add_argument('-l', '--layouts', nargs='+', default=[])
parser.add_argument('-d', '--dst', default='predictor')
parser.add_argument('-t', '--t_path', default='sp')
parser.add_argument('-m', '--mode', default='copy')
parser.add_argument('--method', default='no_target')
parser.add_argument('--methods', default=[], nargs='+')
parser.add_argument('--horizon', type=int, default=800)
parser.add_argument('--eval_n', type=int, default=10)
parser.add_argument('--attack_n', type=int, default=10)
parser.add_argument('--send_ns', default=[10], nargs='+')
parser.add_argument('--epi', type=int, default=3)
parser.add_argument('-k', type=int, default=10)
parser.add_argument('--ks', type=int, default=[], nargs='+')
parser.add_argument('--exp_ids', default=[], nargs='+')
parser.add_argument('-f', '--f_exp_name', default='12_31_t0')

args = parser.parse_args()
method = args.method
horizon = args.horizon
eval_n = args.eval_n
attack_n = args.attack_n
epi = args.epi
top_k = args.k
exp_ids = args.exp_ids
f_exp_name = args.f_exp_name
send_ns = args.send_ns

if len(args.layouts) > 0:
    layouts = args.layouts
if len(args.agent_names) > 0:
    agents = args.agent_names
else:
    agents = sp_agent_names + fcp_agent_names

print(agents, layouts)

base_dst_dir = '../train_agent/saved_models/sp'


def try_send_states(agent_name, layout, states, state_fn='start_state.pkl'):
    dst_dir = os.path.join(base_dst_dir, layout, agent_name, 'start_states')
    os.makedirs(dst_dir, exist_ok=True)
    dst_fn = os.path.join(dst_dir, state_fn)
    if os.path.exists(dst_fn):
        print(f"{state_fn} already exists")
        return
    else:
        pkl.dump(states, open(dst_fn, 'wb'))
    return


def main():
    methods = args.methods
    if len(methods) == 0:
        methods = [method]

    all_data = []
    for m in methods:
        result_path = os.path.join(RESULT_PATH, f'{m}_result.pkl')
        if not os.path.exists(result_path):
            data = dict()
            for layout in layouts:
                data.update(pkl.load(open(f'{RESULT_PATH}/{m}_{layout}_result.pkl', 'rb')))
        else:
            data = pkl.load(open(os.path.join(result_path), 'rb'))
        all_data.append(data)

    for layout in layouts:
        for agent_name in agents:
            print(layout, agent_name)
            partner_name = agent_name

            send_states = []
            for i, data in enumerate(all_data):
                metadata = (horizon, epi, partner_name, args.ks[i], eval_n, exp_ids[i])
                rec = data[layout][agent_name][metadata]
                states = [r['start_state'] for r in rec]
                # send_states += states[:send_n]
                send_states += random.sample(states, int(send_ns[i]))
            try_send_states(agent_name, layout, send_states, state_fn=f"{f_exp_name}.pkl")
    return


if __name__ == '__main__':
    get_dummy_ae()
    # result_path = os.path.join(RESULT_PATH, f'{method}_result.pkl')
    # data = pkl.load(open(os.path.join(result_path), 'rb'))
    main()

