import numpy as np
from Environment.env import Environment
import collections
import pickle
from tqdm import tqdm
data = []
episodes = 100
f_num = 2
domain_num = 1000
T = 100
env = Environment(T=T, domain_num=domain_num, f_num=f_num, function_type="train", seed=0)
for e in range(episodes):
    episode_data = {}
    for k in ['observations', 'next_observations', 'actions', 'rewards', 'terminals']:
        episode_data[k] = []
    actions = np.random.choice(domain_num, T, replace=True)
    state_actions = np.random.random(domain_num*2*f_num+f_num+1).reshape(1,-1)
    print("e: {}".format(e))
    for t in tqdm(range(T)):
        action = actions[t]
        next_state_actions, reward, done, _ = env.step(action)
        
        episode_data['observations'].append(state_actions.reshape(-1).astype('float32'))
        episode_data['next_observations'].append(next_state_actions.reshape(-1).astype('float32'))
        episode_data['actions'].append(np.array([action]))
        episode_data['rewards'].append(reward)
        episode_data['terminals'].append(done)
        state_actions = next_state_actions
    for k in ['observations', 'next_observations', 'actions', 'rewards', 'terminals']:
        episode_data[k] = np.stack(episode_data[k])
        
    data.append(episode_data)
    env.reset(seed=3100+e*10, new_ls=True)
path = 'Decision_Transformer/decision_transformer/envs/data'
with open(path+'/train_f{}_domain{}_ep{}.pkl'.format(f_num, domain_num, episodes), 'wb') as f:
    pickle.dump(data, f)