import glob, tqdm, wandb, os, json, random, time, jax, flax
from absl import app, flags
from ml_collections import config_flags
from log_utils import setup_wandb, get_exp_name, get_flag_dict, get_wandb_video,CsvLogger

from envs.env_utils import make_env_and_datasets
from envs.ogbench_utils import make_ogbench_env_and_datasets

from utils.flax_utils import save_agent, restore_agent_with_file
from utils.datasets import Dataset, ReplayBuffer

from evaluation import evaluate
from agents import agents

import numpy as np

if 'CUDA_VISIBLE_DEVICES' in os.environ:
    os.environ['EGL_DEVICE_ID'] = os.environ['CUDA_VISIBLE_DEVICES']
    os.environ['MUJOCO_EGL_DEVICE_ID'] = os.environ['CUDA_VISIBLE_DEVICES']

FLAGS = flags.FLAGS

flags.DEFINE_string('run_group', 'Debug', 'Run group.')
flags.DEFINE_integer('seed', 0, 'Random seed.')
flags.DEFINE_string('env_name', 'cube-double-play-singletask-task2-v0', 'Environment (dataset) name.')
flags.DEFINE_string('save_dir', 'exp/', 'Save directory.')
flags.DEFINE_string('exp_name', None, 'Experiment name (overrides automatic generation).')
flags.DEFINE_boolean('resume', False, 'Resume from checkpoint if found.')

flags.DEFINE_integer('offline_steps', 1000000, 'Number of online steps.')
flags.DEFINE_integer('online_steps', 0, 'Number of online steps.')
flags.DEFINE_integer('buffer_size', 2000000, 'Replay buffer size.')
flags.DEFINE_integer('log_interval', 5000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 100000, 'Evaluation interval.')
flags.DEFINE_integer('save_interval', -1, 'Save interval.')
flags.DEFINE_integer('max_checkpoints', 1, 'Maximum number of checkpoints to keep.')
flags.DEFINE_integer('start_training', 5000, 'when does training start')

flags.DEFINE_integer('utd_ratio', 1, "update to data ratio")

flags.DEFINE_float('discount', 0.99, 'discount factor')

flags.DEFINE_integer('eval_episodes', 50, 'Number of evaluation episodes.')
flags.DEFINE_integer('video_episodes', 0, 'Number of video episodes for each task.')
flags.DEFINE_integer('video_frame_skip', 3, 'Frame skip for videos.')

# It is related to the function get_config
config_flags.DEFINE_config_file('agent', default='agents/bb.py', lock_config=False)

flags.DEFINE_float('dataset_proportion', 1.0, "Proportion of the dataset to use")
flags.DEFINE_integer('dataset_replace_interval', 1000, 'Dataset replace interval, used for large datasets because of memory constraints')
flags.DEFINE_string('ogbench_dataset_dir', None, 'OGBench dataset directory')

flags.DEFINE_integer('horizon_length', 5, 'action chunking length.')
flags.DEFINE_bool('sparse', False, "make the task sparse reward")

flags.DEFINE_bool('save_all_online_states', False, "save all trajectories to npy")
flags.DEFINE_string('support_type', 'universal', 'Support type for distributional RL (universal, data_centric, custom)')
flags.DEFINE_float('data_margin', 0.1, 'Margin for data-based support for value function.')

from utils.smdp import smdp_return_range

class LoggingHelper:
    def __init__(self, csv_loggers, wandb_logger):
        self.csv_loggers = csv_loggers
        self.wandb_logger = wandb_logger
        self.first_time = time.time()
        self.last_time = time.time()

    def log(self, data, prefix, step):
        assert prefix in self.csv_loggers, prefix
        self.csv_loggers[prefix].log(data, step=step)
        self.wandb_logger.log({f'{prefix}/{k}': v for k, v in data.items()}, step=step)

