import os
import time
from copy import deepcopy
import uuid

import numpy as np
import pprint

import gym
import torch
import d4rl

from utils import *
from replay_buffer import *
from logger import Logger
from agent.test import test
from sampler import get_data,get_init_obs
from gym import wrappers
import set_dataset
import pytorch_util

def rl_train(flags):
    FLAGS = flags
    if not FLAGS.get_dataset:
        logdir = FLAGS.logdir_prefix + FLAGS.exp_name + '_' + FLAGS.env + '_' + time.strftime("%d-%m-%Y_%H-%M-%S")
        my_logger = Logger(logdir)
    set_random_seed(FLAGS.seed)
    env = gym.make(FLAGS.env)
    #record_frequency = 1
    #video_callable = lambda episode: episode % record_frequency == 0
    #env = wrappers.Monitor(env, directory=logdir, video_callable=video_callable, force=False)
    env.seed(FLAGS.seed)
    params = {
        'ob_dim': env.observation_space.shape[0],
        'ac_dim': env.action_space.shape[0],
        'policy_log_std_multiplier': FLAGS.policy_log_std_multiplier,
        'policy_log_std_offset': FLAGS.policy_log_std_offset,
        'orthogonal_init': FLAGS.orthogonal_init,
        'batch_size': FLAGS.batch_size,
        'update_freq': FLAGS.update_freq,
        'config': FLAGS.config,
        'device':FLAGS.device,
        'max_trajs_length':FLAGS.max_traj_length,
        'eval_n_trajs':FLAGS.eval_n_trajs,
    }

    '''
    if FLAGS.get_dataset:
        policy = None
        data_path = FLAGS.data_path + FLAGS.env
        if FLAGS.load_path:
            policy = pytorch_util.TanhGaussianPolicy(env.observation_space.shape[0],env.action_space.shape[0],
                                                     FLAGS.config.policy_arch,FLAGS.policy_log_std_multiplier,FLAGS.policy_log_std_offset,FLAGS.orthogonal_init,)
        set_dataset.save_dataset(env,FLAGS.max_size,FLAGS.max_traj_length,data_path,FLAGS.device,policy)
        set_dataset.load_dataset(data_path)
        return
    '''

    dataset = get_d4rl_dataset(env)
    agent = test(env,params,FLAGS.opt_class)
    dataset['rewards'] = dataset['rewards'] * FLAGS.reward_scale + FLAGS.reward_bias
    if FLAGS.load_path != '':
        agent.load(FLAGS.load_path)
    for epoch in range(FLAGS.n_epochs):
        log = {}
        with Timer() as train_timer:
            for batch_idx in range(FLAGS.n_train_step_per_epoch):
                batch = subsample_batch(dataset,FLAGS.batch_size)
                #if FLAGS.rl_class == 'test':
                #batch['init_observations'] = get_data(dataset, batch_idx, FLAGS.batch_size)#get_init_obs(env)#
                batch = batch_to_torch(batch,FLAGS.device)
                log.update(agent.train(batch))
        with Timer() as eval_timer:
            if epoch == 0 or (epoch + 1) % FLAGS.eval_period == 0:
                log.update(agent.eval())
                if FLAGS.save_model:
                    path = logdir
                    agent.save(path,epoch)

        log['train_time'] = train_timer()
        log['eval_time'] = eval_timer()
        log['epoch_time'] = train_timer() + eval_timer()
        print('epoch:',epoch)
        my_logger.log_scalar(epoch,'epoch',epoch)
        for key,value in log.items():
            print('{}:{}'.format(key,value))
            my_logger.log_scalar(value,key,epoch)
        my_logger.row(log)
        print('================================\n')
    my_logger.close()