import argparse
import gym
import json
import os
import pickle
import random
import time
import torch

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import wasserstein

from decision_transformer.models.decision_transformer import DistributionalDecisionTransformer


plt.style.use('ggplot')
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

VELOCITY_DIM = {
    'halfcheetah': (8, ),
    'hopper': (5, ),
    'walker2d': (8, ),
    'ant': (13, 14)
}


def discount_cumsum(x, gamma):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1]
    return discount_cumsum


def gauss(x, mu=0, sigma=1):
    c = 1 / np.sqrt(2 * np.pi) / (sigma + 1e-12)
    return c * np.exp(-(x - mu)**2 / (2 * sigma**2 + 1e-12))


def experiment(output_dir, eval_dir, variant):
    gpu = variant.get('gpu', 0)
    device = torch.device(
        f"cuda:{gpu}" if (torch.cuda.is_available() and gpu >= 0) else "cpu"
    )

    env_name, dataset = variant['env'], variant['dataset']
    seed = variant['seed']
    dist_dim = variant['dist_dim']
    mode = 'normal'
    n_bins = variant['n_bins']
    distributions = variant['distributions']
    assert distributions in ['categorical', 'gaussian', 'deterministic']
    gamma = variant['gamma']
    if distributions != 'categorical':
        assert gamma == 1.
    condition = variant['condition']
    assert condition in ['reward', 'xvel', 'xyvel']
    shift_const = variant['const_shift']

    if env_name == 'hopper':
        env = gym.make('Hopper-v3')
        eval_env = gym.make('Hopper-v3')
        target_path = 'data/synthesized/hopper_medium_expert_x_vel_synthesized.pkl'
    elif env_name == 'halfcheetah':
        env = gym.make('HalfCheetah-v3')
        eval_env = gym.make('HalfCheetah-v3')
        target_path = 'data/synthesized/half_cheetah_medium_expert_x_vel_synthesized.pkl'
    elif env_name == 'walker2d':
        env = gym.make('Walker2d-v3')
        eval_env = gym.make('Walker2d-v3')
        target_path = 'data/synthesized/walker2d_medium_expert_x_vel_synthesized.pkl'
    elif env_name == 'ant':
        env = gym.make('Ant-v3')
        eval_env = gym.make('Ant-v3')
        target_path = 'data/synthesized/ant_medium_expert_xy_vel_synthesized.pkl'
    else:
        raise NotImplementedError
    vel_dim = VELOCITY_DIM[env_name]
    scale = 1000.
    max_ep_len = 1000
    env.seed(seed)
    eval_env.seed(2 ** 32 - 1 - seed)
    with open(target_path, 'rb') as f:
        synthesized_target = pickle.load(f)

    state_dim = eval_env.observation_space.shape[0]
    act_dim = eval_env.action_space.shape[0]

    if condition == 'reward' or condition == 'xvel':
        if distributions == 'gaussian':
            r_dists_dim = 2
        elif distributions == 'categorical':
            r_dists_dim = dist_dim
        elif distributions == 'deterministic':
            r_dists_dim = 1
    elif condition == 'xyvel':
        if distributions == 'gaussian':
            r_dists_dim = 2 * 2  # 1d gaussian * 2
        elif distributions == 'categorical':
            r_dists_dim = dist_dim * 2  # 1d categorical * 2
        elif distributions == 'deterministic':
            r_dists_dim = 2

    dataset_path = f'data/{env_name}-{dataset}-v2.pkl'

    with open(dataset_path, 'rb') as f:
        trajectories = pickle.load(f)

    if condition == 'reward' or condition == 'xvel':
        states, traj_lens, returns, rewards = [], [], [], []
        for path in trajectories:
            states.append(path['observations'])
            traj_lens.append(len(path['observations']))
            returns.append(path['rewards'].sum())
            if condition == 'reward':
                rewards.extend(path['rewards'])
            elif condition == 'xvel':
                rewards.extend(path['observations'][:, vel_dim[0]])
        traj_lens, returns = np.array(traj_lens), np.array(returns)

        # for categorical distribution matching
        r_min = min(rewards)
        r_max = max(rewards)
        bins = np.linspace(r_min, r_max, n_bins)
        label = [(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)]
        width = bins[1] - bins[0]
        shift_width = shift_const * width
    elif condition == 'xyvel':
        states, traj_lens, returns, xvels, yvels = [], [], [], [], []
        for path in trajectories:
            states.append(path['observations'])
            traj_lens.append(len(path['observations']))
            returns.append(path['rewards'].sum())
            xvels.extend(path['observations'][:, vel_dim[0]])
            yvels.extend(path['observations'][:, vel_dim[1]])
        traj_lens, returns = np.array(traj_lens), np.array(returns)

        # for categorical distribution matching
        r_min = (min(xvels), min(yvels))
        r_max = (max(xvels), max(yvels))
        bins = (np.linspace(r_min[0], r_max[0], n_bins), np.linspace(r_min[1], r_max[1], n_bins))
        label = [[
            (bins[0][i]+bins[0][i+1])/2 for i in range(len(bins[0])-1)],
            [(bins[1][i]+bins[1][i+1])/2 for i in range(len(bins[1])-1)]]
        width = (bins[0][1] - bins[0][0], bins[1][1] - bins[1][0])
        shift_width = (shift_const * width[0], shift_const * width[1])

    # used for input normalization
    states = np.concatenate(states, axis=0)
    state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

    num_timesteps = sum(traj_lens)

    print('=' * 50)
    print(f'Starting new experiment: {env_name} {dataset}')
    print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')
    print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')
    print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')
    print(f'Modality: {condition}')
    print(f'Distribution: {distributions}')
    print(f'Shift width: {shift_width}')
    print('=' * 50)

    K = variant['K']
    batch_size = variant['batch_size']

    print('Preparing empirical distributions.')
    # for evaluation with best/50% trajectories
    _idxes = np.argsort([np.sum(path['rewards']) for path in trajectories]) # rank 0 is the most bad demo.
    trajs_rank = np.empty_like(_idxes)
    trajs_rank[_idxes] = np.arange(len(_idxes))
    n_evals = 5

    r_dists = []
    if condition in ('reward', 'xvel') and distributions == 'categorical':
        for path in trajectories:
            dist = np.zeros(n_bins - 1)
            distributional_rewards = []
            steps_to_go = 0
            if condition == 'reward':
                modality = path['rewards'] + shift_width
            elif condition == 'xvel':
                modality = path['observations'][:, vel_dim[0]] + shift_width
            for t, r in enumerate(reversed(modality)):
                discretized_r = np.histogram(np.clip(r, r_min, r_max), bins=bins)[0]
                steps_to_go *= gamma
                dist *= steps_to_go
                dist = discretized_r + dist
                dist_norm = dist.sum()
                dist /= dist_norm
                steps_to_go += 1
                distributional_rewards.append(dist)
            path['r_dists'] = np.concatenate(distributional_rewards[::-1], axis=0).reshape(-1, n_bins - 1)
            r_dists.append(path['r_dists'])
    elif condition in ('reward', 'xvel') and distributions == 'gaussian':
        for path in trajectories:
            dist_mean = 0
            dist_std = 0
            distributional_rewards = []
            if condition == 'reward':
                modality = path['rewards'] + shift_width
            elif condition == 'xvel':
                modality = path['observations'][:, vel_dim[0]] + shift_width
            for t, r in enumerate(reversed(modality)):
                dist_mean *= max_ep_len
                dist_mean = gamma * dist_mean + r
                dist_std = np.std(modality[-t-1:])
                dist_mean /= max_ep_len
                distributional_rewards.append([dist_mean, dist_std])
            path['r_dists'] = np.array(distributional_rewards[::-1]).reshape(-1, 2)
            r_dists.append(path['r_dists'])
    elif condition == 'xyvel' and distributions == 'categorical':
        for path in trajectories:
            distx = np.zeros(n_bins - 1)
            disty = np.zeros(n_bins - 1)
            distributional_rewardsx = []
            distributional_rewardsy = []
            steps_to_go = 0
            modality = path['observations'][:, vel_dim[0]:vel_dim[1]+1]
            for t, xy in enumerate(reversed(modality)):
                discretized_x = np.histogram(np.clip(xy[0] + shift_width[0], r_min[0], r_max[0]), bins=bins[0])[0]
                discretized_y = np.histogram(np.clip(xy[1] + shift_width[1], r_min[1], r_max[1]), bins=bins[1])[0]
                steps_to_go *= gamma
                distx *= steps_to_go
                disty *= steps_to_go
                distx = discretized_x + distx
                disty = discretized_y + disty
                distx_norm = distx.sum()
                disty_norm = disty.sum()
                distx /= distx_norm
                disty /= disty_norm
                steps_to_go += 1
                distributional_rewardsx.append(distx)
                distributional_rewardsy.append(disty)
            path['r_dists'] = (
                np.concatenate(distributional_rewardsx[::-1], axis=0).reshape(-1, n_bins - 1),
                np.concatenate(distributional_rewardsy[::-1], axis=0).reshape(-1, n_bins - 1))
            r_dists.append(path['r_dists'])
    elif condition == 'xyvel' and distributions == 'gaussian':
        for path in trajectories:
            distx_mean = 0
            disty_mean = 0
            distx_std = 0
            disty_std = 0
            distributional_rewardsx = []
            distributional_rewardsy = []
            modality = path['observations'][:, vel_dim[0]:vel_dim[1]+1]
            for t, xy in enumerate(reversed(modality)):
                # x-vel
                distx_mean *= max_ep_len
                distx_mean = gamma * distx_mean + xy[0] + shift_width[0]
                distx_std = np.std(modality[0, -t-1:])
                distx_mean /= max_ep_len
                distributional_rewards.append([distx_mean, distx_std])
                # y-vel
                disty_mean *= max_ep_len
                disty_mean = gamma * disty_mean + xy[1] + shift_width[1]
                disty_std = np.std(modality[1, -t-1:])
                disty_mean /= max_ep_len
                distributional_rewards.append([disty_mean, disty_std])
                distributional_rewardsx.append(distx)
                distributional_rewardsy.append(disty)
            path['r_dists'] = (
                np.concatenate(distributional_rewardsx[::-1], axis=0).reshape(-1, 2),
                np.concatenate(distributional_rewardsy[::-1], axis=0).reshape(-1, 2))
            r_dists.append(path['r_dists'])
    elif condition in ('reward', 'xvel') and distributions == 'deterministic':
        for path in trajectories:
            dist = 0
            distributional_rewards = []
            if condition == 'reward':
                modality = path['rewards'] + shift_width
            elif condition == 'xvel':
                modality = path['observations'][:, vel_dim[0]] + shift_width
            for t, r in enumerate(reversed(modality)):
                dist *= max_ep_len
                dist = gamma * dist + r
                dist /= max_ep_len
                distributional_rewards.append(dist)
            path['r_dists'] = np.array(distributional_rewards[::-1]).reshape(-1, 1)
            r_dists.append(path['r_dists'])
    elif condition == 'xyvel' and distributions == 'deterministic':
        for path in trajectories:
            distx = 0
            disty = 0
            distributional_rewards = []
            modality = path['observations'][:, vel_dim[0]:vel_dim[1]+1] + shift_width
            for t, xy in enumerate(reversed(modality)):
                distx *= max_ep_len
                disty *= max_ep_len
                distx = xy[0] + gamma * distx
                disty = xy[1] + gamma * disty
                distx /= max_ep_len
                disty /= max_ep_len
                distributional_rewards.append([distx, disty])
            path['r_dists'] = np.array(distributional_rewards[::-1]).reshape(-1, 2)
            r_dists.append(path['r_dists'])
    else:
        raise NotImplementedError
    assert len(trajs_rank) == len(r_dists)
    # train / eval split
    best_trajs = [r_dists[np.where(trajs_rank == len(trajs_rank)-idx-1)[0][0]] for idx in range(n_evals)]  # top-{n_evals}
    middle_trajs = [r_dists[np.where(trajs_rank == int(len(trajs_rank)/2)+idx-2)[0][0]] for idx in range(n_evals)]  # 50%-{n_evals}

    if condition == 'reward':
        best_trajs_all = [
            np.histogram(np.clip(trajectories[np.where(trajs_rank == len(trajs_rank)-idx-1)[0][0]]['rewards'] + shift_width, r_min, r_max), bins=bins)[0].astype(float) for idx in range(n_evals)]
        best_trajs_all = [t/(t.sum() + 1e-12) for t in best_trajs_all]
        middle_trajs_all = [
            np.histogram(np.clip(trajectories[np.where(trajs_rank == int(len(trajs_rank)/2)+idx-2)[0][0]]['rewards'] + shift_width, r_min, r_max), bins=bins)[0].astype(float) for idx in range(n_evals)]
        middle_trajs_all = [t/(t.sum() + 1e-12) for t in middle_trajs_all]
    elif condition == 'xvel':
        best_trajs_all = [
            np.histogram(np.clip(trajectories[np.where(trajs_rank == len(trajs_rank)-idx-1)[0][0]]['observations'][:, vel_dim[0]] + shift_width, r_min, r_max), bins=bins)[0].astype(float) for idx in range(n_evals)]
        best_trajs_all = [t/(t.sum() + 1e-12) for t in best_trajs_all]
        middle_trajs_all = [
            np.histogram(np.clip(trajectories[np.where(trajs_rank == int(len(trajs_rank)/2)+idx-2)[0][0]]['observations'][:, vel_dim[0]] + shift_width, r_min, r_max), bins=bins)[0].astype(float) for idx in range(n_evals)]
        middle_trajs_all = [t/(t.sum() + 1e-12) for t in middle_trajs_all]
    elif condition == 'xyvel':
        best_trajs_all = [
            np.histogram2d(
                x=np.clip(trajectories[np.where(trajs_rank == len(trajs_rank)-idx-1)[0][0]]['observations'][:, vel_dim[0]] + shift_width[0], r_min[0], r_max[0]),
                y=np.clip(trajectories[np.where(trajs_rank == len(trajs_rank)-idx-1)[0][0]]['observations'][:, vel_dim[1]] + shift_width[1], r_min[1], r_max[1]),
                bins=bins)[0].astype(float) for idx in range(n_evals)]
        best_trajs_all = [t/(t.sum() + 1e-12) for t in best_trajs_all]
        best_samplesx = [trajectories[np.where(trajs_rank == len(trajs_rank)-idx-1)[0][0]]['observations'][:, vel_dim[0]] + shift_width[0] for idx in range(n_evals)]
        best_samplesy = [trajectories[np.where(trajs_rank == len(trajs_rank)-idx-1)[0][0]]['observations'][:, vel_dim[1]] + shift_width[1] for idx in range(n_evals)]
        middle_trajs_all = [
            (
                np.histogram(np.clip(trajectories[np.where(trajs_rank == int(len(trajs_rank)/2)+idx-2)[0][0]]['observations'][:, vel_dim[0]] + shift_width[0], r_min[0], r_max[0]), bins=bins[0])[0].astype(float),
                np.histogram(np.clip(trajectories[np.where(trajs_rank == int(len(trajs_rank)/2)+idx-2)[0][0]]['observations'][:, vel_dim[1]] + shift_width[1], r_min[1], r_max[1]), bins=bins[1])[0].astype(float)) for idx in range(n_evals)]
        middle_trajs_all = [(x/(x.sum() + 1e-12), y/(y.sum() + 1e-12)) for x, y in middle_trajs_all]
        middle_trajs_all = [
            np.histogram2d(
                x=np.clip(trajectories[np.where(trajs_rank == int(len(trajs_rank)/2)+idx-2)[0][0]]['observations'][:, vel_dim[0]] + shift_width[0], r_min[0], r_max[0]),
                y=np.clip(trajectories[np.where(trajs_rank == int(len(trajs_rank)/2)+idx-2)[0][0]]['observations'][:, vel_dim[1]] + shift_width[1], r_min[1], r_max[1]),
                bins=bins)[0].astype(float) for idx in range(n_evals)]
        middle_trajs_all = [t/(t.sum() + 1e-12) for t in middle_trajs_all]
        middle_samplesx = [trajectories[np.where(trajs_rank == int(len(trajs_rank)/2)+idx-2)[0][0]]['observations'][:, vel_dim[0]] + shift_width[0] for idx in range(n_evals)]
        middle_samplesy = [trajectories[np.where(trajs_rank == int(len(trajs_rank)/2)+idx-2)[0][0]]['observations'][:, vel_dim[1]] + shift_width[1] for idx in range(n_evals)]
    else:
        raise NotImplementedError

    eval_trajectories = {}
    for i in range(n_evals):
        eval_trajectories['best_traj_{}'.format(i)] = (best_trajs[i], best_trajs_all[i])
        eval_trajectories['middle_traj_{}'.format(i)] = (middle_trajs[i], middle_trajs_all[i])
        if condition == 'xyvel':
            eval_trajectories['best_traj_{}'.format(i)] = (
                best_trajs[i],
                best_trajs_all[i],
                best_samplesx[i],
                best_samplesy[i])
            eval_trajectories['middle_traj_{}'.format(i)] = (
                middle_trajs[i],
                middle_trajs_all[i],
                middle_samplesx[i],
                middle_samplesy[i])
    if condition == 'xvel' and distributions == 'categorical':
        # add synthethized distribution
        for name, syn in zip(synthesized_target.keys(), synthesized_target.values()):
            dist = np.zeros(n_bins - 1)
            distributional_rewards = []
            steps_to_go = 0
            for t, r in enumerate(reversed(syn)):
                discretized_r = np.histogram(np.clip(r, r_min, r_max), bins=bins)[0]
                steps_to_go *= gamma
                dist *= steps_to_go
                dist = discretized_r + dist
                dist_norm = dist.sum()
                dist /= dist_norm
                steps_to_go += 1
                distributional_rewards.append(dist)
            syn_all = np.histogram(np.clip(syn, r_min, r_max), bins=bins)[0].astype(float)
            eval_trajectories[name] = (
                np.concatenate(distributional_rewards[::-1], axis=0).reshape(-1, n_bins - 1),
                syn_all / (syn_all.sum() + 1e-12))
    elif condition == 'xvel' and distributions == 'gaussian':
        # add synthethized distribution
        for name, syn in zip(synthesized_target.keys(), synthesized_target.values()):
            dist_mean = 0
            dist_std = 0
            distributional_rewards = []
            for t, r in enumerate(reversed(syn)):
                dist_mean *= max_ep_len
                dist_mean = gamma * dist_mean + r
                dist_std = np.std(syn[-t-1:])
                dist_mean /= max_ep_len
                distributional_rewards.append([dist_mean, dist_std])
            syn_all = np.histogram(np.clip(syn, r_min, r_max), bins=bins)[0].astype(float)
            eval_trajectories[name] = (
                np.array(distributional_rewards[::-1]).reshape(-1, 2),
                syn_all / (syn_all.sum() + 1e-12))
    elif condition == 'xyvel' and distributions in ('categorical', 'gaussian'):
        for name, syn in zip(synthesized_target.keys(), synthesized_target.values()):
            distx = np.zeros(n_bins - 1)
            distributional_rewardsx = []
            disty = np.zeros(n_bins - 1)
            distributional_rewardsy = []
            steps_to_go = 0
            raw_xvel = []
            raw_yvel = []
            for t, xy in enumerate(reversed(syn)):
                raw_xvel.append(xy[0])
                raw_yvel.append(xy[1])
                discretized_x = np.histogram(np.clip(xy[0], r_min[0], r_max[0]), bins=bins[0])[0]
                discretized_y = np.histogram(np.clip(xy[1], r_min[1], r_max[1]), bins=bins[1])[0]
                steps_to_go *= gamma
                distx *= steps_to_go
                disty *= steps_to_go
                distx = discretized_x + distx
                disty = discretized_y + disty
                distx_norm = distx.sum()
                disty_norm = disty.sum()
                distx /= distx_norm
                disty /= disty_norm
                steps_to_go += 1
                distributional_rewardsx.append(distx)
                distributional_rewardsy.append(disty)
            syn_all = np.histogram2d(
                x=np.clip(syn[:,0], r_min[0], r_max[0]),
                y=np.clip(syn[:,1], r_min[1], r_max[1]),
                bins=bins)[0].astype(float)
            syn_all = syn_all / (syn_all.sum() + 1e-12)
            eval_trajectories[name] = [
                (
                    np.concatenate(distributional_rewardsx[::-1], axis=0).reshape(-1, n_bins - 1),
                    np.concatenate(distributional_rewardsy[::-1], axis=0).reshape(-1, n_bins - 1)),
                syn_all,
                np.array(raw_xvel),
                np.array(raw_yvel),
                ]
    elif condition == 'xvel' and distributions == 'deterministic':
        # add synthethized distribution
        for name, syn in zip(synthesized_target.keys(), synthesized_target.values()):
            dist = 0
            distributional_rewards = []
            for t, r in enumerate(reversed(syn)):
                dist *= max_ep_len
                dist = r + gamma * dist
                dist /= max_ep_len
                distributional_rewards.append(dist)
            syn_all = np.histogram(np.clip(syn, r_min, r_max), bins=bins)[0].astype(float)
            eval_trajectories[name] = (
                np.array(distributional_rewards[::-1]).reshape(-1, 1),
                syn_all / (syn_all.sum() + 1e-12))
    elif condition == 'xyvel' and distributions == 'deterministic':
        # add synthethized distribution
        for name, syn in zip(synthesized_target.keys(), synthesized_target.values()):
            distx = 0
            disty = 0
            distributional_rewards = []
            raw_xvel = []
            raw_yvel = []
            for t, xy in enumerate(reversed(syn)):
                raw_xvel.append(xy[0])
                raw_yvel.append(xy[1])
                distx *= max_ep_len
                disty *= max_ep_len
                distx = xy[0] + gamma * distx
                disty = xy[1] + gamma * disty
                distx /= max_ep_len
                disty /= max_ep_len
                distributional_rewards.append([distx, disty])
            syn_all = np.histogram2d(
                x=np.clip(syn[:,0], r_min[0], r_max[0]),
                y=np.clip(syn[:,1], r_min[1], r_max[1]),
                bins=bins)[0].astype(float)
            syn_all = syn_all / (syn_all.sum() + 1e-12)
            eval_trajectories[name] = [
                np.array(distributional_rewards[::-1]).reshape(-1, 2),
                syn_all,
                np.array(raw_xvel),
                np.array(raw_yvel),
                ]
    elif condition == 'reward':
        pass

    model = DistributionalDecisionTransformer(
        state_dim=state_dim,
        act_dim=act_dim,
        max_length=K,
        max_ep_len=max_ep_len,
        hidden_size=variant['embed_dim'],
        dist_dim=r_dists_dim,
        n_layer=variant['n_layer'],
        n_head=variant['n_head'],
        n_inner=4*variant['embed_dim'],
        activation_function=variant['activation_function'],
        n_positions=1024,
        resid_pdrop=variant['dropout'],
        attn_pdrop=variant['dropout'],
    )

    model = model.to(device=device)

    print('Starting evaluation loop.')
    model.eval()

    state_mean = torch.from_numpy(state_mean).to(device=device)
    state_std = torch.from_numpy(state_std).to(device=device)

    # for itr in range(variant['max_iters']):
    for itr in [9]:
        outputs = dict()
        model.load_state_dict(torch.load(os.path.join(output_dir, f'dt_{itr}.pth'), map_location=device))
        eval_start = time.time()
        for k, v in eval_trajectories.items():
            returns, traj_rewards = [], []
            traj_rewardsy = []
            for _ in range(variant['num_eval_episodes']):

                state = eval_env.reset()
                states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
                actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
                rewards = torch.zeros(0, device=device, dtype=torch.float32)
                input_distribution = v[0]

                if condition == 'xyvel' and distributions in ('categorical', 'gaussian'):
                    next_target = (input_distribution[0][0], input_distribution[1][0])
                elif condition == 'xyvel' and distributions == 'deterministic':
                    next_target = (input_distribution[0][0], input_distribution[0][1])
                else:
                    next_target = input_distribution[0]
                    # for avoiding error
                    _next_target = []
                    if condition == 'xvel' and distributions == 'gaussian':
                        v1 = next_target[0]
                        if type(v1) != np.float64:
                            _next_target.append(v1[0])
                        else:
                            _next_target.append(v1)
                        v2 = next_target[1]
                        if type(v2) != np.float64:
                            _next_target.append(v2[0])
                        else:
                            _next_target.append(v2)
                        next_target = np.array(_next_target).astype(float)

                if condition in ('reward', 'xvel') and distributions in ('categorical', 'gaussian'):
                    target_distributions = torch.from_numpy(
                        next_target).to(device=device, dtype=torch.float32).reshape(1, r_dists_dim)
                elif condition in ('reward', 'xvel') and distributions == 'deterministic':
                    target_distributions = torch.from_numpy(
                        next_target).to(device=device, dtype=torch.float32).reshape(1, 1)
                elif condition == 'xyvel':
                    if distributions == 'gaussian':
                        target_distributions = torch.from_numpy(
                            np.concatenate([next_target[0], next_target[1]]).reshape(1, r_dists_dim)
                            ).to(device=device, dtype=torch.float32).reshape(1, r_dists_dim)
                    elif distributions == 'deterministic':
                        target_distributions = torch.tensor(
                            [next_target[0], next_target[1]]
                            ).to(device=device, dtype=torch.float32).reshape(1, r_dists_dim)
                    else:
                        target_distributions = torch.from_numpy(
                            np.concatenate([next_target[0], next_target[1]]).reshape(1, r_dists_dim)
                            ).to(device=device, dtype=torch.float32).reshape(1, r_dists_dim)

                timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)

                # dummy
                target_return = torch.tensor(0, device=device, dtype=torch.float32).reshape(1, 1)

                sim_states = []

                episode_return, episode_length = 0, 0
                if condition == 'xyvel' and distributions in ('categorical', 'gaussian'):
                    max_traj_len = len(input_distribution[0])
                else:
                    max_traj_len = len(input_distribution)

                for t in range(max_traj_len):
                    # add padding
                    actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
                    rewards = torch.cat([rewards, torch.zeros(1, device=device)])
                    action = model.get_action(
                        (states.to(dtype=torch.float32) - state_mean) / state_std,
                        actions.to(dtype=torch.float32),
                        rewards.to(dtype=torch.float32),
                        target_return.to(dtype=torch.float32),
                        timesteps.to(dtype=torch.long),
                        target_distributions.to(dtype=torch.float32),
                    )
                    actions[-1] = action
                    action = action.detach().cpu().numpy()

                    state, reward, done, _ = eval_env.step(action)

                    if condition == 'xvel':
                        traj_rewards.append(state[vel_dim[0]])
                    elif condition == 'reward':
                        traj_rewards.append(reward)
                    elif condition == 'xyvel':
                        traj_rewards.append(state[vel_dim[0]])
                        traj_rewardsy.append(state[vel_dim[1]])
                    cur_state = torch.from_numpy(state).to(device=device, dtype=torch.float32).reshape(1, state_dim)
                    states = torch.cat([states, cur_state], dim=0)
                    rewards[-1] = reward

                    if t < max_traj_len - 1:
                        assert variant['eval_target'] != 'update'  # 'update' is deplicated
                        if condition in ('reward', 'xvel') and distributions in ('categorical', 'gaussian'):
                            # update target distribution
                            if variant['eval_target'] == 'update':
                                discretized_r = np.histogram(np.clip(reward, r_max, r_min), bins=bins)[0]
                                next_target = next_target * (max_traj_len - t) - discretized_r
                                next_target = np.clip(next_target, 0., None)
                                next_target /= (next_target.sum() + 1e-12)
                            elif variant['eval_target'] == 'fix':
                                next_target = input_distribution[t+1]

                            target = next_target
                            # for avoiding error
                            _next_target = []
                            if condition == 'xvel' and distributions == 'gaussian':
                                v1 = next_target[0]
                                if type(v1) != np.float64:
                                    _next_target.append(v1[0])
                                else:
                                    _next_target.append(v1)
                                v2 = next_target[1]
                                if type(v2) != np.float64:
                                    _next_target.append(v2[0])
                                else:
                                    _next_target.append(v2)
                                target = np.array(_next_target).astype(float)
                        elif condition in ('reward', 'xvel') and distributions == 'deterministic':
                            # update target distribution
                            if variant['eval_target'] == 'update':
                                next_target = (next_target * max_ep_len - reward) / max_ep_len
                            elif variant['eval_target'] == 'fix':
                                next_target = input_distribution[t+1]
                            target = next_target
                        elif condition in 'xyvel':
                            if variant['eval_target'] == 'update':
                                raise NotImplementedError
                            elif variant['eval_target'] == 'fix' and distributions in ('categorical', 'gaussian'):
                                next_target = (input_distribution[0][t+1], input_distribution[1][t+1])
                                target = np.concatenate([next_target[0], next_target[1]]).reshape(1, r_dists_dim)
                            elif variant['eval_target'] == 'fix' and distributions == 'deterministic':
                                target = np.array(input_distribution[t+1])

                        target_distributions = torch.cat(
                            [
                                target_distributions,
                                torch.from_numpy(target).reshape(1, r_dists_dim).to(device=device, dtype=torch.float32)
                                ],
                            dim=1
                        )

                    timesteps = torch.cat(
                        [timesteps,
                        torch.ones((1, 1), device=device, dtype=torch.long) * (t+1)], dim=1)

                    episode_return += reward
                    episode_length += 1

                    if done:
                        break
                returns.append(episode_return)

            # evaluation
            target_all = v[1]
            wsd = wasserstein.EMD()
            if condition in ('reward', 'xvel'):
                all_reward_distribution = np.histogram(np.clip(np.array(traj_rewards), r_min, r_max), bins=bins)[0].astype(float)
                all_reward_distribution /= (all_reward_distribution.sum() + 1e-12)
                distance = wsd(
                    target_all,
                    np.array(label).reshape(-1, 1),
                    all_reward_distribution,
                    np.array(label).reshape(-1, 1)
                    )
                plt.bar(label, target_all, width, color='dodgerblue', alpha=0.5, label='target')
                plt.bar(label, all_reward_distribution, width, color='tomato', alpha=0.5, label='rollout')
                plt.legend()
                if condition == 'reward':
                    xlabel = 'Reward'
                elif condition == 'xvel':
                    xlabel = 'x-Velocity'
                plt.xlabel(xlabel)
                plt.ylabel('Probability')
                plt.title('Distance={:.5f}'.format(distance))
                plt.savefig(os.path.join(eval_dir, f'categorical_{k}_{itr}.pdf'), dpi=300)
                plt.close()
                if condition == 'reward':
                    outputs[f'target_{k}_w_dis_r'] = distance
                elif condition == 'xvel':
                    outputs[f'target_{k}_w_dis_x'] = distance
                outputs[f'target_{k}_return'] = np.mean(returns)
                if (itr == variant['max_iters'] - 1) and variant['save_rollout']:
                    np.save(os.path.join(eval_dir, f'rollout_{k}_{itr}.npy'), traj_rewards)
                    np.save(os.path.join(eval_dir, f'target_{k}_{itr}.npy'), target_all)
            elif condition == 'xyvel':
                # xvelocity
                all_reward_distribution = np.histogram(np.clip(np.array(traj_rewards), r_min[0], r_max[0]), bins=bins[0])[0].astype(float)
                all_reward_distribution /= (all_reward_distribution.sum() + 1e-12)
                target_xvel_distribution = np.histogram(np.clip(v[2], r_min[0], r_max[0]), bins=bins[0])[0].astype(float)
                target_xvel_distribution /= (target_xvel_distribution.sum() + 1e-12)
                xdistance = wsd(
                    target_xvel_distribution,
                    np.array(label[0]).reshape(-1, 1),
                    all_reward_distribution,
                    np.array(label[0]).reshape(-1, 1)
                    )
                plt.bar(label[0], target_xvel_distribution, width[0], color='dodgerblue', alpha=0.5, label='target')
                plt.bar(label[0], all_reward_distribution, width[0], color='tomato', alpha=0.5, label='rollout')
                plt.legend()
                plt.xlabel('x-Velocity')
                plt.ylabel('Probability')
                plt.title('Distance={:.5f}'.format(xdistance))
                plt.savefig(os.path.join(eval_dir, f'categorical_x_{k}_{itr}.pdf'), dpi=300)
                plt.close()

                # yvelocity
                all_reward_distribution = np.histogram(np.clip(np.array(traj_rewardsy), r_min[1], r_max[1]), bins=bins[1])[0].astype(float)
                all_reward_distribution /= (all_reward_distribution.sum() + 1e-12)
                target_yvel_distribution = np.histogram(np.clip(v[3], r_min[1], r_max[1]), bins=bins[1])[0].astype(float)
                target_yvel_distribution /= (target_yvel_distribution.sum() + 1e-12)
                ydistance = wsd(
                    target_yvel_distribution,
                    np.array(label[1]).reshape(-1, 1),
                    all_reward_distribution,
                    np.array(label[1]).reshape(-1, 1)
                    )
                plt.bar(label[1], target_yvel_distribution, width[1], color='dodgerblue', alpha=0.5, label='target')
                plt.bar(label[1], all_reward_distribution, width[1], color='tomato', alpha=0.5, label='rollout')
                plt.legend()
                plt.xlabel('y-Velocity')
                plt.ylabel('Probability')
                plt.title('Distance={:.5f}'.format(ydistance))
                plt.savefig(os.path.join(eval_dir, f'categorical_y_{k}_{itr}.pdf'), dpi=300)
                plt.close()

                # 2D
                plt.scatter(v[2], v[3], color='dodgerblue', alpha=0.5)
                plt.xlabel('x-Velocity')
                plt.ylabel('y-Velocity')
                plt.savefig(os.path.join(eval_dir, f'categorical_xy_target_{k}_{itr}.pdf'), dpi=300)
                plt.close()
                plt.scatter(np.array(traj_rewards), np.array(traj_rewardsy), color='tomato', alpha=0.05)
                plt.xlabel('x-Velocity')
                plt.ylabel('y-Velocity')
                plt.savefig(os.path.join(eval_dir, f'categorical_xy_rollout_{k}_{itr}.pdf'), dpi=300)
                plt.close()

                label_2d = np.array([[label[0][i], label[1][j]] for j in range(len(label[1])) for i in range(len(label[0]))]).reshape(-1, 2)
                all_reward_distribution = np.histogram2d(
                    x=np.clip(np.array(traj_rewards), r_min[0], r_max[0]),
                    y=np.clip(np.array(traj_rewardsy), r_min[1], r_max[1]),
                    bins=bins)[0].astype(float)
                all_reward_distribution /= (all_reward_distribution.sum() + 1e-12)
                all_reward_distribution = all_reward_distribution.reshape(-1, )
                target_all = target_all.reshape(-1, )

                distance = wsd(
                    target_all,
                    label_2d,
                    all_reward_distribution,
                    label_2d
                    )

                if (itr == variant['max_iters'] - 1) and variant['save_rollout']:
                    rollout = np.concatenate(
                        [np.array(traj_rewards).reshape(-1, 1), np.array(traj_rewardsy).reshape(-1, 1)], axis=-1).reshape(-1, 2)
                    target_all = np.concatenate([v[2].reshape(-1, 1), v[3].reshape(-1, 1)], axis=-1).reshape(-1, 2)
                    np.save(os.path.join(eval_dir, f'rollout_{k}_{itr}.npy'), rollout)
                    np.save(os.path.join(eval_dir, f'target_{k}_{itr}.npy'), target_all)
                outputs[f'target_{k}_return'] = np.mean(returns)
                outputs[f'target_{k}_w_dis_x'] = xdistance
                outputs[f'target_{k}_w_dis_y'] = ydistance
                outputs[f'target_{k}_w_dis_xy'] = distance
            else:
                raise NotImplementedError
        # record training loss, etc...
        outputs['time/evaluation'] = time.time() - eval_start

        print('=' * 80)
        print(f'Iteration {itr}')
        for k, v in outputs.items():
            print(f'{k}: {v}')

        _record_values = [itr]
        if itr == 0:
            _basic_columns = ['iter']
            for k, v in outputs.items():
                _basic_columns.append(k)
                _record_values.append(v)
            with open(os.path.join(eval_dir, "eval_log.txt"), "w") as f:
                print("\t".join(_basic_columns), file=f)
            with open(os.path.join(eval_dir, "eval_log.txt"), "a+") as f:
                print("\t".join(str(x) for x in _record_values), file=f)
        else:
            for v in outputs.values():
                _record_values.append(v)
            with open(os.path.join(eval_dir, "eval_log.txt"), "a+") as f:
                print("\t".join(str(x) for x in _record_values), file=f)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='halfcheetah')
    parser.add_argument('--dataset', type=str, default='medium-expert')
    parser.add_argument('--condition', type=str, default='reward')  # or xvel, xyvel 
    parser.add_argument('--K', type=int, default=20)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--embed_dim', type=int, default=128)
    parser.add_argument('--n_layer', type=int, default=3)
    parser.add_argument('--n_head', type=int, default=1)
    parser.add_argument('--activation_function', type=str, default='relu')
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--learning_rate', '-lr', type=float, default=1e-4)
    parser.add_argument('--weight_decay', '-wd', type=float, default=1e-4)
    parser.add_argument('--warmup_steps', type=int, default=10000)
    parser.add_argument('--max_iters', type=int, default=10)
    parser.add_argument('--num_steps_per_iter', type=int, default=10000)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--dist_dim', type=int, default=30)
    parser.add_argument('--n_bins', type=int, default=31)
    parser.add_argument('--distributions', type=str, default='categorical')  # or gaussian, deterministic
    parser.add_argument('--gamma', type=float, default=1.00)
    # for eval
    parser.add_argument('--num_eval_episodes', type=int, default=20)
    parser.add_argument('--eval_target', type=str, default='fix')  # or update
    parser.add_argument('--const_shift', type=float, default=0.0)  # [-3, -2, -1, 0, 1, 2, 3]
    parser.add_argument('--save_rollout', type=bool, default=False)

    args = parser.parse_args()

    # random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # log dir
    save_dir = f'{args.env}-{args.dataset}-{args.distributions}-dim_{args.dist_dim}-bin_{args.n_bins}-gamma_{args.gamma}-{args.condition}-ctx_{args.K}-seed_{args.seed}'
    print(save_dir)
    output_dir = os.path.join('./results', save_dir)
    assert os.path.exists(output_dir)

    eval_dir = os.path.join(output_dir, f'eval-{args.eval_target}_shift-{args.const_shift}_wasserstein')
    print(eval_dir)

    os.makedirs(eval_dir, exist_ok=True)

    with open(os.path.join(eval_dir, 'params_eval.json'), mode="w") as f:
        json.dump(args.__dict__, f, indent=4)

    experiment(output_dir, eval_dir, variant=vars(args))
