from rsa.algos import MCSAC
import rsa.utils as utils
import rsa.utils.pytorch_utils as ptu
from rsa.utils.arg_parser import parse_args
from rsa.utils.logx import EpochLogger
import rsa.utils.spb_utils as spbu

import numpy as np
from tqdm import trange
import os
import json

if __name__ == '__main__':
    params = parse_args(sac_args=True)

    utils.seed(params['seed'])
    logdir = params['logdir']
    os.makedirs(logdir)
    os.makedirs(os.path.join(logdir, 'misc'))
    ptu.setup(params['device'])
    with open(os.path.join(logdir, 'hparams.json'), 'w') as f:
        json.dump(params, f)

    env, test_env = utils.make_env(params)
    is_pointbot_env = params['env'] in ('spb', 'rpb', 'lpb', 'hpb', 'mpb', 'lpb_easy')

    logger = EpochLogger(output_dir=logdir, exp_name=params['exper_name'])
    loss_plotter = utils.LossPlotter(os.path.join(logdir, 'loss_plots'))

    # rsa = TD3((17,), (6,), 1)
    sac = MCSAC(params)

    if params['env'] in utils.d4rl_envs:
        replay_buffer = utils.load_d4rl_replay_buffer(env, params, add_drtg=True, add_gqe=params['gqe'])
    else:
        if params['gen_data']:
            NUM_BC_EPISODES = 20
            expert_policy = utils.make_expert_policy(params, test_env)
            utils.generate_offline_data(test_env, expert_policy, params)
        replay_buffer = utils.load_replay_buffer(params, add_drtg=True, add_gqe=False)
        # replay_buffer = utils.load_replay_buffer(params, add_drtg=True, add_gqe=params['gqe'])

    if params['checkpoint'] is not None:
        sac.load(params['checkpoint'])
    else:
        print('Pretraining Policy')
        os.makedirs(os.path.join(logdir, 'pretrain_plots'))
        for i in trange(params['init_iters']):
            info = sac.update(replay_buffer, init=True)
            loss_plotter.add_data(**info)
            if i > 0 and i % 1000 == 0:
