import glob
import json
import os
import random
import time
from collections import defaultdict

import jax
import numpy as np
import tqdm
import wandb
from absl import app, flags
from ml_collections import config_flags

from agents import agents
from envs.env_utils import make_env_and_datasets
from envs.reward_utils import get_success_fn
from utils.datasets import Dataset, ReplayBuffer, CDataset, GCDataset, HGCDataset, ACGCDataset, normalize, unnormalize
from utils.evaluation import evaluate
from utils.flax_utils import restore_agent, save_agent
from utils.log_utils import CsvLogger, get_exp_name, get_flag_dict, get_wandb_video, setup_wandb

FLAGS = flags.FLAGS

flags.DEFINE_string('job_id', '', 'Job id')
flags.DEFINE_string('run_group', 'Debug', 'Run group.')
flags.DEFINE_integer('seed', 0, 'Random seed.')
flags.DEFINE_string('env_name', 'puzzle-4x5-play-oraclerep-v0', 'Environment (dataset) name.')
flags.DEFINE_string('dataset_dir', None, 'Dataset directory.')
flags.DEFINE_integer('dataset_replace_interval', 1000, 'Dataset replace interval.')
flags.DEFINE_integer('num_datasets', None, 'Number of datasets to use.')
flags.DEFINE_string('save_dir', 'exp/', 'Save directory.')
flags.DEFINE_string('restore_path', None, 'Restore path.')
flags.DEFINE_integer('restore_epoch', None, 'Restore epoch.')

flags.DEFINE_integer('offline_steps', 5000000, 'Number of offline steps.')
flags.DEFINE_integer('log_interval', 10000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 100000, 'Evaluation interval.')
flags.DEFINE_integer('save_interval', 100000, 'Saving interval.')

flags.DEFINE_integer('eval_episodes', 15, 'Number of episodes for each task.')
flags.DEFINE_float('eval_temperature', 0, 'Actor temperature for evaluation.')
flags.DEFINE_float('eval_gaussian', None, 'Action Gaussian noise for evaluation.')
flags.DEFINE_integer('video_episodes', 1, 'Number of video episodes for each task.')
flags.DEFINE_integer('video_frame_skip', 3, 'Frame skip for videos.')
flags.DEFINE_boolean('normalize_obs', False, 'Normalize observations')

flags.DEFINE_integer('mb_batch_size', 50000, 'batch size for model rollouts')
flags.DEFINE_integer('mb_retain_epoch', 5, 'retaining for model rollouts')
flags.DEFINE_integer('mb_rollout_freq', 1000, 'freq for model rollouts')
flags.DEFINE_float('mb_batch_ratio', 0.5, 'ratio for model rollouts')

config_flags.DEFINE_config_file('agent', 'agents/sharsa.py', lock_config=False)

