import diffuser.utils as utils
from ml_logger import logger, RUN
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
from copy import deepcopy
import numpy as np
import os
import gym
from diffuser.utils.timer import Timer
from config.locomotion_config import Config
from diffuser.utils.arrays import to_torch, to_np, to_device
from diffuser.datasets.d4rl import suppress_output
from diffuser.models.value_func_model import ValueMLP
from diffuser.models.forward_dynamics import ForwardDynamics
from diffuser.models.bisimulation_metric_model import BisimNet
from diffuser.datasets.sequence import CustomSequenceDataset
from collections import namedtuple
from diffuser.utils.trajectory import Trajectory
from scripts.create_trajectory import save_traj
from scripts.create_trajectory import train_val
import pickle



def sort(**deps):
    RUN._update(deps)
    Config._update(deps)

    logger.remove('*.pkl')
    logger.remove("traceback.err")
    logger.log_params(Config=vars(Config), RUN=vars(RUN))

    Config.device = 'cuda'

    utils.set_seed(Config.seed)

    dataset_config = utils.Config(
        Config.loader,
        savepath='dataset_config.pkl',
        env=Config.dataset,
        horizon=Config.horizon,
        normalizer=Config.normalizer,
        preprocess_fns=Config.preprocess_fns,
        use_padding=Config.use_padding,
        max_path_length=Config.max_path_length,
        include_returns=Config.include_returns,
        returns_scale=Config.returns_scale,
    )

    dataset = dataset_config()

    path_num_original = dataset.fields.normed_observations.shape[0]
    paths = []
    obs_numpy = np.array([])
    first_epoch = False
    loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    loadpath = os.path.join(loadpath, "new_dataset.dat")

    with open(loadpath, "rb") as f:
        paths = pickle.load(f)
    path_num = len(paths)

    returns = []

    for idx, path in enumerate(paths):
        path_return = np.sum(path[2])
        returns.append([idx, path_return])

    returns = np.array(returns)

    sorted_indices_des = returns[:, 1].argsort()[::-1]

    sorted_returns = returns[sorted_indices_des]

    clipped_returns = sorted_returns[:path_num_original]
    print("first three returns", clipped_returns[:3, 1])
    clipped_idx = clipped_returns[:, 0].astype(np.int32)
    # clipped_paths = paths[clipped_idx]
    clipped_paths = []
    for idx in clipped_idx:
        clipped_paths.append(paths[idx])
    savepath = loadpath
    print("saving clipped paths")
    with open(savepath, "wb") as f:
        pickle.dump(clipped_paths, f)


