import diffuser.utils as utils
from ml_logger import logger, RUN
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
from copy import deepcopy
import numpy as np
import os
import gym
from diffuser.utils.timer import Timer
from config.locomotion_config import Config
from diffuser.utils.arrays import to_torch, to_np, to_device
from diffuser.datasets.d4rl import suppress_output
from diffuser.models.value_func_model import ValueMLP
from diffuser.models.forward_dynamics import ForwardDynamics
from diffuser.models.bisimulation_metric_model import BisimNet
from diffuser.datasets.sequence import CustomSequenceDataset
from collections import namedtuple
from diffuser.utils.trajectory import Trajectory
from scripts.create_trajectory import save_traj
from scripts.create_trajectory import train_val
import pickle



def split(**deps):
    RUN._update(deps)
    Config._update(deps)

    logger.remove('*.pkl')
    logger.remove("traceback.err")
    logger.log_params(Config=vars(Config), RUN=vars(RUN))

    Config.device = 'cuda'

    utils.set_seed(Config.seed)

    dataset_config = utils.Config(
        Config.loader,
        savepath='dataset_config.pkl',
        env=Config.dataset,
        horizon=Config.horizon,
        normalizer=Config.normalizer,
        preprocess_fns=Config.preprocess_fns,
        use_padding=Config.use_padding,
        max_path_length=Config.max_path_length,
        include_returns=Config.include_returns,
        returns_scale=Config.returns_scale,
    )

    dataset = dataset_config()

    path_num = dataset.fields.normed_observations.shape[0]
    paths = []
    obs_numpy = np.array([])
    first_epoch = True
    discounts = Config.discount ** np.arange(1001)[:, None]
    max_reward = 0
    max_return = 0
    loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    loadpath = os.path.join(loadpath, "new_dataset.dat")
    if first_epoch:
        for path_ind in range(path_num):
            # path_ind, start, end = dataset.indices[i]
            observations = dataset.fields.normed_observations[path_ind, :dataset.fields.path_lengths[path_ind]]
            if path_ind == 0:
                obs_numpy = observations
            else:
                obs_numpy = np.concatenate((obs_numpy, observations), axis=0)
            actions = dataset.fields.normed_actions[path_ind, :dataset.fields.path_lengths[path_ind]]
            rewards = dataset.fields.rewards[path_ind, :dataset.fields.path_lengths[path_ind]]
            if rewards.size != 0:
                reward_max_curr = np.max(dataset.fields.rewards[path_ind, :dataset.fields.path_lengths[path_ind]])
                if reward_max_curr > max_reward:
                    max_reward = reward_max_curr
                returns = np.sum(dataset.fields.rewards[path_ind, :dataset.fields.path_lengths[path_ind]] * discounts[:dataset.fields.path_lengths[path_ind]])
                if returns > max_return:
                    max_return = returns
                next_item = np.concatenate((dataset.fields.normed_observations[path_ind, 1:dataset.fields.path_lengths[path_ind]],
                                            np.array([dataset.fields.normed_observations[path_ind, dataset.fields.path_lengths[path_ind] - 1]])), axis=0)
                paths.append([observations, actions, rewards, next_item])
    else:
        with open(loadpath, "rb") as f:
            paths = pickle.load(f)
        path_num = len(paths)
        for path_ind in range(path_num):
            if path_ind == 0:
                obs_numpy = paths[path_ind][0]
            else:
                obs_numpy = np.concatenate((obs_numpy, paths[path_ind][0]), axis=0)

    print(max_return)
    print(max_reward)

    part_length = path_num // 1
    part = []
    part.append(paths[:part_length])

    # for i in range(4):
    savepath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    filename = "part" + str(0) + ".dat"
    savepath = os.path.join(savepath, filename)
    with open(savepath, "wb") as f:
        pickle.dump(part[0], f)
