#! /usr/bin/env python
import os

import gym
import jax
import tqdm
import wandb
from absl import app, flags
from ml_collections import config_flags
from flax.training import checkpoints
import numpy as np
import jax.numpy as jnp

from jaxOfflineRL.agents import TD3BCOBPLearner
from jaxOfflineRL.data import D4RLDataset
from jaxOfflineRL.evaluation import evaluate, evaluate_normalized_state, evaluate_surrogate_normalized_state
from jaxOfflineRL.wrappers import wrap_gym

FLAGS = flags.FLAGS

SEEDS = [606847, 191778, 457260, 7718, 322217]
SEED = SEEDS[0]

flags.DEFINE_string("env_name", "halfcheetah-medium-expert-v2", "Environment name.")
flags.DEFINE_integer("seed", SEED, "Random seed.")
flags.DEFINE_integer("eval_episodes", 10, "Number of episodes used for evaluation.")
flags.DEFINE_integer("log_interval", 1000, "Logging interval.")
flags.DEFINE_integer("eval_interval", 5000, "Eval interval.")
flags.DEFINE_integer("batch_size", 256, "Mini batch size.")
flags.DEFINE_integer("max_steps", int(2e5), "Number of training steps.")
flags.DEFINE_boolean("tqdm", True, "Use tqdm progress bar.")
flags.DEFINE_float("filter_percentile", None, "Take top N% trajectories.")
flags.DEFINE_float(
    "filter_threshold", None, "Take trajectories with returns above the threshold."
)
flags.DEFINE_bool("load_model", False, "Set False to train, True to load without training")
flags.DEFINE_string("model_path", "./models/checkpoint", "Set False to train, True to load without training")
config_flags.DEFINE_config_file(
    "config",
    "configs/offline_config.py:td3_bc_obp_mujoco",
    "File path to the training hyperparameter configuration.",
    lock_config=False,
)