def main(_):
    # Set up logger.
    if FLAGS.job_id and os.path.exists(f'wandb_id/{FLAGS.job_id}'):
        with open(f'wandb_id/{FLAGS.job_id}', 'r') as F:
            wandb_id, exp_name = F.read().split(' ')
        setup_wandb(project='scalembrl', group=FLAGS.run_group, wandb_id=wandb_id)
    else:
        exp_name = get_exp_name(FLAGS.seed)
        run = setup_wandb(project='scalembrl', group=FLAGS.run_group, name=exp_name)
        wandb_id = run.id
   
    if FLAGS.job_id:
        with open(f'wandb_id/{FLAGS.job_id}', 'w') as F:
            F.write(f"{wandb_id} {exp_name}")

    FLAGS.save_dir = os.path.join(FLAGS.save_dir, wandb.run.project, FLAGS.run_group, exp_name)
    os.makedirs(FLAGS.save_dir, exist_ok=True)
    flag_dict = get_flag_dict()
    with open(os.path.join(FLAGS.save_dir, 'flags.json'), 'w') as f:
        json.dump(flag_dict, f)

    # Set up environment and datasets.
    config = FLAGS.agent
    if FLAGS.dataset_dir is None:
        datasets = [None]
    else:
        # Dataset directory.
        datasets = [file for file in sorted(glob.glob(f'{FLAGS.dataset_dir}/*.npz')) if '-val.npz' not in file]
    if FLAGS.num_datasets is not None:
        datasets = datasets[: FLAGS.num_datasets]
    dataset_idx = 0
    env, train_dataset, val_dataset = make_env_and_datasets(FLAGS.env_name, dataset_path=datasets[dataset_idx])
    goal_conditioned = 'singletask' not in FLAGS.env_name

    if FLAGS.normalize_obs:
        #print(val_dataset.keys())
        scale = {
          k: {
            'mean': val_dataset[k].mean(axis=0),
            'std': val_dataset[k].std(axis=0),
            'min': val_dataset[k].min(axis=0),
            'max': val_dataset[k].max(axis=0),
          } for k in ['observations', 'oracle_reps']
        }
    else:
        scale = {
          k: None for k in ['observations', 'oracle_reps']
        }

    # Initialize agent.
    random.seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)

    dataset_class_dict = {
        'CDataset': CDataset,
        'GCDataset': GCDataset,
        'HGCDataset': HGCDataset,
        'ACGCDataset': ACGCDataset,
    }
    dataset_class = dataset_class_dict[config['dataset_class']]
    train_dataset = dataset_class(Dataset.create(scale=scale, **train_dataset), config)
    val_dataset = dataset_class(Dataset.create(scale=scale, **val_dataset), config)

    example_batch = train_dataset.sample(1)

    agent_class = agents[config['agent_name']]
    #if config['agent_name'].startswith('mb'):
    _success_fn = get_success_fn(FLAGS.env_name)
    success_fn = lambda obs, goal: _success_fn(unnormalize(scale['observations'], obs), unnormalize(scale['oracle_reps'], goal))
    agent = agent_class.create(
        FLAGS.seed,
        example_batch,
        config,
        success_fn,
    )

    # Restore agent.
    epoch = 0
    if os.path.exists(f'wandb_id/step_{FLAGS.job_id}'):
        with open(f'wandb_id/step_{FLAGS.job_id}', 'r') as F: 
            epoch = int(F.read())
        agent = restore_agent(agent, FLAGS.save_dir, epoch)

    if FLAGS.restore_path is not None:
        agent = restore_agent(agent, FLAGS.restore_path, FLAGS.restore_epoch)

    # Train agent.
    train_logger = CsvLogger(os.path.join(FLAGS.save_dir, 'train.csv'))
    eval_logger = CsvLogger(os.path.join(FLAGS.save_dir, 'eval.csv'))
    first_time = time.time()
    last_time = time.time()

    rng = jax.random.PRNGKey(FLAGS.seed)
    fake_data = defaultdict(list)
    for _ in range(FLAGS.mb_retain_epoch):
        rng, rollout_rng = jax.random.split(rng)
        batch = train_dataset.sample(FLAGS.mb_batch_size)
        fake_batch = agent.rollout(batch, rollout_rng)
        for k, v in fake_batch.items():
            fake_data[k].append(v)
    fake_data = {k: np.concatenate(v, axis=0) for k, v in fake_data.items()}
    fake_buffer = ReplayBuffer(Dataset.create(scale=scale, **fake_data))
    fake_buffer.size = fake_buffer.max_size

    for i in tqdm.tqdm(range(epoch+1, FLAGS.offline_steps + 1), smoothing=0.1, dynamic_ncols=True):
        if i % FLAGS.mb_rollout_freq == 0:
            rng, rollout_rng = jax.random.split(rng)
            batch = train_dataset.sample(FLAGS.mb_batch_size)
            fake_batch = agent.rollout(batch, rollout_rng)
            fake_buffer.add_batch(fake_batch)

        real_batch_size = int(config['batch_size'] * FLAGS.mb_batch_ratio)
        fake_batch_size = config['batch_size'] - real_batch_size
        real_batch = train_dataset.sample(real_batch_size)
        fake_batch = fake_buffer.sample(fake_batch_size)
        agent, update_info = agent.update(real_batch, fake_batch)

        # Log metrics.
        if i == 1 or i % FLAGS.log_interval == 0:
            if hasattr(agent, 'compute_metrics'):
                rng, key = jax.random.split(rng)
                add_metrics = agent.compute_metrics(real_batch, fake_batch, rng=key)
                update_info.update({f'metrics/{k}': v for k, v in add_metrics.items()})
            train_metrics = {f'training/{k}': v for k, v in update_info.items()}

            val_batch = val_dataset.sample(real_batch_size)
            _, val_info = agent.total_loss(val_batch, fake_batch, grad_params=None)
            train_metrics.update({f'validation/{k}': v for k, v in val_info.items()})

            train_metrics['time/epoch_time'] = (time.time() - last_time) / FLAGS.log_interval
            train_metrics['time/total_time'] = time.time() - first_time
            last_time = time.time()
            wandb.log(train_metrics, step=i)
            train_logger.log(train_metrics, step=i)

        # Evaluate agent.
        if FLAGS.eval_interval != 0 and (i == -1 or i % FLAGS.eval_interval == 0):
            renders = []
            eval_metrics = {}
            overall_metrics = defaultdict(list)

            if goal_conditioned: 
                task_infos = env.unwrapped.task_infos if hasattr(env.unwrapped, 'task_infos') else env.task_infos
                num_tasks = len(task_infos)
                for task_id in tqdm.trange(1, num_tasks + 1):
                    task_name = task_infos[task_id - 1]['task_name']
                    eval_info, trajs, cur_renders = evaluate(
                        agent=agent,
                        env=env,
                        env_name=FLAGS.env_name,
                        goal_conditioned=True,
                        task_id=task_id,
                        config=config,
                        num_eval_episodes=FLAGS.eval_episodes,
                        num_video_episodes=FLAGS.video_episodes,
                        video_frame_skip=FLAGS.video_frame_skip,
                        eval_temperature=FLAGS.eval_temperature,
                        eval_gaussian=FLAGS.eval_gaussian,
                        scale=scale,
                    )

                    renders.extend(cur_renders)
                    metric_names = ['success', 'dynamics_loss', 'state_loss', 'value_loss']
                    eval_metrics.update(
                        {f'evaluation/{task_name}_{k}': v for k, v in eval_info.items() if k in metric_names}
                    )
                    for k, v in eval_info.items():
                        if k in metric_names:
                            overall_metrics[k].append(v)
               
                for k, v in overall_metrics.items():
                    eval_metrics[f'evaluation/overall_{k}'] = np.mean(v)

                if FLAGS.video_episodes > 0:
                    video = get_wandb_video(renders=renders, n_cols=5)
                    eval_metrics['video'] = video
            else:
                eval_info, trajs, cur_renders = evaluate(
                    agent=agent,
                    env=env,
                    env_name=FLAGS.env_name,
                    goal_conditioned=False,
                    task_id=None,
                    config=config,
                    num_eval_episodes=FLAGS.eval_episodes,
                    num_video_episodes=FLAGS.video_episodes,
                    video_frame_skip=FLAGS.video_frame_skip,
                    eval_temperature=FLAGS.eval_temperature,
                    eval_gaussian=FLAGS.eval_gaussian,
                    scale=scale,
                )
                renders.extend(cur_renders)
                metric_names = ['success', 'dynamics_loss', 'state_loss', 'value_loss']
                eval_metrics.update(
                    {f'evaluation/{k}': v for k, v in eval_info.items() if k in metric_names}
                )
                for k, v in eval_info.items():
                    if k in metric_names:
                        eval_metrics[f'evaluation/{k}'] = v

                if FLAGS.video_episodes > 0:
                    video = get_wandb_video(renders=renders, n_cols=1)
                    eval_metrics['video'] = video

            wandb.log(eval_metrics, step=i)
            eval_logger.log(eval_metrics, step=i)

        # Save agent.
        if i == 1 or i % FLAGS.save_interval == 0:
            save_agent(agent, FLAGS.save_dir, i)
            if FLAGS.job_id:
                with open(f'wandb_id/step_{FLAGS.job_id}', 'w') as F: F.write(f"{i}")

        if FLAGS.dataset_replace_interval != 0 and i % FLAGS.dataset_replace_interval == 0 and len(datasets) > 1:
            dataset_idx = (dataset_idx + 1) % len(datasets)
            train_dataset, val_dataset = make_env_and_datasets(
                FLAGS.env_name, dataset_path=datasets[dataset_idx], dataset_only=True, cur_env=env
            )
            train_dataset = dataset_class(Dataset.create(scale=scale, **train_dataset), config)
            val_dataset = dataset_class(Dataset.create(scale=scale, **val_dataset), config)

    train_logger.close()
    eval_logger.close()


if __name__ == '__main__':
    app.run(main)
