import diffuser.utils as utils
from ml_logger import logger, RUN
import torch
from torch import nn
import matplotlib.pyplot as plt
from torch.distributions.multivariate_normal import MultivariateNormal
from copy import deepcopy
import numpy as np
import os
import gym
from diffuser.utils.timer import Timer
from config.locomotion_config import Config
from diffuser.utils.arrays import to_torch, to_np, to_device
from diffuser.datasets.d4rl import suppress_output
from diffuser.models.value_func_model import ValueMLP
from diffuser.models.forward_dynamics import ForwardDynamics
from diffuser.models.bisimulation_metric_model import BisimNet
from diffuser.datasets.sequence import CustomSequenceDataset
from collections import namedtuple
from diffuser.utils.trajectory import Trajectory
from scripts.create_trajectory import save_traj
from scripts.create_trajectory import train_val
import pickle


RewardBatch = namedtuple('Batch', 'trajectories conditions returns rewards')


def save_data(dataset, bucket, part_num):
    savepath = os.path.join(bucket, logger.prefix, 'checkpoint')
    os.makedirs(savepath, exist_ok=True)
    filename = "dataset" + str(part_num) + ".dat"
    savepath = os.path.join(savepath, filename)
    with open(savepath, "wb") as f:
        pickle.dump(dataset, f)


def cycle(dl):
    while True:
        for data in dl:
            yield data


def concat_state(states, batch_size, transpose=False):
    state_expand = torch.unsqueeze(states, dim=0)
    state_tile = state_expand.tile((batch_size, 1, 1))
    state_tile_t = torch.transpose(state_tile, dim0=0, dim1=1)
    if transpose:
        concat_states = torch.cat((state_tile_t, state_tile), dim=2)
    else:
        concat_states = torch.cat((state_tile, state_tile_t), dim=2)
    representation_dim = states.shape[1]
    return torch.reshape(concat_states, (batch_size ** 2, representation_dim * 2))


def create_neighbour(paths, s, s_prime, device, paths_lengths_cumsum, path_length=1000, eps=0.8, min_thres=0.00):
    path_num = len(paths)
    path_idx = 0
    step_idx = 0
    neighbour = np.array([])
    temp = 0
    temp1 = 0
    paths_lengths_cumsum = torch.from_numpy(paths_lengths_cumsum).to(device)
    target = torch.from_numpy(s_prime).to(device)
    target = torch.unsqueeze(target, dim=0)

    obs_whole = paths
    obs_input = torch.from_numpy(obs_whole).to(device)
    obs_batch_size = obs_input.shape[0]
    new_target = torch.tile(target, (obs_batch_size, 1))
    obs_curr = torch.concatenate((new_target, obs_input), dim=1)
    batch_split = obs_curr.shape[0] // 2
    # dists1 = bisim_model(obs_curr[:batch_split, :])
    # dists2 = bisim_model(obs_curr[batch_split:, :])

    # epsilon_ball
    pdist = nn.PairwiseDistance(p=2)
    dists = pdist(new_target, obs_input)

    # dists = torch.concatenate((dists1, dists2), dim=0)
    dists_np = dists.detach().cpu().numpy()

    check_neg = torch.any(dists < 0).item()
    print(check_neg)
    dists_size = dists.shape[0]
    # for i in range(dists_size):
    #     if dists[i] < eps:
    #         neighbour.append([path_idx, step_idx, obs_whole[i, :]])
    #         # temp1 += 1
    #     step_idx += 1
    #     if step_idx == 1000:
    #         path_idx += 1
    #         step_idx = 0
    # dists_mask = torch.squeeze(torch.where((dists < eps) & (dists > min_thres), dists, 0), dim=1)
    dists_mask = torch.where((dists < eps) & (dists > min_thres), dists, 0)
    dists_idx = torch.argwhere(dists_mask)
    if torch.numel(dists_idx):
        print("not empty")
        obs_select = torch.squeeze(obs_input[dists_idx, :], dim=1)
        # mod_dists_idx = torch.fmod(dists_idx, path_length)
        # path_idx = dists_idx.div(path_length, rounding_mode="trunc")
        path_idx = torch.searchsorted(paths_lengths_cumsum, dists_idx, right=True)
        # plt.hist(dists_np, bins=1000)
        # plt.show()
        # if path_idx == 0:
        #     mod_dists_idx = dists_idx - 0
        # else:
        #     mod_dists_idx = dists_idx - paths_lengths_cumsum[path_idx - 1]
        mod_dists_idx = torch.where(torch.zeros(path_idx.shape).to(device) == path_idx, dists_idx, dists_idx - paths_lengths_cumsum[path_idx - torch.ones(path_idx.shape, dtype=torch.int64).to(device)])
        obs_select = obs_select.cpu().numpy()
        path_idx = path_idx.cpu().numpy()
        mod_dists_idx = mod_dists_idx.cpu().numpy()
        neighbour = np.concatenate((path_idx, mod_dists_idx, obs_select), axis=1)
    else:
        neighbour = np.array([])





    # TODO: change it into one single batch

    # for path in paths:
    #     # for i in range(len(path[0])):
    #     #     dist = bisim_model(torch.concatenate([torch.from_numpy(path[0][i]).to(device), torch.from_numpy(s_prime).to(device)]))
    #     #     if dist < eps:
    #     #         neighbour.append([path_idx, i, path[0][i]])
    #     obs = torch.from_numpy(path[0]).to(device)
    #     obs_num = obs.shape[0]
    #     new_target = torch.tile(target, (obs_num, 1))
    #     dists = bisim_model(torch.concatenate((new_target, obs), dim=1))
    #     dists_size = dists.shape[0]
    #     dists_mask = torch.squeeze(torch.where(dists < eps, dists, 0), dim=1)
    #     dists_idx = torch.argwhere(dists_mask)
    #     obs_select = torch.squeeze(obs[dists_idx, :], dim=1)
    #     dists_idx = dists_idx.cpu().numpy()
    #     path_total = np.tile(np.array([path_idx]), reps=dists_idx.shape)
    #     obs_select = obs_select.cpu().numpy()
    #     path_batch = np.concatenate((path_total, dists_idx, obs_select), axis=1)
    #
    #     if path_idx == 0:
    #         neighbour = path_batch
    #     else:
    #         neighbour = np.concatenate((neighbour, path_batch), axis=0)
    #
    #     path_idx += 1
    # #     for i in range(dists_size):
    # #         temp += 1
    # #         if dists[i] < eps:
    # #             neighbour.append([path_idx, i, to_np(obs[i])])
    # #             temp1 += 1
    # #     path_idx += 1
    # # logger.print("total", temp)
    # # logger.print("less than 0.1", temp1)
    neighbour = neighbour.astype(np.float32)
    return neighbour



