import os
import random

import numpy as np
import tqdm
from absl import app, flags
from ml_collections import config_flags
from tensorboardX import SummaryWriter

from jax_rl.agents import AWACLearner, SACLearner
from jax_rl.datasets import ReplayBuffer
from jax_rl.evaluation_fetch import evaluate
from jax_rl.utils_fetch import make_env

FLAGS = flags.FLAGS

flags.DEFINE_string('env_name', 'FetchPushDense-v2', 'Environment name.')
flags.DEFINE_string('save_dir', './tmp/', 'Tensorboard logging dir.')
flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_integer('eval_episodes', 10,
                     'Number of episodes used for evaluation.')
flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 5000, 'Eval interval.')
flags.DEFINE_integer('batch_size', 256, 'Mini batch size.')
flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')
flags.DEFINE_integer('start_training', int(1e4),
                     'Number of training steps to start training.')
flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')
flags.DEFINE_boolean('save_video', False, 'Save videos during evaluation.')
flags.DEFINE_boolean('small_init', True, 'Use smaller init for last policy layer')
flags.DEFINE_boolean('standardize',False,"Use equivariant standardization of the state")
flags.DEFINE_boolean("gan_betas", False, "use GAN betas or not")
flags.DEFINE_float("tau", 0.005, 'tau for SAC updates')
flags.DEFINE_float('clipping', 0.5, 'Gradient Norm magnitude at which to clip')
config_flags.DEFINE_config_file(
    'config',
    'configs/sac_default.py',
    'File path to the training hyperparameter configuration.',
    lock_config=False)

from representations import environment_symmetries

def main(_):
    summary_writer = SummaryWriter(
        os.path.join(FLAGS.save_dir,FLAGS.env_name, str(FLAGS.seed)))

    if FLAGS.save_video:
        video_train_folder = os.path.join(FLAGS.save_dir, 'video', 'train')
        video_eval_folder = os.path.join(FLAGS.save_dir, 'video', 'eval')
    else:
        video_train_folder = None
        video_eval_folder = None

    env = make_env(FLAGS.env_name, FLAGS.seed, video_train_folder)
    eval_env = make_env(FLAGS.env_name, FLAGS.seed + 42, video_eval_folder)

    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    kwargs = dict(FLAGS.config)
    kwargs.update(environment_symmetries[FLAGS.env_name])
    # kwargs['action_space'] = environment_symmetries[FLAGS.env_name]['action_space']
    algo = kwargs.pop('algo')
    replay_buffer_size = kwargs.pop('replay_buffer_size')
    action_dim = env.action_space.shape[0] if kwargs['action_space']=='continuous' else 1
    replay_buffer = ReplayBuffer(env.observation_space, action_dim,
                                 replay_buffer_size or FLAGS.max_steps,kwargs['state_rep'],
                                 kwargs['state_transform'],kwargs['inv_state_transform'],
                                 FLAGS.standardize)
    if algo == 'sac':
        print(np.asarray(env.action_space.sample())[None],np.asarray(env.action_space.sample())[None].shape)
        agent = SACLearner(FLAGS.seed,
                           env.observation_space.sample()[np.newaxis],
                           np.asarray(env.action_space.sample())[None],
                           standardizer=replay_buffer.running_stats.standardize if FLAGS.standardize else None,
                            clipping=FLAGS.clipping,
                            gan_betas=FLAGS.gan_betas,
                            tau=FLAGS.tau,**kwargs)
    else:
        raise NotImplementedError()

    eval_returns = []
    (observation, _), done = env.reset(seed=FLAGS.seed), False
    for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1),
                       smoothing=0.1,
                       disable=not FLAGS.tqdm):
        if i < FLAGS.start_training:
            action = env.action_space.sample()
        else:
            action = agent.sample_actions(observation)
        next_observation, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated

        if not done or truncated:
            mask = 1.0
        else:
            mask = 0.0

        replay_buffer.insert(observation, action, reward, mask,
                             next_observation)
        observation = next_observation

        if done:
            (observation, _), done = env.reset(seed=FLAGS.seed), False
            for k, v in info['episode'].items():
                summary_writer.add_scalar(f'training/{k}', v,
                                          info['total']['timesteps'])

        if i >= FLAGS.start_training:
            batch = replay_buffer.sample(FLAGS.batch_size)
            update_info = agent.update(batch)

            if i % FLAGS.log_interval == 0:
                for k, v in update_info.items():
                    summary_writer.add_scalar(f'training/{k}', v, i)
                summary_writer.flush()

        if i % FLAGS.eval_interval == 0:
            eval_stats = evaluate(agent, eval_env, FLAGS.eval_episodes)

            for k, v in eval_stats.items():
                summary_writer.add_scalar(f'evaluation/average_{k}s', v,
                                          info['total']['timesteps'])
            summary_writer.flush()

            eval_returns.append(
                (info['total']['timesteps'], eval_stats['return']))
            np.savetxt(os.path.join(FLAGS.save_dir, f'{FLAGS.seed}.txt'),
                       eval_returns,
                       fmt=['%d', '%.1f'])


if __name__ == '__main__':
    app.run(main)
