import argparse
import gym
import json
import os
import pickle
import random
import torch

import numpy as np

from decision_transformer.models.decision_transformer import DistributionalDecisionTransformer
from decision_transformer.training.seq_trainer import DistributionalSequenceTrainer


VELOCITY_DIM = {
    'halfcheetah': (8, ),
    'hopper': (5, ),
    'walker2d': (8, ),
    'ant': (13, 14)
}


def discount_cumsum(x, gamma):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1]
    return discount_cumsum


def experiment(output_dir, variant):
    gpu = variant.get('gpu', 0)
    device = torch.device(
        f"cuda:{gpu}" if (torch.cuda.is_available() and gpu >= 0) else "cpu"
    )

    env_name, dataset = variant['env'], variant['dataset']
    seed = variant['seed']
    dist_dim = variant['dist_dim']
    mode = 'normal'
    n_bins = variant['n_bins']
    distributions = variant['distributions']
    assert distributions in ['categorical', 'gaussian', 'deterministic']
    gamma = variant['gamma']
    if distributions != 'categorical':
        assert gamma == 1.
    condition = variant['condition']
    assert condition in ['reward', 'xvel', 'xyvel']

    if env_name == 'hopper':
        env = gym.make('Hopper-v3')
        eval_env = gym.make('Hopper-v3')
    elif env_name == 'halfcheetah':
        env = gym.make('HalfCheetah-v3')
        eval_env = gym.make('HalfCheetah-v3')
    elif env_name == 'walker2d':
        env = gym.make('Walker2d-v3')
        eval_env = gym.make('Walker2d-v3')
    elif env_name == 'ant':
        env = gym.make('Ant-v3')
        eval_env = gym.make('Ant-v3')
    else:
        raise NotImplementedError
    vel_dim = VELOCITY_DIM[env_name]
    scale = 1000.
    max_ep_len = 1000
    env.seed(seed)
    eval_env.seed(2 ** 32 - 1 - seed)

    state_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    if condition == 'reward' or condition == 'xvel':
        if distributions == 'gaussian':
            r_dists_dim = 2
        elif distributions == 'categorical':
            r_dists_dim = dist_dim
        elif distributions == 'deterministic':
            r_dists_dim = 1
    elif condition == 'xyvel':
        if distributions == 'gaussian':
            r_dists_dim = 2 * 2  # 1d gaussian * 2
        elif distributions == 'categorical':
            r_dists_dim = dist_dim * 2  # 1d categorical * 2
        elif distributions == 'deterministic':
            r_dists_dim = 2

    dataset_path = f'data/{env_name}-{dataset}-v2.pkl'

    with open(dataset_path, 'rb') as f:
        trajectories = pickle.load(f)

    if condition == 'reward' or condition == 'xvel':
        states, traj_lens, returns, rewards = [], [], [], []
        for path in trajectories:
            states.append(path['observations'])
            traj_lens.append(len(path['observations']))
            returns.append(path['rewards'].sum())
            if condition == 'reward':
                rewards.extend(path['rewards'])
            elif condition == 'xvel':
                rewards.extend(path['observations'][:, vel_dim[0]])
        traj_lens, returns = np.array(traj_lens), np.array(returns)

        # for categorical distribution matching
        r_min = min(rewards)
        r_max = max(rewards)
        bins = np.linspace(r_min, r_max, n_bins)
        label = [(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)]
    elif condition == 'xyvel':
        states, traj_lens, returns, xvels, yvels = [], [], [], [], []
        for path in trajectories:
            states.append(path['observations'])
            traj_lens.append(len(path['observations']))
            returns.append(path['rewards'].sum())
            xvels.extend(path['observations'][:, vel_dim[0]])
            yvels.extend(path['observations'][:, vel_dim[1]])
        traj_lens, returns = np.array(traj_lens), np.array(returns)

        # for categorical distribution matching
        r_min = (min(xvels), min(yvels))
        r_max = (max(xvels), max(yvels))
        bins = (np.linspace(r_min[0], r_max[0], n_bins), np.linspace(r_min[1], r_max[1], n_bins))
        label = [[
            (bins[0][i]+bins[0][i+1])/2 for i in range(len(bins[0])-1)],
            [(bins[1][i]+bins[1][i+1])/2 for i in range(len(bins[1])-1)]]

    # used for input normalization
    states = np.concatenate(states, axis=0)
    state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

    num_timesteps = sum(traj_lens)

    print('=' * 50)
    print(f'Starting new experiment: {env_name} {dataset}')
    print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')
    print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')
    print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')
    print(f'Modality: {condition}')
    print(f'Distribution: {distributions}')
    print('=' * 50)

    K = variant['K']
    batch_size = variant['batch_size']

    print('Preparing empirical distributions.')
    # for evaluation with best/50% trajectories
    _idxes = np.argsort([np.sum(path['rewards']) for path in trajectories]) # rank 0 is the most bad demo.
    trajs_rank = np.empty_like(_idxes)
    trajs_rank[_idxes] = np.arange(len(_idxes))
    n_evals = 5

    r_dists = []
    if condition in ('reward', 'xvel') and distributions in ('categorical', 'gaussian'):
        for path in trajectories:
            dist = np.zeros(n_bins - 1)
            distributional_rewards = []
            steps_to_go = 0
            if condition == 'reward':
                modality = path['rewards']
            elif condition == 'xvel':
                modality = path['observations'][:, vel_dim[0]]
            for t, r in enumerate(reversed(modality)):
                discretized_r = np.histogram(np.clip(r, r_min, r_max), bins=bins)[0]
                steps_to_go *= gamma
                dist *= steps_to_go
                dist = discretized_r + dist
                dist_norm = dist.sum()
                dist /= dist_norm
                steps_to_go += 1
                distributional_rewards.append(dist)
            path['r_dists'] = np.concatenate(distributional_rewards[::-1], axis=0).reshape(-1, n_bins - 1)
            r_dists.append(path['r_dists'])
    elif condition == 'xyvel' and distributions in ('categorical', 'gaussian'):
        for path in trajectories:
            distx = np.zeros(n_bins - 1)
            disty = np.zeros(n_bins - 1)
            distributional_rewardsx = []
            distributional_rewardsy = []
            steps_to_go = 0
            modality = path['observations'][:, vel_dim[0]:vel_dim[1]+1]
            for t, xy in enumerate(reversed(modality)):
                discretized_x = np.histogram(np.clip(xy[0], r_min[0], r_max[0]), bins=bins[0])[0]
                discretized_y = np.histogram(np.clip(xy[1], r_min[1], r_max[1]), bins=bins[1])[0]
                steps_to_go *= gamma
                distx *= steps_to_go
                disty *= steps_to_go
                distx = discretized_x + distx
                disty = discretized_y + disty
                distx_norm = distx.sum()
                disty_norm = disty.sum()
                distx /= distx_norm
                disty /= disty_norm
                steps_to_go += 1
                distributional_rewardsx.append(distx)
                distributional_rewardsy.append(disty)
            path['r_dists'] = (
                np.concatenate(distributional_rewardsx[::-1], axis=0).reshape(-1, n_bins - 1),
                np.concatenate(distributional_rewardsy[::-1], axis=0).reshape(-1, n_bins - 1))
            r_dists.append(path['r_dists'])
    elif condition in ('reward', 'xvel') and distributions == 'deterministic':
        for path in trajectories:
            dist = 0
            distributional_rewards = []
            if condition == 'reward':
                modality = path['rewards']
            elif condition == 'xvel':
                modality = path['observations'][:, vel_dim[0]]
            for t, r in enumerate(reversed(modality)):
                if t == 0:
                    dist += r
                else:
                    dist = r + gamma * dist
                distributional_rewards.append(dist)
            path['r_dists'] = np.array(distributional_rewards[::-1]).reshape(-1, 1) / max_ep_len
            r_dists.append(path['r_dists'])
    elif condition == 'xyvel' and distributions == 'deterministic':
        for path in trajectories:
            distx = 0
            disty = 0
            distributional_rewards = []
            modality = path['observations'][:, vel_dim[0]:vel_dim[1]+1]
            for t, xy in enumerate(reversed(modality)):
                if t == 0:
                    distx += xy[0]
                    disty += xy[1]
                else:
                    distx = xy[0] + gamma * distx
                    disty = xy[0] + gamma * disty
                distributional_rewards.append([distx, disty])
            path['r_dists'] = np.array(distributional_rewards[::-1]).reshape(-1, 2) / max_ep_len
            r_dists.append(path['r_dists'])
    else:
        raise NotImplementedError
    assert len(trajs_rank) == len(r_dists)
    # train / eval split
    eval_indices = [np.where(trajs_rank == len(trajs_rank)-idx-1)[0][0] for idx in range(n_evals)] + [np.where(trajs_rank == int(len(trajs_rank)/2)+idx-2)[0][0] for idx in range(n_evals)]
    # remove eval trajectories
    train_indices =  [i for i in range(len(trajs_rank))]
    for i in eval_indices:
        train_indices.remove(i)

    def get_batch(batch_size=256, max_len=K):
        batch_inds = np.random.choice(
            np.array(train_indices),
            size=batch_size,
            replace=True,
        )
        s, a, r, d, rtg, timesteps, mask, dist = [], [], [], [], [], [], [], []
        for i in range(batch_size):
            traj = trajectories[int(batch_inds[i])]
            si = random.randint(0, traj['rewards'].shape[0] - 1)

            s.append(traj['observations'][si:si + max_len].reshape(1, -1, state_dim))
            a.append(traj['actions'][si:si + max_len].reshape(1, -1, act_dim))
            r.append(traj['rewards'][si:si + max_len].reshape(1, -1, 1))
            if 'terminals' in traj:
                d.append(traj['terminals'][si:si + max_len].reshape(1, -1))
            else:
                d.append(traj['dones'][si:si + max_len].reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= max_ep_len] = max_ep_len-1  # padding cutoff
            rtg.append(discount_cumsum(traj['rewards'][si:], gamma=1.)[:s[-1].shape[1] + 1].reshape(1, -1, 1))
            if rtg[-1].shape[1] <= s[-1].shape[1]:
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)
            if condition in ('reward', 'xvel') and distributions == 'gaussian':
                mean = (np.array(label) * traj['r_dists'][si:si + max_len].reshape(-1, dist_dim)).sum(axis=-1).reshape(-1, 1)
                value = np.array(label) * np.ones((mean.shape[0], dist_dim))
                std = np.sqrt((((value - mean) ** 2) * traj['r_dists'][si:si + max_len].reshape(-1, dist_dim)).sum(axis=-1).reshape(-1, 1))
                target = np.concatenate([mean, std], axis=-1)
                dist.append(target.reshape(1, -1, 2))
                batch_dist_dim = 2
            elif condition in ('reward', 'xvel') and distributions == 'categorical':
                dist.append(traj['r_dists'][si:si + max_len].reshape(1, -1, dist_dim))
                batch_dist_dim = dist_dim
            elif condition in ('reward', 'xvel') and distributions == 'deterministic':
                dist.append(traj['r_dists'][si:si + max_len].reshape(1, -1, 1))
                batch_dist_dim = 1
            elif condition == 'xyvel' and distributions == 'gaussian':
                meanx = (np.array(label[0]) * traj['r_dists'][0][si:si + max_len].reshape(-1, dist_dim)).sum(axis=-1).reshape(-1, 1)
                valuex = np.array(label[0]) * np.ones((meanx.shape[0], dist_dim))
                stdx = np.sqrt((((valuex - meanx) ** 2) * traj['r_dists'][0][si:si + max_len].reshape(-1, dist_dim)).sum(axis=-1).reshape(-1, 1))

                meany = (np.array(label[1]) * traj['r_dists'][1][si:si + max_len].reshape(-1, dist_dim)).sum(axis=-1).reshape(-1, 1)
                valuey = np.array(label[1]) * np.ones((meany.shape[1], dist_dim))
                stdy = np.sqrt((((valuey - meany) ** 2) * traj['r_dists'][1][si:si + max_len].reshape(-1, dist_dim)).sum(axis=-1).reshape(-1, 1))

                target = np.concatenate([meanx, stdx, meany, stdy], axis=-1)
                dist.append(target.reshape(1, -1, 4))
                batch_dist_dim = 4
            elif condition == 'xyvel' and distributions == 'categorical':
                dist.append(np.concatenate([traj['r_dists'][0][si:si + max_len], traj['r_dists'][1][si:si + max_len]]).reshape(1, -1, dist_dim*2))
                batch_dist_dim = dist_dim * 2
            elif condition == 'xyvel' and distributions == 'deterministic':
                dist.append(traj['r_dists'][si:si + max_len].reshape(1, -1, 2))
                batch_dist_dim = 2
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1)
            s[-1] = (s[-1] - state_mean) / state_std
            a[-1] = np.concatenate([np.ones((1, max_len - tlen, act_dim)) * -10., a[-1]], axis=1)
            r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1)
            timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1))
            dist[-1] = np.concatenate([np.zeros((1, max_len - tlen, batch_dist_dim)), dist[-1]], axis=1)

        s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device)
        a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device)
        r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=device)
        d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=device)
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device) / scale
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device)
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device)
        dist = torch.from_numpy(np.concatenate(dist, axis=0)).to(dtype=torch.float32, device=device)

        return s, a, r, d, rtg, timesteps, mask, dist

    model = DistributionalDecisionTransformer(
        state_dim=state_dim,
        act_dim=act_dim,
        max_length=K,
        max_ep_len=max_ep_len,
        hidden_size=variant['embed_dim'],
        dist_dim=r_dists_dim,
        n_layer=variant['n_layer'],
        n_head=variant['n_head'],
        n_inner=4*variant['embed_dim'],
        activation_function=variant['activation_function'],
        n_positions=1024,
        resid_pdrop=variant['dropout'],
        attn_pdrop=variant['dropout'],
    )

    model = model.to(device=device)

    warmup_steps = variant['warmup_steps']
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=variant['learning_rate'],
        weight_decay=variant['weight_decay'],
    )
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lambda steps: min((steps+1)/warmup_steps, 1)
    )

    trainer = DistributionalSequenceTrainer(
        model=model,
        optimizer=optimizer,
        batch_size=batch_size,
        get_batch=get_batch,
        scheduler=scheduler,
        loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a)**2),
        eval_fns=None,
    )

    print('Starting training loop.')
    for itr in range(variant['max_iters']):
        outputs = trainer.train_only_iteration(num_steps=variant['num_steps_per_iter'], iter_num=itr+1, print_logs=True)
        if variant['save_model']:
            torch.save(model.state_dict(), os.path.join(output_dir, f'dt_{itr}.pth'))
        # record training loss, etc...
        if itr == 0:
            _basic_columns = ['iter']
            _record_values = [itr]
            for k, v in outputs.items():
                _basic_columns.append(k)
                _record_values.append(v)
            with open(os.path.join(output_dir, "train_log.txt"), "w") as f:
                print("\t".join(_basic_columns), file=f)
            with open(os.path.join(output_dir, "train_log.txt"), "a+") as f:
                print("\t".join(str(x) for x in _record_values), file=f)
        else:
            _record_values = [itr]
            for v in outputs.values():
                _record_values.append(v)
            with open(os.path.join(output_dir, "train_log.txt"), "a+") as f:
                print("\t".join(str(x) for x in _record_values), file=f)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='halfcheetah')
    parser.add_argument('--dataset', type=str, default='medium-expert')
    parser.add_argument('--K', type=int, default=20)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--embed_dim', type=int, default=128)
    parser.add_argument('--n_layer', type=int, default=3)
    parser.add_argument('--n_head', type=int, default=1)
    parser.add_argument('--activation_function', type=str, default='relu')
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--learning_rate', '-lr', type=float, default=1e-4)
    parser.add_argument('--weight_decay', '-wd', type=float, default=1e-4)
    parser.add_argument('--warmup_steps', type=int, default=10000)
    parser.add_argument('--max_iters', type=int, default=10)
    parser.add_argument('--num_steps_per_iter', type=int, default=10000)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--dist_dim', type=int, default=30)
    parser.add_argument('--n_bins', type=int, default=31)
    parser.add_argument('--gamma', type=float, default=1.00)
    parser.add_argument('--save_model', type=bool, default=False)
    parser.add_argument('--condition', type=str, default='reward')  # or xvel, xyvel
    parser.add_argument('--distributions', type=str, default='categorical')  # or gaussian, deterministic

    args = parser.parse_args()

    # random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # log dir
    save_dir = f'{args.env}-{args.dataset}-{args.distributions}-dim_{args.dist_dim}-bin_{args.n_bins}-gamma_{args.gamma}-{args.condition}-ctx_{args.K}-seed_{args.seed}'
    output_dir = os.path.join('./results', save_dir)
    os.makedirs(output_dir, exist_ok=True)

    with open(os.path.join(output_dir, 'params.json'), mode="w") as f:
        json.dump(args.__dict__, f, indent=4)

    experiment(output_dir, variant=vars(args))
