import numpy as np
import functools
import math
import torch
import pickle
import os

mat_data_path = "/home/LAB/qiuyue/language_condition_MAT/scripts/data/smac_data/2c_vs_64zg_mat_2k/StarCraft2.pkl"
hatrpo_data_path = "/home/LAB/qiuyue/TRPO-in-MARL-master/scripts/data/2c_vs_64zg-hatrpo-2k/StarCraft2.pkl"
happo_data_path = "/home/LAB/qiuyue/TRPO-in-MARL-master/scripts/data/2c_vs_64zg-happo-2k/StarCraft2.pkl"

with open(mat_data_path, 'rb') as f:
    mat_data_set = pickle.load(f)
with open(hatrpo_data_path, 'rb') as f:
    hatrpo_data_set = pickle.load(f)
with open(happo_data_path, 'rb') as f:
    happo_data_set = pickle.load(f)

# print('----------------- mat data --------------------')
# for key in mat_data_set[0]:
#     print(key, mat_data_set[0][key][0].shape if isinstance(mat_data_set[0][key], (list, tuple)) else mat_data_set[0][key])
#
# print('----------------- hatrpo data --------------------')
# for key in hatrpo_data_set[0]:
#     print(key, hatrpo_data_set[0][key][0].shape if isinstance(hatrpo_data_set[0][key], (list, tuple)) else hatrpo_data_set[0][key])
#
# print('----------------- happo data --------------------')
# for key in happo_data_set[0]:
#     print(key, happo_data_set[0][key][0].shape if isinstance(happo_data_set[0][key], (list, tuple)) else happo_data_set[0][key])

mix_algo_data_set, step_num = [], 0
for episode_data in mat_data_set:
    episode_data['agent_tags'] = [0 for _ in range(len(episode_data['state']))]
    mix_algo_data_set.append(episode_data)
    step_num += len(episode_data['state'])

for episode_data in hatrpo_data_set:
    episode_data['agent_tags'] = [1 for _ in range(len(episode_data['state']))]
    mix_algo_data_set.append(episode_data)
    step_num += len(episode_data['state'])

for episode_data in happo_data_set:
    episode_data['agent_tags'] = [2 for _ in range(len(episode_data['state']))]
    mix_algo_data_set.append(episode_data)
    step_num += len(episode_data['state'])

# print(mix_algo_data_set[0]['agent_tags'])
# print(mix_algo_data_set[12]['agent_tags'])
# print(mix_algo_data_set[22]['agent_tags'])

mix_algo_data_dir = "/home/LAB/qiuyue/language_condition_MAT/scripts/data/smac_data/2c_vs_64zg_2k_mixAlgo"
if not os.path.isdir(mix_algo_data_dir):
    os.makedirs(mix_algo_data_dir)
with open(mix_algo_data_dir + '/StarCraft2.pkl', 'wb') as f:
    pickle.dump(mix_algo_data_set, f)
print(len(mix_algo_data_set))
print(step_num)