def stitching(**deps):

    RUN._update(deps)
    Config._update(deps)

    logger.remove('*.pkl')
    logger.remove("traceback.err")
    logger.log_params(Config=vars(Config), RUN=vars(RUN))

    Config.device = 'cuda'

    if Config.predict_epsilon:
        prefix = f'predict_epsilon_{Config.n_diffusion_steps}_1000000.0'
    else:
        prefix = f'predict_x0_{Config.n_diffusion_steps}_1000000.0'

    loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    next_loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')

    if Config.save_checkpoints:
        loadpath = os.path.join(loadpath, f'state_{self.step}.pt')
    else:
        loadpath = os.path.join(loadpath, 'state.pt')

    if Config.save_checkpoints:
        next_loadpath = os.path.join(next_loadpath, f'state_{self.step}.pt')
    else:
        next_loadpath = os.path.join(next_loadpath, 'state_include_next.pt')

    state_dict = torch.load(loadpath, map_location=Config.device)
    next_state_dict = torch.load(next_loadpath, map_location=Config.device)

    # Load configs
    torch.backends.cudnn.benchmark = True
    utils.set_seed(Config.seed)

    dataset_config = utils.Config(
        Config.loader,
        savepath='dataset_config.pkl',
        env=Config.dataset,
        horizon=Config.horizon,
        normalizer=Config.normalizer,
        preprocess_fns=Config.preprocess_fns,
        use_padding=Config.use_padding,
        max_path_length=Config.max_path_length,
        include_returns=Config.include_returns,
        returns_scale=Config.returns_scale,
    )

    render_config = utils.Config(
        Config.renderer,
        savepath='render_config.pkl',
        env=Config.dataset,
    )

    dataset = dataset_config()
    renderer = render_config()

    observation_dim = dataset.observation_dim
    action_dim = dataset.action_dim

    if Config.diffusion == 'models.GaussianInvDynDiffusion' and Config.next_diffusion == 'models.NextGaussianInvDynDiffusion':
        transition_dim = observation_dim
    else:
        transition_dim = observation_dim + action_dim

    model_config = utils.Config(
        Config.model,
        savepath='model_config.pkl',
        horizon=Config.horizon,
        transition_dim=transition_dim,
        cond_dim=observation_dim,
        dim_mults=Config.dim_mults,
        dim=Config.dim,
        returns_condition=Config.returns_condition,
        device=Config.device,
    )

    next_model_config = utils.Config(
        Config.next_model,
        savepath='model_config.pkl',
        horizon=Config.horizon,
        transition_dim=transition_dim,
        cond_dim=observation_dim,
        dim_mults=Config.dim_mults,
        dim=Config.dim,
        returns_condition=Config.returns_condition,
        device=Config.device,
    )

    diffusion_config = utils.Config(
        Config.diffusion,
        savepath='diffusion_config.pkl',
        horizon=Config.horizon,
        observation_dim=observation_dim,
        action_dim=action_dim,
        n_timesteps=Config.n_diffusion_steps,
        loss_type=Config.loss_type,
        clip_denoised=Config.clip_denoised,
        predict_epsilon=Config.predict_epsilon,
        hidden_dim=Config.hidden_dim,
        ## loss weighting
        action_weight=Config.action_weight,
        loss_weights=Config.loss_weights,
        loss_discount=Config.loss_discount,
        returns_condition=Config.returns_condition,
        device=Config.device,
        condition_guidance_w=Config.condition_guidance_w,
    )

    next_diffusion_config = utils.Config(
        Config.next_diffusion,
        savepath='diffusion_config.pkl',
        horizon=Config.horizon,
        observation_dim=observation_dim,
        action_dim=action_dim,
        n_timesteps=Config.n_diffusion_steps,
        loss_type=Config.loss_type,
        clip_denoised=Config.clip_denoised,
        predict_epsilon=Config.predict_epsilon,
        hidden_dim=Config.hidden_dim,
        ## loss weighting
        action_weight=Config.action_weight,
        loss_weights=Config.loss_weights,
        loss_discount=Config.loss_discount,
        returns_condition=Config.returns_condition,
        device=Config.device,
        condition_guidance_w=Config.condition_guidance_w,
    )

    trainer_config = utils.Config(
        utils.Trainer,
        savepath='trainer_config.pkl',
        train_batch_size=Config.batch_size,
        train_lr=Config.learning_rate,
        gradient_accumulate_every=Config.gradient_accumulate_every,
        ema_decay=Config.ema_decay,
        sample_freq=Config.sample_freq,
        save_freq=Config.save_freq,
        log_freq=Config.log_freq,
        label_freq=int(Config.n_train_steps // Config.n_saves),
        save_parallel=Config.save_parallel,
        bucket=Config.bucket,
        n_reference=Config.n_reference,
        train_device=Config.device,
    )

    model = model_config()
    next_model = next_model_config()
    diffusion = diffusion_config(model)
    next_diffusion = next_diffusion_config(next_model)
    trainer = trainer_config(diffusion, dataset, renderer)
    next_trainer = trainer_config(next_diffusion, dataset, renderer)
    logger.print(utils.report_parameters(model), color='green')

    trainer.step = state_dict['step']
    trainer.model.load_state_dict(state_dict['model'])
    trainer.ema_model.load_state_dict(state_dict['ema'])

    next_trainer.step = next_state_dict['step']
    next_trainer.model.load_state_dict(next_state_dict['model'])
    next_trainer.ema_model.load_state_dict(next_state_dict['ema'])

    state_dim = observation_dim
    V_model_1 = ValueMLP(hidden_dim=256, input_dim=state_dim, output_dim=1).to(Config.device)
    V_model_2 = ValueMLP(hidden_dim=256, input_dim=state_dim, output_dim=1).to(Config.device)
    value_loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    value_loadpath = os.path.join(value_loadpath, 'V_model.pt')
    value_state_dict = torch.load(value_loadpath, map_location=Config.device)
    V_model_1.load_state_dict(value_state_dict['model1'])
    V_model_2.load_state_dict(value_state_dict['model2'])

    forward_models = []
    forward_loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    forward_loadpath = os.path.join(forward_loadpath, 'forward.pt')
    forward_state_dict = torch.load(forward_loadpath, map_location=Config.device)
    for i in range(7):
        hidden_dim = 200
        model = ForwardDynamics(state_dim=state_dim, hidden_dim=hidden_dim).to(Config.device)
        model_idx = "model" + str(i)
        model.load_state_dict(forward_state_dict[model_idx])
        forward_models.append(model)

    # path_num = len(dataset.indices)
    path_num = dataset.fields.normed_observations.shape[0]
    paths = []
    first_epoch = False
    obs_numpy = np.array([])
    loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    loadpath = os.path.join(loadpath, "new_dataset.dat")
    # os.makedirs(loadpath, exist_ok=True)
    if first_epoch:
        for path_ind in range(path_num):
            # path_ind, start, end = dataset.indices[i]
            observations = dataset.fields.normed_observations[path_ind, :dataset.fields.path_lengths[path_ind]]
            if path_ind == 0:
                obs_numpy = observations
            else:
                obs_numpy = np.concatenate((obs_numpy, observations), axis=0)
            actions = dataset.fields.normed_actions[path_ind, :dataset.fields.path_lengths[path_ind]]
            rewards = dataset.fields.rewards[path_ind, :dataset.fields.path_lengths[path_ind]] / 4  # modified reward
            next_item = np.concatenate((dataset.fields.normed_observations[path_ind, 1:dataset.fields.path_lengths[path_ind]], np.array([dataset.fields.normed_observations[path_ind, dataset.fields.path_lengths[path_ind] - 1]])), axis=0)
            paths.append([observations, actions, rewards, next_item])
    else:
        with open(loadpath, "rb") as f:
            paths = pickle.load(f)
        path_num = len(paths)
        for path_ind in range(path_num):
            if path_ind == 0:
                obs_numpy = paths[path_ind][0]
            else:
                obs_numpy = np.concatenate((obs_numpy, paths[path_ind][0]), axis=0)

    loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    os.makedirs(loadpath, exist_ok=True)
    part_num = 0
    part_length = path_num // 1
    loadpath = os.path.join(loadpath, "part0.dat")
    with open(loadpath, "rb") as f:
        paths_part = pickle.load(f)

    # logger.print(paths[-1][0].shape, paths[-1][1].shape, paths[-1][2].shape, paths[-1][3].shape)

    num_layer = 1
    # bisim_model = BisimNet(state_dim=state_dim, num_layers=num_layer).to(Config.device)
    # bisim_loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    # bisim_loadpath = os.path.join(bisim_loadpath, 'bisim.pt')
    # bisim_state_dict = torch.load(bisim_loadpath, map_location=Config.device)
    # bisim_model.load_state_dict(bisim_state_dict['online'])

    discount = 0.99
    returns_scale = 400
    max_path_length = 1000
    discounts = discount ** np.arange(max_path_length)[:, None]
    epoch_num = 1
    train_diffuser_epoch = 10
    horizon = 10
    threshold = 0.1
    returns = to_device(Config.test_ret * torch.ones(1, 1), Config.device)
    # dataloader = cycle(torch.utils.data.DataLoader(
    #     dataset, batch_size=batch_size, num_workers=0, shuffle=True, pin_memory=True
    # ))
    timer = Timer()
    paths_lengths = []
    new_paths_temp = []
    for path in paths:
        paths_lengths.append(path[0].shape[0])
        if len(path) == 3:
            observations = path[0]
            actions = path[1]
            rewards = path[2]
            if observations.shape[0] > 2:
                next_obs = np.concatenate((observations[1:], np.array([observations[-1]])), axis=0)
            elif observations.shape[0] == 2:
                next_obs = np.concatenate((np.array([observations[-1]]), np.array([observations[-1]])), axis=0)
            else:
                next_obs = np.array([observations[-1]])
            path_temp = [observations, actions, rewards, next_obs]
        else:
            path_temp = path
        new_paths_temp.append(path_temp)
    paths = new_paths_temp
    paths_lengths = np.array(paths_lengths)
    print(np.where(paths_lengths == 490))
    paths_lengths_cumsum = np.cumsum(paths_lengths)

    new_paths_part_temp = []

    for path in paths_part:
        if len(path) == 3:
            observations = path[0]
            actions = path[1]
            rewards = path[2]
            if observations.shape[0] > 2:
                next_obs = np.concatenate((observations[1:], np.array([observations[-1]])), axis=0)
            elif observations.shape[0] == 2:
                next_obs = np.concatenate((np.array([observations[-1]]), np.array([observations[-1]])), axis=0)
            else:
                next_obs = np.array([observations[-1]])
            path_temp = [observations, actions, rewards, next_obs]
        else:
            path_temp = path
        new_paths_part_temp.append(path_temp)
    paths_part = new_paths_part_temp
    new_paths = []
    stitching_num = 0
    count = 0
    obs_numpy_new = np.array([])
    with torch.no_grad():
        for path in paths_part:
            # path = paths_part[368]
            s = path[0][0]
            path_length = len(path[0])
            s_prime = path[3][0]
            new_path = []
            s_path_idx = part_num * part_length
            # s_path_idx = 3043
            s_step_idx = 0
            step_cnt = 0
            s_path_idx += count
            new_states = []
            flag = False

            while (s_step_idx < paths_lengths[s_path_idx]) and step_cnt < max_path_length:
            # while not np.array_equal(s, s_prime):
                # neighbour = create_neighbour(bisim_model, obs_numpy, s, s_prime, Config.device)
                neighbour = create_neighbour(obs_numpy, s, s_prime, Config.device, paths_lengths_cumsum)
                gaussian_params = torch.zeros((7, state_dim * 2))
                step_cnt += 1
                neighbour_size = len(neighbour)
                s_prime_prob = torch.zeros(7, dtype=torch.float32).to(Config.device)
                s_hat_prime_prob = torch.zeros(7, dtype=torch.float32).to(Config.device)
                values = torch.zeros(neighbour_size, dtype=torch.float32)
                s_prime_val_1 = V_model_1(torch.from_numpy(s_prime).to(Config.device))
                s_prime_val_2 = V_model_2(torch.from_numpy(s_prime).to(Config.device))
                s_prime_val = min(s_prime_val_1, s_prime_val_2)
                logger.print(neighbour_size)
                if neighbour_size == 0:
                    new_states.append(s)
                    temp_action = dataset.normalizer.unnormalize(paths[s_path_idx][1][s_step_idx], "actions")
                    if not new_path:
                        new_path = [np.expand_dims(s, axis=0), np.expand_dims(temp_action, axis=0),
                                    np.expand_dims(paths[s_path_idx][2][s_step_idx], axis=0),
                                    np.expand_dims(paths[s_path_idx][3][s_step_idx], axis=0)]
                    else:
                        new_path[0] = np.concatenate((new_path[0], np.expand_dims(s, axis=0)), axis=0)
                        new_path[1] = np.concatenate((new_path[1], np.expand_dims(temp_action, axis=0)), axis=0)
                        new_path[2] = np.concatenate((new_path[2], np.expand_dims(paths[s_path_idx][2][s_step_idx], axis=0)), axis=0)
                        new_path[3] = np.concatenate((new_path[3], np.expand_dims(paths[s_path_idx][3][s_step_idx], axis=0)), axis=0)
                    s = s_prime
                    logger.print("normal step", s_path_idx, s_step_idx)
                    s_step_idx += 1
                    s_path_idx = s_path_idx
                    if s_step_idx < paths_lengths[s_path_idx]:
                        s_prime = paths[s_path_idx][3][s_step_idx]
                    continue

                neighbour_torch = torch.from_numpy(neighbour[:, 2:])
                neighbour_split = neighbour_torch.shape[0] // 2
                neighbour_1 = neighbour_torch[:neighbour_split, :].to(Config.device)
                neighbour_2 = neighbour_torch[neighbour_split:, :].to(Config.device)

                value1_1 = V_model_1(neighbour_1)
                value1_2 = V_model_1(neighbour_2)
                value1 = torch.concatenate((value1_1, value1_2), dim=0)
                value2_1 = V_model_2(neighbour_1)
                value2_2 = V_model_2(neighbour_2)
                value2 = torch.concatenate((value2_1, value2_2), dim=0)
                value_cat = torch.concatenate((value1, value2), dim=1).to(torch.device("cpu"))
                values = torch.min(value_cat, dim=1).values
                neighbour_1 = neighbour_1.to(torch.device("cpu"))
                neighbour_2 = neighbour_2.to(torch.device("cpu"))

                # for j in range(neighbour_size):
                #     value1 = V_model_1(torch.from_numpy(neighbour[j][2:]).to(Config.device))
                #     value2 = V_model_2(torch.from_numpy(neighbour[j][2:]).to(Config.device))
                #     values[j] = min(value1, value2)
                max_val = torch.max(values)
                max_val_idx = torch.argmax(values)
                max_val_path = int(neighbour[max_val_idx][0])
                max_val_step = int(neighbour[max_val_idx][1])
                flag = False
                for item in new_states:
                    if np.array_equal(item, paths[max_val_path][0][max_val_step]):
                        flag = True
                        break

                if flag:
                    new_states.append(s)
                    temp_action = dataset.normalizer.unnormalize(paths[s_path_idx][1][s_step_idx], "actions")
                    if not new_path:
                        new_path = [np.expand_dims(s, axis=0), np.expand_dims(temp_action, axis=0),
                                    np.expand_dims(paths[s_path_idx][2][s_step_idx], axis=0),
                                    np.expand_dims(paths[s_path_idx][3][s_step_idx], axis=0)]
                    else:
                        new_path[0] = np.concatenate((new_path[0], np.expand_dims(s, axis=0)), axis=0)
                        new_path[1] = np.concatenate((new_path[1], np.expand_dims(temp_action, axis=0)), axis=0)
                        new_path[2] = np.concatenate((new_path[2], np.expand_dims(paths[s_path_idx][2][s_step_idx], axis=0)), axis=0)
                        new_path[3] = np.concatenate((new_path[3], np.expand_dims(paths[s_path_idx][3][s_step_idx], axis=0)), axis=0)
                    s = s_prime
                    logger.print("normal step", s_path_idx, s_step_idx)
                    s_step_idx += 1
                    s_path_idx = s_path_idx
                    if s_step_idx < paths_lengths[s_path_idx]:
                        s_prime = paths[s_path_idx][3][s_step_idx]
                    continue

                for j in range(7):
                    gaussian_params[j] = forward_models[j](torch.from_numpy(s).to(Config.device))
                    mean = gaussian_params[j, :state_dim]
                    std = gaussian_params[j, state_dim:]
                    var = torch.diag(std)
                    var = torch.square(var)
                    normal = MultivariateNormal(mean, var)
                    # prob = torch.zeros(neighbour_size, dtype=torch.float32).to(Config.device)
                    # for k in range(neighbour_size):
                    #     prob[k] = torch.exp(normal.log_prob(torch.from_numpy(neighbour[k][2])))
                    s_prime_prob[j] = torch.exp(normal.log_prob(torch.from_numpy(s_prime)))
                    s_hat_prime_prob[j] = torch.exp(normal.log_prob(torch.from_numpy(neighbour[max_val_idx][2:])))
                s_hat_prime_prob_sorted, _ = torch.sort(s_hat_prime_prob, dim=0)
                # chosen the top three
                s_hat_prime_prob_chosen = s_hat_prime_prob_sorted[-3:]
                if torch.min(s_hat_prime_prob_chosen) > torch.mean(s_prime_prob) and max_val > s_prime_val:
                    new_s_prime_torch = torch.from_numpy(paths[max_val_path][0][max_val_step]).to(Config.device)
                    s_torch = torch.from_numpy(s).to(Config.device)
                    reward = torch.from_numpy(paths[max_val_path][2][max_val_step]).to(Config.device)
                    s_comb = torch.concatenate((s_torch, new_s_prime_torch), dim=-1)
                    s_comb = torch.unsqueeze(s_comb, dim=0)
                    s_comb = torch.reshape(s_comb, (-1, 2 * state_dim))
                    action_pred = trainer.ema_model.inv_model(s_comb)   #unnormalized
                    # action_pred = dataset.normalizer.unnormalize(action_pred, 'actions')
                    # s_comb_reward = torch.concatenate((s_torch, new_s_prime_torch, reward), dim=-1)
                    # s_comb_reward = torch.unsqueeze(s_comb_reward, dim=0)
                    # s_comb_reward = torch.reshape(s_comb_reward, (-1, 2 * state_dim + 1))

                    # previous_states = torch.from_numpy(path[s_path_idx][0][(s_step_idx - horizon + 1): s_step_idx]).to(Config.device)
                    # previous_next = torch.from_numpy(path[s_path_idx][3][(s_step_idx - horizon + 1): s_step_idx]).to(Config.device)
                    # previous_reward = torch.from_numpy(path[s_path_idx][2][(s_step_idx - horizon + 1): s_step_idx]).to(Config.device)

                    input_states = torch.unsqueeze(s_torch.to(Config.device), dim=0)
                    input_next = torch.unsqueeze(new_s_prime_torch.to(Config.device), dim=0)

                    condition = {0: input_states}
                    next_condition = {0: input_next}

                    samples = next_trainer.ema_model.conditional_sample(condition, returns=returns, next_conditions=next_condition)
                    dims = samples.shape[2]
                    samples, sample_next, sample_reward = samples[:, :, :((dims - 1) // 2)], samples[:, :, ((dims - 1) // 2):-1], samples[:, :, -1]
                    reward_pred = to_np(sample_reward[:, 1])
                    # action_pred = dataset.normlizer.unnormalize(action_pred, "actions")
                    action_pred = to_np(action_pred)
                    action_pred = dataset.normalizer.unnormalize(action_pred, "actions")
                    new_states.append(s)
                    if not new_path:
                        new_path = [np.expand_dims(s, axis=0), action_pred, np.expand_dims(reward_pred, axis=0), np.expand_dims(paths[max_val_path][0][max_val_step], axis=0)]
                    else:
                        new_path[0] = np.concatenate((new_path[0], np.expand_dims(s, axis=0)), axis=0)
                        new_path[1] = np.concatenate((new_path[1], action_pred), axis=0)
                        new_path[2] = np.concatenate((new_path[2], np.expand_dims(reward_pred, axis=0)), axis=0)
                        new_path[3] = np.concatenate((new_path[3], np.expand_dims(paths[max_val_path][0][max_val_step], axis=0)), axis=0)
                    s = paths[max_val_path][0][max_val_step]
                    s_prime = paths[max_val_path][3][max_val_step]
                    logger.print("prior to stitching", s_path_idx, s_step_idx)
                    s_path_idx = max_val_path
                    s_step_idx = max_val_step
                    logger.print("stitching happened", s_path_idx, s_step_idx)
                    stitching_num += 1
                else:
                    new_states.append(s)
                    temp_action = dataset.normalizer.unnormalize(paths[s_path_idx][1][s_step_idx], "actions")
                    if not new_path:
                        new_path = [np.expand_dims(s, axis=0), np.expand_dims(temp_action, axis=0), np.expand_dims(paths[s_path_idx][2][s_step_idx], axis=0), np.expand_dims(paths[s_path_idx][3][s_step_idx], axis=0)]
                    else:
                        new_path[0] = np.concatenate((new_path[0], np.expand_dims(s, axis=0)), axis=0)
                        new_path[1] = np.concatenate((new_path[1], np.expand_dims(temp_action, axis=0)), axis=0)
                        new_path[2] = np.concatenate((new_path[2], np.expand_dims(paths[s_path_idx][2][s_step_idx], axis=0)), axis=0)
                        new_path[3] = np.concatenate((new_path[3], np.expand_dims(paths[s_path_idx][3][s_step_idx], axis=0)), axis=0)
                    s = s_prime

                    logger.print("normal step", s_path_idx, s_step_idx)
                    s_step_idx += 1
                    s_path_idx = s_path_idx
                    if s_step_idx < paths_lengths[s_path_idx]:
                        s_prime = paths[s_path_idx][3][s_step_idx]

            new_path_reward_sum = np.sum(new_path[2])

            path_reward_sum = np.sum(path[2])



            if new_path_reward_sum > (1 + threshold) * path_reward_sum:
                new_paths.append(new_path)
                print("stitched used")
                if obs_numpy_new.size == 0:
                    obs_numpy_new = new_path[0]
                else:
                    obs_numpy_new = np.concatenate((obs_numpy_new, new_path[0]), axis=0)
            else:
                new_paths.append(path)
                if obs_numpy_new.size == 0:
                    obs_numpy_new = new_path[0]
                else:
                    obs_numpy_new = np.concatenate((obs_numpy_new, new_path[0]), axis=0)

            torch.cuda.empty_cache()
            logger.print("finish one path")
            count += 1
            logger.print("finished number", count)

        paths = new_paths
        obs_numpy = obs_numpy_new
        # paths_numpy = np.array(new_paths)
        logger.print(f"stitching num at current epoch:", stitching_num)
        save_data(paths, Config.bucket, part_num)

    # for path in paths:
    #     path = [to_torch(item, device=Config.device) for item in path]
    #     paths_torch.append(path)

    # TODO: retrain the value function uncomment
    # new_dataset = CustomSequenceDataset(paths, dataset.indices)
    # new_dataloader = cycle(torch.utils.data.DataLoader(
    #     new_dataset, batch_size=Config.train_batch_size, num_workers=0, shuffle=True, pin_memory=True
    # ))
    #
    # for step in range(Config.n_train_steps // 10):
    #     for k in range(Config.gradient_accumulate_every):
    #         batch = next(new_dataloader)
    #         batch = [to_torch(item, device=Config.device) for item in batch]
    #         loss, infos = trainer.model.loss(*batch)
    #         loss = loss / Config.gradient_accumulate_every
    #         loss.backward()
    #
    #     trainer.optimizer.step()
    #     trainer.optimizer.zero_grad()
    #
    #     if trainer.step % trainer.update_ema_every == 0:
    #         trainer.step_ema()
    #
    #     if trainer.step % trainer.save_freq == 0:
    #         trainer.save()
    #
    #     if trainer.step % trainer.log_freq == 0:
    #         infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()])
    #         logger.print(f'{trainer.step}: {loss:8.4f} | {infos_str} | t: {timer():8.4f}')
    #         metrics = {k:v.detach().item() for k, v in infos.items()}
    #         metrics['steps'] = trainer.step
    #         metrics['loss'] = loss.detach().item()
    #         logger.log_metrics_summary(metrics, default_stats='mean')
    #
    #     if trainer.step == 0 and trainer.sample_freq:
    #         trainer.render_reference(trainer.n_reference)
    #
    #     trainer.step += 1
    #
    # trainer.step = 0
    #
    # # TODO: create trajectories and retrain value function
    #
    # num_eval = 100
    # device = Config.device
    #
    # env_list = [gym.make(Config.dataset) for _ in range(num_eval)]
    # dones = [0 for _ in range(num_eval)]
    # episode_rewards = [0 for _ in range(num_eval)]
    #
    # assert trainer.ema_model.condition_guidance_w == Config.condition_guidance_w
    # returns = to_device(Config.test_ret * torch.ones(num_eval, 1), device)
    #
    # t = 0
    # obs_list = [env.reset()[None] for env in env_list]
    # obs = np.concatenate(obs_list, axis=0)
    # recorded_obs = [deepcopy(obs[:, None])]
    #
    # reward_list = [[] for i in range(num_eval)]
    # true_list = [[] for i in range(num_eval)]
    #
    # traj_length = 1000
    # state_dim = env_list[0].observation_space.shape[0]
    # action_dim = env_list[0].action_space.shape[0]
    # length = [0 for i in range(num_eval)]
    # traj = [Trajectory(length=traj_length, state_dim=state_dim, action_dim=action_dim) for i in range(num_eval)]
    #
    # while sum(dones) < num_eval:
    #     obs = dataset.normalizer.normalize(obs, 'observations')
    #     conditions = {0: to_torch(obs, device=device)}
    #     samples = trainer.ema_model.conditional_sample(conditions, returns=returns)
    #     samples, sample_reward = samples[:, :, :-1], samples[:, :, -1]
    #     obs_comb = torch.cat([samples[:, 0, :], samples[:, 1, :]], dim=-1)  # TODO:increase window length
    #     obs_comb = obs_comb.reshape(-1, 2 * observation_dim)
    #     action = trainer.ema_model.inv_model(obs_comb)
    #
    #     samples = to_np(samples)
    #     # samples = dataset.normalizer.unnormalize(samples, 'observations')
    #     action = to_np(action)
    #     sample_t = to_np(samples[:, 0, :])
    #
    #     reward_t = to_np(sample_reward[:, 0])
    #     reward_t_1 = to_np(sample_reward[:, 1])
    #
    #     # action = dataset.normalizer.unnormalize(action, 'actions')
    #
    #     # if t == 0:
    #     #     normed_observations = samples[:, :, :]
    #     #     observations = dataset.normalizer.unnormalize(normed_observations, 'observations')
    #     #     savepath = os.path.join('images', 'sample-planned.png')
    #     #     renderer.composite(savepath, observations)
    #
    #     obs_list = []
    #
    #     reward_t = np.array(reward_t)
    #     reward_t = reward_t * 4
    #
    #     for i in range(num_eval):
    #         traj[i].add_item(input_state=sample_t[i, :], input_action=action[i], input_reward=reward_t[i],
    #                          idx=length[i])
    #         length[i] += 1
    #         # this_obs, this_reward, this_done, _ = env_list[i].step(action[i])
    #         # true_list[i].append(this_reward)
    #         # reward_list[i].append(reward_t)
    #         # logger.print(f"R_t: {R_t[i] * 400 * 4}")    #for horizon 100 (100 * 3)
    #         # logger.print(f"R_t_1: {R_t_1[i] * 400 * 4}")
    #         # logger.print(f"this reward {this_reward}")
    #
    #         # obs_list.append(this_obs[None])
    #         if length[i] == traj_length:
    #             dones[i] = 1
    #
    #     # obs = np.concatenate(obs_list, axis=0)
    #     # recorded_obs.append(deepcopy(obs[:, None]))
    #     t += 1
    #
    # save_traj(traj, Config.bucket)
    #
    # # for i in range(num_eval):
    # #     print(f"true {i}", true_list[i])
    # #     print(f"reward {i}", reward_list[i] * 4)
    # #
    # # recorded_obs = np.concatenate(recorded_obs, axis=1)
    # # savepath = os.path.join('images', f'sample-executed.png')
    # # renderer.composite(savepath, recorded_obs)
    # # episode_rewards = np.array(episode_rewards)
    # #
    # # logger.print(f"average_ep_reward: {np.mean(episode_rewards)}, std_ep_reward: {np.std(episode_rewards)}",
    # #              color='green')
    # # logger.log_metrics_summary(
    # #     {'average_ep_reward': np.mean(episode_rewards), 'std_ep_reward': np.std(episode_rewards)})
    # train_val(traj, state_dim, action_dim, device, Config.seed, Config.bucket)
    #
    # save_data(paths, Config.bucket)