def main(_):
    exp_name = FLAGS.exp_name if FLAGS.exp_name else get_exp_name(FLAGS.seed)
    
    project_name = FLAGS.project_name if hasattr(FLAGS, 'project_name') and FLAGS.project_name else 'cgq'
    save_dir_base = os.path.join(FLAGS.save_dir, project_name, FLAGS.run_group, FLAGS.env_name, exp_name)
    
    wandb_id = None
    wandb_id_path = os.path.join(save_dir_base, 'wandb_id.txt')
    if FLAGS.resume and os.path.exists(wandb_id_path):
        with open(wandb_id_path, 'r') as f:
            wandb_id = f.read().strip()
            print(f"Resuming WandB run ID: {wandb_id}")

    run = setup_wandb(project=project_name, group=FLAGS.run_group, name=exp_name, resume_id=wandb_id)
    
    FLAGS.save_dir = os.path.join(FLAGS.save_dir, wandb.run.project, FLAGS.run_group, FLAGS.env_name, exp_name)
    os.makedirs(FLAGS.save_dir, exist_ok=True)
    
    # Save run ID if we didn't resume or if we just created a new one
    if not wandb_id:
        with open(os.path.join(FLAGS.save_dir, 'wandb_id.txt'), 'w') as f:
            f.write(run.id)
    flag_dict = get_flag_dict()

    with open(os.path.join(FLAGS.save_dir, 'flags.json'), 'w') as f:
        json.dump(flag_dict, f)

    config = FLAGS.agent
    
    # data loading
    if FLAGS.ogbench_dataset_dir is not None:
        # custom ogbench dataset
        assert FLAGS.dataset_replace_interval != 0
        assert FLAGS.dataset_proportion == 1.0
        dataset_idx = 0
        dataset_paths = [
            file for file in sorted(glob.glob(f"{FLAGS.ogbench_dataset_dir}/*.npz")) if '-val.npz' not in file
        ]
        env, eval_env, train_dataset, val_dataset = make_ogbench_env_and_datasets(
            FLAGS.env_name,
            dataset_path=dataset_paths[dataset_idx],
            compact_dataset=False,
        )
    else:
        env, eval_env, train_dataset, val_dataset = make_env_and_datasets(FLAGS.env_name)

    # house keeping
    random.seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)

    online_rng, rng = jax.random.split(jax.random.PRNGKey(FLAGS.seed), 2)
    log_step = 0
    
    discount = FLAGS.discount
    config["horizon_length"] = FLAGS.horizon_length
    if config['agent_name'] == 'fql':
        config['horizon_length'] = 1    # FQL does not use action chunking
    # handle dataset
    def process_train_dataset(ds):
        """
        Process the train dataset to 
            - handle dataset proportion
            - handle sparse reward
            - convert to action chunked dataset
        """

        ds = Dataset.create(**ds)
        if FLAGS.dataset_proportion < 1.0:
            new_size = int(len(ds['masks']) * FLAGS.dataset_proportion)
            ds = Dataset.create(
                **{k: v[:new_size] for k, v in ds.items()}
            )
        
        if FLAGS.sparse:
            # Create a new dataset with modified rewards instead of trying to modify the frozen one
            sparse_rewards = (ds["rewards"] != 0.0) * -1.0
            ds_dict = {k: v for k, v in ds.items()}
            ds_dict["rewards"] = sparse_rewards
            ds = Dataset.create(**ds_dict)


        return ds
    
    train_dataset = process_train_dataset(train_dataset)
    
    if any(x in config['agent_name'] for x in ['deas']):
        support_type = FLAGS.support_type
        r_min, r_max = train_dataset.min_rewards, train_dataset.max_rewards
        L = FLAGS.horizon_length
        H = 1000
        gamma1 = FLAGS.discount
        gamma2 = config['discount']

        # Set config values that were placeholders
        config['action_sequence'] = L
        config['action_dim'] = train_dataset['actions'].shape[-1]
        config['ob_dims'] = train_dataset['observations'].shape[1:]

        if support_type == 'data_centric':
            print('Using data-centric support for value function')
            stats = train_dataset.stats
            delta = FLAGS.data_margin * stats['delta']
            config['v_min'] = float(stats['p1'] - delta)
            config['v_max'] = float(stats['p99'] + delta)

        elif support_type == 'universal':
            print('Using universal support for value function')
            v_min, v_max = smdp_return_range(r_min, r_max, L, H, gamma1, gamma2)
            config['v_min'] = float(v_min)
            config['v_max'] = float(v_max)

        elif support_type == 'custom':
            assert config['v_min'] is not None and config['v_max'] is not None, 'v_min and v_max must be set'

        else:
            raise ValueError('Invalid support type. Please check the support type to be one of: data, smdp, or custom.')

        print(f'v_min: {config["v_min"]}, v_max: {config["v_max"]}')

        example_batch = train_dataset.sample(())
    else:    
        example_batch = train_dataset.sample(())
    
    agent_class = agents[config['agent_name']]
    agent = agent_class.create(
        FLAGS.seed,
        example_batch['observations'],
        example_batch['actions'],
        config,
    )

    start_step = 0
    if FLAGS.resume:
        chkpts = glob.glob(os.path.join(FLAGS.save_dir, 'params_*.pkl'))
        if chkpts:
            epochs = []
            for c in chkpts:
                try:
                    e = int(c.split('params_')[-1].split('.pkl')[0])
                    epochs.append(e)
                except:
                    pass
            if epochs:
                start_step = max(epochs)
                resume_file = os.path.join(FLAGS.save_dir, f'params_{start_step}.pkl')
                print(f"Resuming from {resume_file}, step {start_step}")
                agent = restore_agent_with_file(agent, resume_file)
                
                if FLAGS.ogbench_dataset_dir is not None and FLAGS.dataset_replace_interval != 0:
                    dataset_idx = (start_step // FLAGS.dataset_replace_interval) % len(dataset_paths)
                    print(f"Fast-forwarding dataset to index {dataset_idx}")
                    train_dataset, val_dataset = make_ogbench_env_and_datasets(
                        FLAGS.env_name,
                        dataset_path=dataset_paths[dataset_idx],
                        compact_dataset=False,
                        dataset_only=True,
                        cur_env=env,
                    )
                    train_dataset = process_train_dataset(train_dataset)

    # Setup logging.
    prefixes = ["eval", "env"]
    if FLAGS.offline_steps > 0:
        prefixes.append("offline_agent")
    if FLAGS.online_steps > 0:
        prefixes.append("online_agent")

    log_mode = 'a' if (FLAGS.resume and start_step > 0) else 'w'
    logger = LoggingHelper(
        csv_loggers={prefix: CsvLogger(os.path.join(FLAGS.save_dir, f"{prefix}.csv"), mode=log_mode) 
                    for prefix in prefixes},
        wandb_logger=wandb,
    )

    offline_init_time = time.time()
    log_step = start_step

    # Offline RL
    if log_step < FLAGS.offline_steps:
        offline_loop_start = log_step + 1
        for i in tqdm.tqdm(range(offline_loop_start, FLAGS.offline_steps + 1)):
            log_step += 1

            if FLAGS.ogbench_dataset_dir is not None and FLAGS.dataset_replace_interval != 0 and i % FLAGS.dataset_replace_interval == 0:
                dataset_idx = (dataset_idx + 1) % len(dataset_paths)
                print(f"Using new dataset: {dataset_paths[dataset_idx]}", flush=True)
                train_dataset, val_dataset = make_ogbench_env_and_datasets(
                    FLAGS.env_name,
                    dataset_path=dataset_paths[dataset_idx],
                    compact_dataset=False,
                    dataset_only=True,
                    cur_env=env,
                )
                train_dataset = process_train_dataset(train_dataset)

            batch = train_dataset.sample_sequence(config['batch_size'], sequence_length=FLAGS.horizon_length, discount=discount)
            agent, offline_info = agent.update(batch)

            if i % FLAGS.log_interval == 0:
                logger.log(offline_info, "offline_agent", step=log_step)
            
            # saving
            if FLAGS.save_interval > 0 and i % FLAGS.save_interval == 0:
                save_agent(agent, FLAGS.save_dir, log_step, max_to_keep=FLAGS.max_checkpoints)

            # eval
            if (FLAGS.eval_interval != 0 and (i % FLAGS.eval_interval == 0 or i == 1)):
                # during eval, the action chunk is executed fully
                eval_info, trajs, cur_renders = evaluate(
                    agent=agent,
                    env=eval_env,
                    action_dim=example_batch["actions"].shape[-1],
                    num_eval_episodes=FLAGS.eval_episodes,
                    num_video_episodes=FLAGS.video_episodes,
                    video_frame_skip=FLAGS.video_frame_skip,
                )
                logger.log(eval_info, "eval", step=log_step)
                if FLAGS.video_episodes > 0:
                    video = get_wandb_video(cur_renders)   # wandb.Video 객체 생성
                    wandb.log({"video": video}, step=log_step)

        
    # transition from offline to online
    replay_buffer = ReplayBuffer.create_from_initial_dataset(
        dict(train_dataset), size=max(FLAGS.buffer_size, train_dataset.size + 1)
    )
        
    ob, _ = env.reset()
    
    action_queue = []
    action_dim = example_batch["actions"].shape[-1]

    # Online RL
    update_info = {}

    from collections import defaultdict
    data = defaultdict(list)
    online_init_time = time.time()
    
    online_start_i = 1
    if log_step >= FLAGS.offline_steps:
        online_start_i = (log_step - FLAGS.offline_steps) + 1

    for i in tqdm.tqdm(range(online_start_i, FLAGS.online_steps + 1)):
        log_step += 1
        online_rng, key = jax.random.split(online_rng)
        
        # during online rl, the action chunk is executed fully
        if len(action_queue) == 0:
            action = agent.sample_actions(observations=ob, rng=key)

            action_chunk = np.array(action).reshape(-1, action_dim)
            for action in action_chunk:
                action_queue.append(action)
        action = action_queue.pop(0)
        
        next_ob, int_reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated

        if FLAGS.save_all_online_states:
            state = env.get_state()
            data["steps"].append(i)
            data["obs"].append(np.copy(next_ob))
            data["qpos"].append(np.copy(state["qpos"]))
            data["qvel"].append(np.copy(state["qvel"]))
            if "button_states" in state:
                data["button_states"].append(np.copy(state["button_states"]))
        
        # logging useful metrics from info dict
        env_info = {}
        for key, value in info.items():
            if key.startswith("distance"):
                env_info[key] = value
        # always log this at every step
        logger.log(env_info, "env", step=log_step)

        if 'antmaze' in FLAGS.env_name and (
            'diverse' in FLAGS.env_name or 'play' in FLAGS.env_name or 'umaze' in FLAGS.env_name
        ):
            # Adjust reward for D4RL antmaze.
            int_reward = int_reward - 1.0
        # elif is_robomimic_env(FLAGS.env_name):
        #     # Adjust online (0, 1) reward for robomimic
        #     int_reward = int_reward - 1.0

        if FLAGS.sparse:
            assert int_reward <= 0.0
            int_reward = (int_reward != 0.0) * -1.0

        transition = dict(
            observations=ob,
            actions=action,
            rewards=int_reward,
            terminals=float(done),
            masks=1.0 - terminated,
            next_observations=next_ob,
        )
        replay_buffer.add_transition(transition)
        
        # done
        if done:
            ob, _ = env.reset()
            action_queue = []  # reset the action queue
        else:
            ob = next_ob

        if i >= FLAGS.start_training:
            batch = replay_buffer.sample_sequence(config['batch_size'] * FLAGS.utd_ratio, 
                        sequence_length=FLAGS.horizon_length, discount=discount)
            batch = jax.tree.map(lambda x: x.reshape((
                FLAGS.utd_ratio, config["batch_size"]) + x.shape[1:]), batch)

            agent, update_info["online_agent"] = agent.batch_update(batch)
            
        if i % FLAGS.log_interval == 0:
            for key, info in update_info.items():
                logger.log(info, key, step=log_step)
            update_info = {}

            if FLAGS.eval_interval != 0 and i % FLAGS.eval_interval == 0:
                renders = []
                eval_info, trajs, cur_renders = evaluate(
                    agent=agent,
                    env=eval_env,
                    action_dim=action_dim,
                    num_eval_episodes=FLAGS.eval_episodes,
                    num_video_episodes=FLAGS.video_episodes,
                    video_frame_skip=FLAGS.video_frame_skip,
                )
                renders.extend(cur_renders)
                
                logger.log(eval_info, "eval", step=log_step)

            
            
        # saving
        if FLAGS.save_interval > 0 and i % FLAGS.save_interval == 0:
            save_agent(agent, FLAGS.save_dir, log_step, max_to_keep=FLAGS.max_checkpoints)

    end_time = time.time()
    
    for key, csv_logger in logger.csv_loggers.items():
        csv_logger.close()

    if FLAGS.save_all_online_states:
        c_data = {"steps": np.array(data["steps"]),
                 "qpos": np.stack(data["qpos"], axis=0), 
                 "qvel": np.stack(data["qvel"], axis=0), 
                 "obs": np.stack(data["obs"], axis=0), 
                 "offline_time": online_init_time - offline_init_time,
                 "online_time": end_time - online_init_time,
        }
        if len(data["button_states"]) != 0:
            c_data["button_states"] = np.stack(data["button_states"], axis=0)
        np.savez(os.path.join(FLAGS.save_dir, "data.npz"), **c_data)

    with open(os.path.join(FLAGS.save_dir, 'token.tk'), 'w') as f:
        f.write(run.url)

if __name__ == '__main__':
    app.run(main)
