import os
import random
import time
import glob

import numpy as np
import tqdm
from absl import app, flags
from ml_collections import config_flags
from tensorboardX import SummaryWriter

from jax_rl.agents import AWACLearner, VanillaSACLearner
from jax_rl.datasets import ReplayBuffer
from jax_rl.evaluation import evaluate
from jax_rl.utils import make_env,  _should_record

import wandb
from wandb_log import init_wandb_or_disable

FLAGS = flags.FLAGS

flags.DEFINE_string('env_name', 'HalfCheetah-v2', 'Environment name.')
flags.DEFINE_string('save_dir', './tmp/', 'Tensorboard logging dir.')
flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_integer('eval_episodes', 10,
                     'Number of episodes used for evaluation.')
flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 10000, 'Eval interval.')
flags.DEFINE_integer('batch_size', 256, 'Mini batch size.')
flags.DEFINE_integer('updates_per_step', 1, 'Gradient updates per step.')
flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')
flags.DEFINE_integer('start_training', int(1e4),
                     'Number of training steps to start training.')
flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')
flags.DEFINE_boolean('save_video', False, 'Save videos during evaluation.')
flags.DEFINE_integer('video_interval', 10000, 'Video saving interval.')
flags.DEFINE_boolean('debug', False, 'Track experiments with Weights and Biases.')
flags.DEFINE_string('wandb_project', None, "W&B project name")
flags.DEFINE_string('wandb_entity', None, "W&B entity name")
flags.DEFINE_string('wandb_run_name', None, 'W&B run name')
flags.DEFINE_string('wandb_group', None, 'W&B group name')
flags.DEFINE_list('wandb_tags', [], 'W&B tags')
flags.DEFINE_string('wandb_mode', 'online', 'W&B mode: online, offline, disabled')
config_flags.DEFINE_config_file(
    'config',
    'configs/sac_default.py',
    'File path to the training hyperparameter configuration.',
    lock_config=False)


def main(_):
    method='vanilla_sac'
    fname = f'{method}_seed{FLAGS.seed}'
    save_dir = FLAGS.save_dir
    
    kwargs = dict(FLAGS.config)
    algo = kwargs.pop('algo')

    if not FLAGS.debug:
        summary_writer = SummaryWriter(
            os.path.join(save_dir, FLAGS.env_name, fname))

    video_train_folder = None
    video_eval_folder = None
    # if FLAGS.save_video:
    #     video_train_folder = os.path.join(FLAGS.save_dir, 'video', 'train')
    #     video_eval_folder = os.path.join(FLAGS.save_dir, 'video', 'eval')
    # else:
    #     video_train_folder = None
    #     video_eval_folder = None

    env = make_env(FLAGS.env_name, FLAGS.seed, video_train_folder)
    # eval_env = make_env(FLAGS.env_name, FLAGS.seed + 42, video_eval_folder)
    
    model_bits = [
        "Vanilla_SAC",
        FLAGS.env_name,
    ]
    model_name = "_".join(model_bits)
    
    if FLAGS.wandb_run_name is None:
        FLAGS.wandb_run_name = model_name + f"_seed{FLAGS.seed}"
    if not FLAGS.debug:
        init_wandb_or_disable(FLAGS, fname, model_name=model_name, extra_cfg=kwargs)

    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    replay_buffer_size = kwargs.pop('replay_buffer_size')
    if algo == 'sac':
        agent = VanillaSACLearner(FLAGS.seed,
                           env.observation_space.sample()[np.newaxis],
                           env.action_space.sample()[np.newaxis], **kwargs)
    else:
        raise NotImplementedError()

    replay_buffer = ReplayBuffer(env.observation_space, env.action_space.shape[0],
                                 replay_buffer_size or FLAGS.max_steps, None, None, None)

    eval_returns = []
    observation, done = env.reset(), False
    for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1),
                       smoothing=0.1,
                       disable=not FLAGS.tqdm):
        if i < FLAGS.start_training:
            action = env.action_space.sample()
        else:
            action = agent.sample_actions(observation)
        next_observation, reward, done, info = env.step(action)

        if not done or 'TimeLimit.truncated' in info:
            mask = 1.0
        else:
            mask = 0.0

        replay_buffer.insert(observation, action, reward, mask, next_observation)
        observation = next_observation

        if done:
            observation, done = env.reset(), False
            
            ep_log = {}
            for k, v in info['episode'].items():
                ep_log[f"training/{k}"] = v
            step_to_use = info.get('total', {}).get('timesteps', i)
            if not FLAGS.debug:
                wandb.log(ep_log, step=step_to_use)
            
                for k, v in info['episode'].items():
                    summary_writer.add_scalar(f'training/{k}', v,
                                            info['total']['timesteps'])

        if i >= FLAGS.start_training:
            for _ in range(FLAGS.updates_per_step):
                batch = replay_buffer.sample(FLAGS.batch_size)
                update_info = agent.update(batch)

            if i % FLAGS.log_interval == 0 and not FLAGS.debug:
                for k, v in update_info.items():
                    summary_writer.add_scalar(f'training/{k}', v, i)
                summary_writer.flush()

        if i % FLAGS.eval_interval == 0:
            
            record_now = _should_record(i, FLAGS.save_video, FLAGS.video_interval)
            eval_vid_dir = (os.path.join(FLAGS.save_dir, FLAGS.env_name, fname, 'eval_videos', f"step_{i}")) if record_now else None
            
            eval_env = make_env(FLAGS.env_name, FLAGS.seed + 42, save_folder=eval_vid_dir)
            
            eval_stats = evaluate(agent,eval_env,FLAGS.eval_episodes)

            step_to_use = info.get('total', {}).get('timesteps', i)
            if not FLAGS.debug:
                wandb.log({f"evaluation/average_{k}s": v for k, v in eval_stats.items()}, step=step_to_use)

                for k, v in eval_stats.items():
                    summary_writer.add_scalar(f'evaluation/average_{k}s', v,
                                            info['total']['timesteps'])
                summary_writer.flush()

                if eval_vid_dir is not None:
                    mp4s = sorted(glob.glob(os.path.join(eval_vid_dir, '*.mp4')))
                    for ep_idx, vf in enumerate(mp4s):
                        wandb.log(
                            {f"eval/video/ep{ep_idx}": wandb.Video(vf, fps=30, format="mp4",
                                                                   caption=f"{FLAGS.env_name} @ step {step_to_use} (ep {ep_idx})")},
                            step=step_to_use
                        )
                try:
                    eval_env.close()
                except Exception:
                    pass

                eval_returns.append(
                    (info['total']['timesteps'], eval_stats['return']))
                np.savetxt(os.path.join(FLAGS.save_dir, FLAGS.env_name, fname, f'{FLAGS.seed}.txt'),
                        eval_returns,
                        fmt=['%d', '%.1f'])


if __name__ == '__main__':
    app.run(main)