import diffuser.utils as utils
from ml_logger import logger, RUN
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
from copy import deepcopy
import numpy as np
import os
import gym
from diffuser.utils.timer import Timer
from config.locomotion_config import Config
from diffuser.utils.arrays import to_torch, to_np, to_device
from diffuser.datasets.d4rl import suppress_output
from diffuser.models.value_func_model import ValueMLP
from diffuser.models.forward_dynamics import ForwardDynamics
from diffuser.models.bisimulation_metric_model import BisimNet
from diffuser.datasets.sequence import CustomSequenceDataset
from collections import namedtuple
from diffuser.utils.trajectory import Trajectory
from scripts.create_trajectory import save_traj
from scripts.create_trajectory import train_val
import pickle



def merge(**deps):
    RUN._update(deps)
    Config._update(deps)

    logger.remove('*.pkl')
    logger.remove("traceback.err")
    logger.log_params(Config=vars(Config), RUN=vars(RUN))

    Config.device = 'cuda'
    part_num = 1

    utils.set_seed(Config.seed)

    data = []

    for i in range(part_num):
        loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
        os.makedirs(loadpath, exist_ok=True)
        filename = "dataset" + str(i) + ".dat"
        loadpath = os.path.join(loadpath, filename)
        with open(loadpath, "rb") as f:
            temp = pickle.load(f)
        if i == 0:
            data = temp
        else:
            data = data + temp

    savepath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    os.makedirs(savepath, exist_ok=True)
    filename = "new_dataset.dat"
    savepath = os.path.join(savepath, filename)
    with open(savepath, "wb") as f:
        pickle.dump(data, f)
