import rsa.utils as utils
from rsa.baselines import BehaviorCloning
import rsa.utils.plot_utils as pu
from rsa.utils.arg_parser import parse_args

import numpy as np
import matplotlib.pyplot as plt
import argparse
import pandas as pd
import os
import logging
log = logging.getLogger("main")


def main():
    params = parse_args()

    logdir = params['logdir']
    print('logging to', logdir)
    log.info('logging to %s' % logdir)
    os.makedirs(logdir)
    utils.init_logging(logdir)
    utils.seed(params['seed'])

    env, test_env = utils.make_env(params)
    print(env.action_space)
    horizon = params['horizon']

    bc = BehaviorCloning(env, logdir)

    if params['checkpoint'] is None:
        if params['env'] in utils.d4rl_envs:
            replay_buffer = utils.load_d4rl_replay_buffer(env, params, add_drtg=False)
        else:
            replay_buffer = utils.load_replay_buffer(params)
        losses = bc.train(replay_buffer, params['init_iters'])
        bc.save(os.path.join(logdir, 'bc.pth'))

        plt.figure()
        losses_trunc = np.array(losses[:(len(losses) // 100) * 100])
        plt.plot(losses_trunc.reshape((-1, 100)).mean(1))
        plt.savefig(os.path.join(logdir, 'loss.pdf'))
    else:
        bc.load(params['checkpoint'])

    rets = []
    for traj in range(100):
        traj_rewards = []
        obs = np.array(test_env.reset())
        for i in range(horizon):
            action = bc.act(obs)
            next_obs, reward, done, info = test_env.step(action)
            next_obs = np.array(next_obs)
            traj_rewards.append(reward)
            obs = next_obs
            if done:
                break
        rets.append(sum(traj_rewards))
        print('Traj %d, Reward: %d' % (traj, sum(traj_rewards)))

    rets = np.array(rets)
    avg = rets.mean()
    std = rets.std()

    print('Average: %.5f' % avg)

    num = 1000

    data = {
        'Epoch': np.arange(num),
        'TotalEnvInteracts': np.arange(num) * 1000,
        'AverageTestEpRet': [avg] * num,
        'StdTestEpRet': [std] * num,
    }
    df = pd.DataFrame(data=data)
    df.to_csv(os.path.join(params['logdir'], 'progress.csv'))




if __name__ == '__main__':
    main()
