
import os
import pickle
import numpy as np
import torch
        

def combine_envs(envs):
    X = []
    y = []
    X_dim = envs[0][0].shape[1]
    y_dim = envs[0][1].shape[1]
    
    num = 0
    for env in envs:
        X.append(env[0])
        y.append(env[1])
        num += len(env[1])
        
    X = torch.cat(X, dim=0)
    y = torch.cat(y, dim=0)
    return X.reshape(num, -1), y.reshape(num, -1)


def combine_envs_np(envs):
    for idx, env in enumerate(envs):
        if idx == 0:
            X, y = env
        else:
            X_i, y_i = env
            X = np.concatenate([X, X_i], axis=0)
            y = np.concatenate([y, y_i], axis=0)
    return X, y
