import os
from absl import app, flags
import numpy as np
import time
from datetime import datetime
import glob

import tqdm

# from src.agents import hilp_ft_v2 as learner
from src.dataset_utils import GCDataset
from src import d4rl_utils
from src.d4rl_utils import kitchen_render
from src.env_dataset import get_env_and_dataset
from src.agents import hilp_ft_v2 as learner

import jax
import flax

from jaxrl_m.wandb import setup_wandb, default_wandb_config
import wandb
from jaxrl_m.evaluation import (
    evaluate_with_trajectories,
    EpisodeMonitor,
    supply_rng,
    get_frame,
    annotate_frame,
    env_reset,
    env_step,
)
from ml_collections import config_flags
import pickle

from src.utils import record_video, CsvLogger

FLAGS = flags.FLAGS
flags.DEFINE_string("agent_name", "hilp", "")
flags.DEFINE_string("env_name", "antmaze-large-play-v0", "Environment name.")

flags.DEFINE_string("save_dir", "exp/", "")
flags.DEFINE_string("restore_path", None, "")
flags.DEFINE_integer("restore_epoch", None, "")
flags.DEFINE_integer("seed", np.random.choice(1000000), "Random seed.")
flags.DEFINE_integer("eval_episodes", 50, "Number of episodes used for evaluation.")
flags.DEFINE_integer("num_video_episodes", 2, "")
flags.DEFINE_integer("log_interval", 1000, "Logging interval.")
flags.DEFINE_integer("eval_interval", 100000, "Eval interval.")
flags.DEFINE_integer("save_interval", 1000000, "Eval interval.")
flags.DEFINE_integer("batch_size", 256, "Mini batch size.")
flags.DEFINE_integer("num_pretraining_steps", int(1e6), "Number of pretraining steps.")
flags.DEFINE_integer("max_steps", int(1e6), "Number of training steps.")
flags.DEFINE_integer("warmup_steps", int(0), "Number of training steps.")
flags.DEFINE_float("act_temperature", 1, "")
flags.DEFINE_float("temperature", 3.0, "")
flags.DEFINE_string("run_group", "Debug", "")
flags.DEFINE_integer("viz_train", 0, "")

flags.DEFINE_float("lr", 3e-4, "")
flags.DEFINE_integer("value_hidden_dim", 512, "")
flags.DEFINE_integer("value_num_layers", 3, "")
flags.DEFINE_integer("actor_hidden_dim", 512, "")
flags.DEFINE_integer("actor_num_layers", 3, "")
flags.DEFINE_float("discount", 0.99, "")
flags.DEFINE_float("tau", 0.005, "")
flags.DEFINE_float("expectile", 0.95, "")
flags.DEFINE_integer("use_layer_norm", 1, "")
flags.DEFINE_integer("skill_dim", 32, "")
flags.DEFINE_float("skill_expectile", 0.9, "")
flags.DEFINE_float("skill_temperature", 10, "")
flags.DEFINE_float("skill_discount", 0.99, "")
flags.DEFINE_integer("load_value", 1, "")
flags.DEFINE_integer("load_pi", 1, "")
flags.DEFINE_integer("init_dataset", 1, "")
flags.DEFINE_integer("use_norm_rewards", 0, "")
flags.DEFINE_integer("use_maze_neg_rew", 1, "")
flags.DEFINE_float("maze_rew_scale", 1.0, "")
flags.DEFINE_float("maze_rew_shift", 0.0, "")

flags.DEFINE_float("p_currgoal", 0.0, "")
flags.DEFINE_float("p_trajgoal", 0.625, "")
flags.DEFINE_float("p_randomgoal", 0.375, "")
flags.DEFINE_float("p_aug", None, "")

flags.DEFINE_string("encoder", None, "")

flags.DEFINE_string("algo_name", None, "")  # Not used, only for logging

config_flags.DEFINE_config_dict("wandb", default_wandb_config(), lock_config=False)


def get_normalization(dataset):
    returns = []
    ret = 0
    for r, term in zip(dataset["rewards"], dataset["dones_float"]):
        ret += r
        if term:
            returns.append(ret)
            ret = 0
    return (max(returns) - min(returns)) / 1000