#                spbu.plot_Q(sac, env,
#                            points=np.array([transition['obs'] for transition in
#                                             replay_buffer.all_transitions()]) * (180, 150),
#                            file=os.path.join(logdir, 'pretrain_plots', 'q_%d.pdf' % i),
#                            skip=2)
#                if params['plot_drtg_maxes']:
#                    spbu.plot_maxes(sac, env,
#                                    file=os.path.join(logdir, 'pretrain_plots', 'q_maxes_%d.pdf' % i))
#                    sac.drtg_buffer = set()
#                    sac.bellman_buffer = set()
                loss_plotter.plot()
        if params['init_iters'] > 0:
            sac.save(os.path.join(logdir, 'pretrain'))
            loss_plotter.plot()

    # Run training loop
    # Prepare for interaction with environment
    i = 0
    n_episodes = 0
    epoch = 0
    metrics = {
        'Timesteps': 0,
    }
    robosuite = params['env'] in ('Lift', 'Door', 'NutAssembly', 'TwoArmPegInHole')

    total_timesteps = params['total_timesteps']

    while i < total_timesteps:
        # Collect one trajectory
        obs, done, t = env.reset(), False, 0
        ep_buf, rets = [], []
        while not done and t < params['horizon']:
            ################################################################################
            # Every params['eval_freq'] timesteps, run the evaluation loop and output logs #
            ################################################################################
            if i % params['eval_freq'] == 0:

                print('Testing Agent')
                for j in range(params['num_eval_episodes']):
                    obs, done, ep_ret, ep_len = test_env.reset(), False, 0, 0
                    while not done:
                        # Take deterministic actions at test time (noise_scale=0)
                        act = sac.select_action(obs, evaluate=True)
                        next_obs, rew, done, info = test_env.step(act)
                        ep_ret += rew
                        ep_len += 1
                        obs = next_obs
                    if robosuite:
                        test_env.close()
                    logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)

                # Log info about epoch
                logger.log_tabular('Epoch', epoch)
                logger.log_tabular('TotalEnvInteracts', i)
                logger.log_tabular('TestEpRet')
                logger.log_tabular('TestEpLen', average_only=True)
                if epoch == 0:
                    logger.log_tabular('AverageTrainEpRet', 0)
                    logger.log_tabular('StdTrainEpRet', 0)
                    logger.log_tabular('TrainEpLen', 0)
                    logger.log_tabular('Q1', 0)
                    logger.log_tabular('Q2', 0)
                    if params['do_diagnostics']:
                        logger.log_tabular('Q_expert', 0)
                        logger.log_tabular('Q_online', 0)
                        logger.log_tabular('DRTG_prob_expert', 0)
                        logger.log_tabular('DRTG_prob_online', 0)
                else:
                    logger.log_tabular('TrainEpRet')
                    logger.log_tabular('TrainEpLen', average_only=True)
                    logger.log_tabular('Q1', average_only=True)
                    logger.log_tabular('Q2', average_only=True)
                    if params['do_diagnostics']:
                        logger.log_tabular('Q_expert', average_only=True)
                        logger.log_tabular('Q_online', average_only=True)
                        logger.log_tabular('DRTG_prob_expert', average_only=True)
                        logger.log_tabular('DRTG_prob_online', average_only=True)
                for metric, value in metrics.items():
                    logger.log_tabular(metric, value)
                logger.dump_tabular()

                epoch += 1
                loss_plotter.plot()
                sac.save(os.path.join(logdir, 'models'))

                if is_pointbot_env:
                    spbu.plot_Q(sac, env,
                                os.path.join(logdir, 'misc', 'q_%d.pdf' % i),
                                skip=2)
                if params['plot_drtg_maxes']:
                    spbu.plot_maxes(sac, env,
                                    os.path.join(logdir, 'misc', 'q_maxes_%d.pdf' % n_episodes))
                    sac.drtg_buffer = set()
                    sac.bellman_buffer = set()

            ########################
            # Begin policy updates #
            ########################

            if i < params['start_timesteps']:
                act = env.action_space.sample()
                a_expert = None
            else:
                if np.random.random() < params['greedy_exp_eps']:
                    act = env.action_space.sample()
                else:
                    act = sac.select_action(obs)

            next_obs, rew, done, info = env.step(act)
            ep_buf.append({
                'obs': obs,
                'next_obs': next_obs,
                'act': act,
                'rew': utils.shift_reward(rew, params),
                'done': done,
                'expert': 0,
                'goal': info['goal'] if 'goal' in info else 0,
                'mask': info['mask'] if 'mask' in info
                else (1 if t == params['horizon'] else float(not done))

            })
            obs = next_obs

            i += 1
            t += 1
            rets.append(rew)
            metrics['Timesteps'] += 1

            # grad steps
            if i >= params['start_timesteps']:
                for _ in range(params['update_n_steps']):
                    if len(replay_buffer) == 0:
                        break
                    info = sac.update(replay_buffer)
                    logger.store(**info)
                    loss_plotter.add_data(**info)

        x, succ = 0, 0

        for j, transition in enumerate(reversed(ep_buf)):
            # TODO We need to come up with a good way to estimate this for general environments.
            #   For the goal conditioned method it's easy to say the rest of the rewards will
            #   always be -1 or 0. However, for general environments this is not the case.
            #   Possible options I've considered are assuming it will always be minimum reward,
            #   mean reward or median rewar, or last reward.
            #   -
            #   For now I'm implementing last reward
            if j == 0:
                succ = succ or transition['goal']
                if not transition['mask']:
                    x = transition['rew']
                else:
                    # Set drtg to infinite discounted reward sum.
                    # reward_estimate = np.median(rets)
                    reward_estimate = ep_buf[-1]['rew']
                    if params['discount'] < 1:
                        x = reward_estimate / (1 - params['discount'])
                    else:
                        x = reward_estimate * float('inf')
            else:
                x = transition['rew'] + transition['mask'] * params['discount'] * x

            # print(x, transition['rew'])
            transition['drtg'] = x
            transition['succ'] = succ
            del transition['goal']

        # if params['gqe']:
        #     psums = [np.nan for _ in range(params['horizon'])]
        #     future_obs = [np.nan*np.ones(params['d_obs']) for _ in range(params['horizon'])]
        #     x = 0
        #     for j, transition in enumerate(ep_buf):
        #         if j == len(ep_buf)-1:
        #             if not transition['mask']:
        #                 x += transition['rew']
        #             else:
        #                 # Set drtg to infinite discounted reward sum.
        #                 # reward_estimate = np.median(rets)
        #                 reward_estimate = ep_buf[-1]['rew']
        #                 if params['discount'] < 1:
        #                     x =+ reward_estimate / (1 - params['discount'])
        #                 else:
        #                     x =+ reward_estimate * float('inf')
        #         else:
        #             x = transition['rew'] + transition['mask'] * params['discount'] * x
        #         psums[j] = x
        #         future_obs[j] = np.array(transition['next_obs'])
        #     for j, transition in enumerate(ep_buf):
        #         transition['psums'] = psums[j:] + [np.nan]*j
        #         transition['future_obs'] = future_obs[j:] + [np.nan*np.ones(params['d_obs'])]*j

        for transition in ep_buf:
            replay_buffer.store_transition(transition)

        if robosuite:
            env.close()

        logger.store(TrainEpRet=sum(rets), TrainEpLen=len(rets))
        n_episodes += 1