def main(_):

    env = gym.make(FLAGS.env_name)
    env = wrap_gym(env)
    env.seed(FLAGS.seed)

    dataset = D4RLDataset(env)
    state_mean, state_std = dataset.normalize_state()

    if FLAGS.filter_percentile is not None or FLAGS.filter_threshold is not None:
        dataset.filter(
            percentile=FLAGS.filter_percentile, threshold=FLAGS.filter_threshold
        )
    dataset.seed(FLAGS.seed)

    if "antmaze" in FLAGS.env_name:
        dataset.dataset_dict["rewards"] *= 100
    elif FLAGS.env_name.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]:
        dataset.normalize_returns(scaling=1000)

    kwargs = dict(FLAGS.config.model_config)
    if kwargs.pop("cosine_decay", False):
        kwargs["decay_steps"] = FLAGS.max_steps
    agent = globals()[FLAGS.config.model_constructor](
        FLAGS.seed, env.observation_space, env.action_space, **kwargs
    )
    CKPT_DIR = FLAGS.env_name
    os.makedirs(os.path.join(CKPT_DIR, "actor_evaluation"), exist_ok=True)
    os.makedirs(os.path.join(CKPT_DIR, "actor_tilde"), exist_ok=True)
    os.makedirs(os.path.join(CKPT_DIR, "critic"), exist_ok=True)

    total_mse_loss = []
    total_mse_loss_tilde = []
    total_q_value = []
    total_q_value_tilde = []
    total_q_value_data = []
    if not FLAGS.load_model:
        for i in tqdm.tqdm(
            range(1, FLAGS.max_steps + 1), smoothing=0.1, disable=not FLAGS.tqdm
        ):
            batch = dataset.sample(FLAGS.batch_size)
            info = agent.update(batch, i % kwargs["policy_delay"] == 0)

            """
            if i % FLAGS.eval_interval == 0:
                eval_info = evaluate_normalized_state(agent, env, FLAGS.eval_episodes, state_mean.squeeze(0), state_std.squeeze(0))
                #eval_info = evaluate(agent, env, FLAGS.eval_episodes)
                eval_info["return"] = env.get_normalized_score(eval_info["return"]) * 100.0
                for k, v in eval_info.items():
                    wandb.log({f"evaluation/{k}": v}, step=i)

                surrogate_info = evaluate_surrogate_normalized_state(agent, env, FLAGS.eval_episodes, state_mean.squeeze(0), state_std.squeeze(0))
                eval_info["return"] = env.get_normalized_score(surrogate_info["return"]) * 100.0
                for k, v in eval_info.items():
                    wandb.log({f"evaluation/{k}_surrogate": v}, step=i)
            """
        checkpoints.save_checkpoint(ckpt_dir=os.path.join(CKPT_DIR, "actor_evaluation"), target=agent._actor, step=0, overwrite=True)
        checkpoints.save_checkpoint(ckpt_dir=os.path.join(CKPT_DIR, "actor_tilde"), target=agent._bp_actor, step=0, overwrite=True)
        checkpoints.save_checkpoint(ckpt_dir=os.path.join(CKPT_DIR, "critic"), target=agent._critic, step=0, overwrite=True)


        start_idx = 0
        end_idx = FLAGS.batch_size
        n_itr = len(dataset) // FLAGS.batch_size + 1
        for i in tqdm.tqdm(range(n_itr)):
            batch = dataset.sample(FLAGS.batch_size * 10, indx=np.arange(start_idx, min(end_idx, len(dataset))))
            mse_loss, mse_loss_tilde, q_values, q_values_tilde, q_values_data = agent.eval_bcLoss_qValue(batch["observations"], batch["actions"])

            #break
            total_mse_loss.append(mse_loss)
            total_mse_loss_tilde.append(mse_loss_tilde)
            total_q_value.append(q_values)
            total_q_value_tilde.append(q_values_tilde)
            total_q_value_data.append(q_values_data)

            #if end_idx >= len(dataset):
            #if end_idx >= 100000:
            #    break
            start_idx += FLAGS.batch_size
            end_idx += FLAGS.batch_size

    else:
        checkpoints.restore_checkpoint(ckpt_dir=os.path.join(CKPT_DIR, "actor_evaluation"), target=agent._actor)
        checkpoints.restore_checkpoint(ckpt_dir=os.path.join(CKPT_DIR, "actor_tilde"), target=agent._bp_actor)
        checkpoints.restore_checkpoint(ckpt_dir=os.path.join(CKPT_DIR, "critic"), target=agent._critic)
        """get all data, predict their bc loss q values store the 2d matrix in a .npy"""
        start_idx = 0
        end_idx = FLAGS.batch_size
        n_itr = len(dataset) // FLAGS.batch_size + 1
        #while True:
        for i in tqdm.tqdm(range(n_itr)):
            batch = dataset.sample(FLAGS.batch_size * 10, indx=np.arange(start_idx, min(end_idx, len(dataset))))
            mse_loss, mse_loss_tilde, q_values, q_values_tilde, q_values_data = agent.eval_bcLoss_qValue(batch["observations"], batch["actions"])

            #break
            total_mse_loss.append(mse_loss)
            total_mse_loss_tilde.append(mse_loss_tilde)
            total_q_value.append(q_values)
            total_q_value_tilde.append(q_values_tilde)
            total_q_value_data.append(q_values_data)

            start_idx += FLAGS.batch_size
            end_idx += FLAGS.batch_size

    total_mse_loss = np.concatenate(total_mse_loss, axis=0)
    total_mse_loss_tilde = np.concatenate(total_mse_loss_tilde, axis=0)
    total_q_value = np.concatenate(total_q_value, axis=0)
    total_q_value_tilde = np.concatenate(total_q_value_tilde, axis=0)
    total_q_value_data = np.concatenate(total_q_value_data, axis=0)

    jnp.savez(f"{FLAGS.env_name}-TD3BCEP", total_mse_loss, total_mse_loss_tilde, total_q_value, total_q_value_tilde, total_q_value_data)


if __name__ == "__main__":
    app.run(main)