def main(_):
    g_start_time = int(datetime.now().timestamp())

    exp_name = ""
    exp_name += f"sd{FLAGS.seed:03d}_"
    if "SLURM_JOB_ID" in os.environ:
        exp_name += f's_{os.environ["SLURM_JOB_ID"]}.'
    if "SLURM_PROCID" in os.environ:
        exp_name += f'{os.environ["SLURM_PROCID"]}.'
    if "SLURM_RESTART_COUNT" in os.environ:
        exp_name += f'rs_{os.environ["SLURM_RESTART_COUNT"]}.'
    exp_name += f"{g_start_time}"
    exp_name += f'_{FLAGS.wandb["name"]}'

    # Create wandb logger
    FLAGS.wandb["project"] = "u2o_gcrl"
    FLAGS.wandb["name"] = FLAGS.wandb["exp_descriptor"] = exp_name
    FLAGS.wandb["group"] = FLAGS.wandb["exp_prefix"] = FLAGS.run_group
    setup_wandb(dict(), **FLAGS.wandb)

    FLAGS.save_dir = os.path.join(
        FLAGS.save_dir,
        wandb.run.project,
        FLAGS.env_name,
        wandb.config.exp_prefix,
        wandb.config.experiment_id,
    )
    os.makedirs(FLAGS.save_dir, exist_ok=True)

    env_and_dataset_dict = get_env_and_dataset(FLAGS, kitchen_full_obs=False)
    train_env = env_and_dataset_dict["env"]
    eval_env = env_and_dataset_dict["eval_env"]
    replay_buffer = env_and_dataset_dict["replay_buffer"]
    dataset = env_and_dataset_dict["dataset"]
    aux_env = env_and_dataset_dict["aux_env"]
    goal_info = env_and_dataset_dict["goal_info"]

    base_observation = jax.tree_map(lambda arr: arr[0], replay_buffer["observations"])
    train_env.reset()

    gc_dataset = GCDataset(
        dataset,
        p_currgoal=FLAGS.p_currgoal,
        p_trajgoal=FLAGS.p_trajgoal,
        p_randomgoal=FLAGS.p_randomgoal,
        discount=FLAGS.discount,
        p_aug=FLAGS.p_aug,
    )

    train_logger = CsvLogger(os.path.join(FLAGS.save_dir, "train.csv"))
    eval_logger = CsvLogger(os.path.join(FLAGS.save_dir, "eval.csv"))
    first_time = time.time()
    last_time = time.time()

    example_batch = dataset.sample(1)

    agent = learner.create_learner(
        FLAGS.seed,
        example_batch["observations"],
        example_batch["actions"],
        max_steps=FLAGS.num_pretraining_steps + FLAGS.max_steps,
        lr=FLAGS.lr,
        value_hidden_dims=(FLAGS.value_hidden_dim,) * FLAGS.value_num_layers,
        actor_hidden_dims=(FLAGS.actor_hidden_dim,) * FLAGS.actor_num_layers,
        discount=FLAGS.discount,
        tau=FLAGS.tau,
        expectile=FLAGS.expectile,
        use_layer_norm=FLAGS.use_layer_norm,
        skill_dim=FLAGS.skill_dim,
        skill_expectile=FLAGS.skill_expectile,
        skill_temperature=FLAGS.skill_temperature,
        skill_discount=FLAGS.skill_discount,
        encoder=FLAGS.encoder,
        use_norm_rewards=FLAGS.use_norm_rewards,
    )

    # load model
    if FLAGS.restore_path is not None:
        restore_path = FLAGS.restore_path
        candidates = glob.glob(restore_path)
        if len(candidates) == 0:
            raise Exception(f"Path does not exist: {restore_path}")
        if len(candidates) > 1:
            raise Exception(f"Multiple matching paths exist for: {restore_path}")
        if FLAGS.restore_epoch is None:
            restore_path = candidates[0] + "/params.pkl"
        else:
            restore_path = candidates[0] + f"/params_{FLAGS.restore_epoch}.pkl"
        with open(restore_path, "rb") as f:
            load_dict = pickle.load(f)

        def do_not_load(key):
            load_dict["agent"]["network"]["params"][key] = agent.network.params[key]
            load_dict["agent"]["network"]["opt_state"]["0"]["mu"][
                key
            ] = agent.network.opt_state[0].mu[key]
            load_dict["agent"]["network"]["opt_state"]["0"]["nu"][
                key
            ] = agent.network.opt_state[0].nu[key]

        if not FLAGS.load_value:
            for name in [
                "networks_skill_critic",
                "networks_skill_target_critic",
                "networks_skill_value",
                "networks_target_value",
                # "networks_value"
            ]:
                do_not_load(name)
        if not FLAGS.load_pi:
            do_not_load("networks_skill_actor")

        # load_dict["agent"]["network"]["rew_mean"] = agent.network.rew_mean
        # load_dict["agent"]["network"]["rew_std"] = agent.network.rew_std
        agent = flax.serialization.from_state_dict(agent, load_dict["agent"])
        print(f"Restored from {restore_path}")

    # observation, done = train_env.reset(), False
    policy_type = "eval"
    observation, obs_goal = env_reset(
        FLAGS.env_name,
        train_env,
        goal_info,
        base_observation,
        policy_type,
    )
    if "kitchen" in FLAGS.env_name and "visual" not in FLAGS.env_name:
        observation = observation[:30]

    done = False

    train_render, step = [], 0
    for i in tqdm.tqdm(
        range(1, FLAGS.num_pretraining_steps + FLAGS.max_steps + 1), smoothing=0.1
    ):
        # policy_fn = supply_rng(agent.sample_actions)
        train_metrics = {}
        if i <= FLAGS.num_pretraining_steps + 1:
            phi_obs, phi_goal = agent.get_phi(np.array([observation, obs_goal]))
            skill = (phi_goal - phi_obs) / np.linalg.norm(phi_goal - phi_obs)

        if i >= FLAGS.num_pretraining_steps + 1:
            # Collect samples
            agent, action = agent.sample_skill_actions(
                observations=observation,
                skills=skill,
                temperature=FLAGS.act_temperature,
            )
            action = np.array(action)
            next_observation, reward, done, info = env_step(
                FLAGS.env_name,
                train_env,
                action,
            )
            step += 1

            if not done or (
                "antmaze" in FLAGS.env_name and "TimeLimit.truncated" in info
            ):
                mask = 1.0
            else:
                mask = 0.0

            if "antmaze" in FLAGS.env_name:
                reward -= 1.0

            if FLAGS.viz_train:
                cur_frame = get_frame(FLAGS.env_name, train_env)
                cur_frame = annotate_frame(cur_frame, reward, {"mask": mask})
                train_render.append(cur_frame)

            transition = dict(
                observations=observation,
                actions=action,
                rewards=reward,
                masks=mask,
                dones_float=done,
                next_observations=next_observation,
                traj_ends=done,
            )
            replay_buffer.add_transition(transition)
            transition_info = {}
            # transition["action"] = np.ndarray(action)
            for k, v in transition.items():
                if isinstance(v, np.ndarray):
                    transition_info[f"transition/{k}_max"] = v.max()
                    transition_info[f"transition/{k}_mean"] = v.mean()
                    transition_info[f"transition/{k}_min"] = v.min()
                elif isinstance(v, float):
                    transition_info[f"transition/{k}"] = v
                elif isinstance(v, bool):
                    transition_info[f"transition/{k}"] = float(v)

            wandb.log(transition_info, step=i)
            observation = next_observation

            if done:
                observation, done = train_env.reset(), False
                if "kitchen" in FLAGS.env_name:
                    if "visual" in FLAGS.env_name:
                        observation = kitchen_render(train_env)
                    else:
                        observation = observation[:30]

        if i >= FLAGS.num_pretraining_steps + 1:
            batch = replay_buffer.sample(FLAGS.batch_size)
            batch["rewards"] *= FLAGS.maze_rew_scale
            batch["rewards"] += FLAGS.maze_rew_shift
            batch["skills"] = np.repeat(
                np.array(skill)[
                    None,
                ],
                batch["observations"].shape[0],
                axis=0,
            )
            agent, update_info = agent.finetune(batch)
        else:
            batch = gc_dataset.sample(FLAGS.batch_size)
            agent, update_info = agent.update(batch)


        if i % FLAGS.log_interval == 0:
            # train_metrics = {f"training/{k}": v for k, v in update_info.items()}
            if len(train_render) > 1:
                train_video = record_video("Video", i, renders=[np.array(train_render)])
                train_metrics["train_video"] = train_video
                train_render, step = [], 0
            for k, v in update_info.items():
                train_metrics[f"training/{k}"] = v
            train_metrics["time/epoch_time"] = (
                time.time() - last_time
            ) / FLAGS.log_interval
            train_metrics["time/total_time"] = time.time() - first_time

            if "hilp" in FLAGS.agent_name:
                log_batch, log_batch_next = replay_buffer.sample_seq(FLAGS.batch_size)
                feat_metrics = agent.log_feat_dot_prod(log_batch, log_batch_next, skill)
                train_metrics.update(feat_metrics)

            last_time = time.time()
            wandb.log(train_metrics, step=i)
            train_logger.log(train_metrics, step=i)

        if i == 1 or i % FLAGS.eval_interval == 0:
            eval_metrics = {}
            trajs_dict = {}

            num_episodes = FLAGS.eval_episodes
            num_video_episodes = FLAGS.num_video_episodes
            planning_info = None
            eval_info, cur_trajs, renders = evaluate_with_trajectories(
                agent,
                eval_env,
                goal_info=goal_info,
                env_name=FLAGS.env_name,
                num_episodes=num_episodes,
                base_observation=base_observation,
                num_video_episodes=num_video_episodes,
                policy_type=policy_type,
                planning_info=planning_info,
                skill=skill,
            )
            eval_metrics.update({f"{policy_type}/{k}": v for k, v in eval_info.items()})
            trajs_dict[policy_type] = cur_trajs

            if FLAGS.num_video_episodes > 0:
                video = record_video("Video", i, renders=renders)
                eval_metrics["eval_video"] = video

            wandb.log(eval_metrics, step=i)
            eval_logger.log(eval_metrics, step=i)

        if i % FLAGS.save_interval == 0:
            save_dict = dict(
                agent=flax.serialization.to_state_dict(agent),
            )

            fname = os.path.join(FLAGS.save_dir, f"params_{i}.pkl")
            print(f"Saving to {fname}")
            with open(fname, "wb") as f:
                pickle.dump(save_dict, f)

    train_logger.close()
    eval_logger.close()


if __name__ == "__main__":
    app.run(main)
