import gym
import numpy as np
import torch

import argparse
import pickle
import random
import sys
import d4rl.gym_mujoco
import os
from decision_transformer.models.VACO import VACO
from utils import eval_policy
from tensorboardX import SummaryWriter


def discount_cumsum(x, gamma):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1]
    return discount_cumsum


def experiment(
        exp_prefix,
        variant,
):
    device = variant.get('device', 'cuda')

    env_name, dataset = variant['env'], variant['dataset']
    group_name = f'{exp_prefix}-{env_name}-{dataset}'
    exp_prefix = f'{group_name}-{random.randint(int(1e5), int(1e6) - 1)}'

    expectile = variant['expectile']
    discount = variant['discount']
    temper = variant['temperature']
    rt = variant['reward_type']
    ver = variant['ver']
    save_load_dir = f'{discount}_{expectile}_{temper}_{env_name}_{dataset}_{rt}_{ver}'
    env_name_total = f'{env_name}-{dataset}-{ver}'

    if env_name == 'hopper':
        env = gym.make('Hopper-v3')
        max_ep_len = 1000
        env_targets = [3600, 1800]  # evaluation conditioning targets
        scale = 1000.  # normalization for rewards/returns
    elif env_name == 'halfcheetah':
        env = gym.make('HalfCheetah-v3')
        max_ep_len = 1000
        env_targets = [12000, 6000]
        scale = 1000.
    elif env_name == 'walker2d':
        env = gym.make('Walker2d-v3')
        max_ep_len = 1000
        env_targets = [5000, 2500]
        scale = 1000.
    elif env_name == 'reacher2d':
        from decision_transformer.envs.reacher_2d import Reacher2dEnv
        env = Reacher2dEnv()
        max_ep_len = 100
        env_targets = [76, 40]
        scale = 10.
    elif env_name == 'antmaze':
        env = gym.make(env_name_total)
    else:
        raise NotImplementedError

    state_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    kwargs = {
        "state_dim": state_dim,
        "action_dim": act_dim,
        "expectile": variant['expectile'],
        "discount": variant['discount'],
        "tau": variant['tau'],
        "temperature": variant['temperature'],
        'group_size': variant['group_size'],
        'device': variant['device'],
    }

    policy = VACO(**kwargs)

    # load dataset
    dataset_path = f'data/{env_name}-{dataset}-{ver}.pkl'
    with open(dataset_path, 'rb') as f:
        trajectories = pickle.load(f)

    # save all path information into separate lists
    mode = variant.get('mode', 'normal')
    states, traj_lens, returns = [], [], []
    for path in trajectories:
        if mode == 'delayed':  # delayed: all rewards moved to end of trajectory
            path['rewards'][-1] = path['rewards'].sum()
            path['rewards'][:-1] = 0.
        states.append(path['observations'])
        traj_lens.append(len(path['observations']))
        returns.append(path['rewards'].sum())
    traj_lens, returns = np.array(traj_lens), np.array(returns)

    # used for input normalization
    states = np.concatenate(states, axis=0)
    state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

    num_timesteps = sum(traj_lens)

    print('=' * 50)
    print(f'Starting new experiment: {env_name} {dataset}')
    print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')
    print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')
    print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')
    print('=' * 50)

    batch_size = variant['batch_size']
    num_eval_episodes = variant['num_eval_episodes']
    pct_traj = variant.get('pct_traj', 1.)

    # only train on top pct_traj trajectories (for %BC experiment)
    num_timesteps = max(int(pct_traj*num_timesteps), 1)
    sorted_inds = np.argsort(returns)  # lowest to highest
    num_trajectories = 1
    timesteps = traj_lens[sorted_inds[-1]]
    ind = len(trajectories) - 2
    while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] <= num_timesteps:
        timesteps += traj_lens[sorted_inds[ind]]
        num_trajectories += 1
        ind -= 1
    sorted_inds = sorted_inds[-num_trajectories:]

    # used to reweight sampling so we sample according to timesteps instead of trajectories
    # p_sample do not use
    p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds])

    def get_batch_bi(batch_size=256):
        batch_inds = np.random.choice(
            np.arange(num_trajectories),
            size=batch_size,
            replace=True,
            p=p_sample,  # reweights so we sample according to timesteps
        )

        s, a, ns, r, nd = [], [], [], [], []
        for i in range(batch_size):
            traj = trajectories[int(sorted_inds[batch_inds[i]])]
            si = random.randint(0, traj['rewards'].shape[0] - 1)

            # get sequences from dataset
            s.append(traj['observations'][si].reshape(1, state_dim))
            a.append(traj['actions'][si].reshape(1, act_dim))
            ns.append(traj['next_observations'][si].reshape(1, state_dim))
            r.append(traj['rewards'][si].reshape(1, 1))
            if 'terminals' in traj:
                nd.append(1 - traj['terminals'][si].reshape(1, -1))
            else:
                nd.append(1 - traj['dones'][si].reshape(1, -1))

            # state + reward normalization
            if variant['norm']:
                s[-1] = (s[-1] - state_mean) / state_std
                ns[-1] = (ns[-1] - state_mean) / state_std
            if variant['reward_type'] == 0:
                r[-1] = (r[-1] - 0.5) * 4
            elif variant['reward_type'] == 1:
                r[-1] = r[-1] * 100
            elif variant['reward_type'] == 2:
                r[-1] = r[-1] * 1

        s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device).reshape(batch_size, -1)
        a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device).reshape(batch_size, -1)
        ns = torch.from_numpy(np.concatenate(ns, axis=0)).to(dtype=torch.float32, device=device).reshape(batch_size, -1)
        r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=device).reshape(batch_size, -1)
        nd = torch.from_numpy(np.concatenate(nd, axis=0)).to(dtype=torch.long, device=device).reshape(batch_size, -1)

        return s, a, ns, r, nd

    if variant['bi']:
        if not variant['load_model']:
            for t in range(int(variant['max_timesteps_critic'])):
                policy.train_critic(get_batch_bi, variant['batch_size'])
                if (t + 1) % variant['eval_freq'] == 0:
                    print(f"Time steps: {t + 1}")
                    policy.save_critic(save_load_dir)
        else:
            policy.critic.load_state_dict(
                torch.load('/gym/' + save_load_dir + '/critic_s' + str(
                    variant['load_num']) + '.pth'))
            policy.critic_target.load_state_dict(
                torch.load('/gym/' + save_load_dir + '/critic_target_s' + str(
                    variant['load_num']) + '.pth'))
            policy.value.load_state_dict(
                torch.load('/gym/' + save_load_dir + '/value_s' + str(
                    variant['load_num']) + '.pth'))

        evaluations = []
        test_str = (str(variant['env']) + '_' + str(variant['dataset']) + '_iql_bi_' + str(
            variant['discount']) + '_' + str(variant['expectile']) + '_' + str(variant['temperature'])
                    + '_' + str(variant['norm']) + '_' + str(variant['reward_type']) + '_' + str(variant['ver']))
        writer = SummaryWriter(log_dir=os.path.join('log_antmaze', test_str))
        for t in range(int(variant['max_timesteps_actor'])):
            policy.train_actor(get_batch_bi, variant["batch_size"])
            # Evaluate episode
            if (t + 1) % variant['eval_freq'] == 0:
                print(f"Time steps: {t + 1}")
                evaluations.append(eval_policy(policy, env, variant['seed_eval'], state_mean, state_std, variant['norm'],
                                               eval_episodes=variant['num_eval_episodes']))
                writer.add_scalar('train_loss', evaluations[-1][0], t)
                policy.save_actor(save_load_dir)
    else:
        if not variant['load_model']:
            evaluations = []
            test_str = (str(variant['env']) + '_' + str(variant['dataset']) + '_iql_ori_' + str(
                variant['discount']) + '_' + str(variant['expectile']) + '_' + str(variant['temperature'])
                        + '_' + str(variant['norm']) + '_' + str(variant['reward_type']) + '_' + str(variant['ver']))
            writer = SummaryWriter(log_dir=os.path.join('log_antmaze', test_str))
            for t in range(int(variant['max_timesteps'])):
                policy.train(get_batch_bi, variant["batch_size"])
                # Evaluate episode
                if (t + 1) % variant['eval_freq'] == 0:
                    print(f"Time steps: {t + 1}")
                    evaluations.append(eval_policy(policy, env, variant['seed_eval'], state_mean, state_std, variant['norm'],
                                                   eval_episodes=variant['num_eval_episodes']))
                    writer.add_scalar('train_loss', evaluations[-1][0], t)
                    policy.save(save_load_dir)
        else:
            policy.critic.load_state_dict(
                torch.load('/gym/' + save_load_dir + '/critic_s' + str(
                    variant['load_num']) + '.pth'))
            policy.critic_target.load_state_dict(
                torch.load('/gym/' + save_load_dir + '/critic_target_s' + str(
                    variant['load_num']) + '.pth'))
            policy.value.load_state_dict(
                torch.load('/gym/' + save_load_dir + '/value_s' + str(
                    variant['load_num']) + '.pth'))

            evaluations = []
            test_str = (str(variant['env']) + '_' + str(variant['dataset']) + '_iql_ori_load_' + str(
                variant['discount']) + '_' + str(variant['expectile']) + '_' + str(variant['temperature'])
                        + '_' + str(variant['norm']) + '_' + str(variant['reward_type']) + '_' + str(variant['ver']))
            writer = SummaryWriter(log_dir=os.path.join('log_antmaze', test_str))
            for t in range(int(variant['max_timesteps'])):
                policy.train_ori(get_batch_bi, variant["batch_size"])
                # Evaluate episode
                if (t + 1) % variant['eval_freq'] == 0:
                    print(f"Time steps: {t + 1}")
                    evaluations.append(eval_policy(policy, env, variant['seed_eval'], state_mean, state_std, variant['norm'],
                                                   eval_episodes=variant['num_eval_episodes']))
                    writer.add_scalar('train_loss', evaluations[-1][0], t)
                    policy.save_ori(save_load_dir)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='hopper')
    parser.add_argument('--dataset', type=str, default='medium')  # medium, medium-replay, medium-expert, expert
    parser.add_argument('--mode', type=str, default='normal')  # normal for standard setting, delayed for sparse
    parser.add_argument('--pct_traj', type=float, default=1.)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--num_eval_episodes', type=int, default=100)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument("--eval_freq", default=5e3, type=int)  # How often (time steps) we evaluate
    parser.add_argument("--max_timesteps", default=1e6, type=int)  # Max time steps to run environment
    parser.add_argument("--max_timesteps_critic", default=1e6, type=int)  # Max time steps to run environment
    parser.add_argument("--max_timesteps_actor", default=1e6, type=int)  # Max time steps to run environment
    parser.add_argument("--save_model", action="store_true")  # Save model and optimizer parameters
    parser.add_argument("--load_model", default=1, type=int)
    parser.add_argument("--seed_train", default=0, type=int)
    parser.add_argument("--seed_eval", default=42, type=int)
    parser.add_argument("--expectile", default=0.9, type=float)
    parser.add_argument("--discount", default=0.99, type=float)  # Discount factor
    parser.add_argument("--tau", default=0.005, type=float)  # Target network update rate
    parser.add_argument("--temperature", default=10.0, type=float)
    # bi
    parser.add_argument("--group_size", default=4, type=int)
    parser.add_argument("--load_num", default=1000000, type=int)
    parser.add_argument("--bi", default=1, type=int)
    parser.add_argument("--norm", default=0, type=int)
    parser.add_argument("--reward_type", default=0, type=int)
    parser.add_argument('--ver', type=str, default='v2')

    args = parser.parse_args()

    random.seed(args.seed_train)
    np.random.seed(args.seed_train)
    torch.manual_seed(args.seed_train)
    torch.cuda.manual_seed_all(args.seed_train)

    experiment('gym-experiment', variant=vars(args))
