import os
import pathlib
import pickle
import pprint
from collections import defaultdict

import absl.app
import absl.flags
import gym
import jax
import numpy as np
import torch
import transformers
from flax.training.early_stopping import EarlyStopping
from tqdm import tqdm, trange

import bpref_v2.envs.wrappers as wrappers
from bpref_v2.data.preference_dataset import PrefDataset
from bpref_v2.data.qlearning_dataset import (
    qlearning_ant_dataset,
    qlearning_factorworld_dataset,
    qlearning_metaworld_dataset,
    qlearning_robosuite_dataset,
)
from bpref_v2.data.replay_buffer import get_d4rl_dataset
from bpref_v2.envs import MetaWorld
from bpref_v2.utils.viskit.logging import setup_logger

from ..utils.jax_utils import batch_to_jax
from ..utils.utils import (
    WandBLogger,
    define_flags_with_default,
    get_user_flags,
    prefix_metrics,
    save_pickle,
    set_random_seed,
)
from .algos import (
    EnsembleMRLearner,
    EnsemblePTLearner,
    MRLearner,
    NMRLearner,
    PTLearner,
    VMRLearner,
    VPTLearner,
)
from .sampler import TrajSampler

# Jax memory
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".95"


FLAGS_DEF = define_flags_with_default(
    env="halfcheetah-medium-v2",
    model_type="MLP",
    max_traj_length=1000,
    seed=42,
    data_seed=42,
    save_model=True,
    batch_size=64,
    early_stop=False,
    min_delta=1e-4,
    patience=10,
    n_epochs=2000,
    eval_period=10,
    save_period=100,
    comment="",
    robosuite=False,
    dataset_path="",
    robosuite_dataset_type="ph",
    robosuite_max_episode_steps=500,
    data=PrefDataset.get_default_config(),
    mlp=MRLearner.get_default_config(),
    visual_mlp=VMRLearner.get_default_config(),
    transformer=PTLearner.get_default_config(),
    visual_transformer=VPTLearner.get_default_config(),
    lstm=NMRLearner.get_default_config(),
    logging=WandBLogger.get_default_config(),
    num_ensembles=1,
)


def main(_):
    FLAGS = absl.flags.FLAGS
    jax_devices = jax.local_devices()
    variant = get_user_flags(FLAGS, FLAGS_DEF)

    save_dir = FLAGS.logging.output_dir + "/" + FLAGS.env
    save_dir += "/" + str(FLAGS.model_type) + "/"

    FLAGS.logging.group = f"{FLAGS.env}_{FLAGS.model_type}"
    assert FLAGS.comment, "You must leave your comment for logging experiment."
    FLAGS.logging.group += f"_{FLAGS.comment}"
    FLAGS.logging.experiment_id = FLAGS.logging.group + f"_s{FLAGS.seed}"
    save_dir += FLAGS.comment + "/"
    save_dir += "s" + str(FLAGS.seed)

    setup_logger(variant=variant, seed=FLAGS.seed, base_log_dir=save_dir, include_exp_prefix_sub_dir=False)

    FLAGS.logging.output_dir = save_dir
    wb_logger = WandBLogger(FLAGS.logging, variant=variant)

    set_random_seed(FLAGS.seed)

    if FLAGS.robosuite:
        import robomimic.utils.env_utils as EnvUtils
        from robosuite.wrappers import GymWrapper

        ds = qlearning_robosuite_dataset(
            os.path.join(FLAGS.dataset_path, FLAGS.env.lower(), FLAGS.robosuite_dataset_type, "low_dim.hdf5")
        )
        env = EnvUtils.create_env_from_metadata(env_meta=ds["env_meta"], render=False, render_offscreen=False).env
        gym_env = GymWrapper(env)
        gym_env._max_episode_steps = gym_env.horizon
        gym_env.seed(FLAGS.seed)
        gym_env.action_space.seed(FLAGS.seed)
        gym_env.observation_space.seed(FLAGS.seed)
        gym_env.ignore_done = False
        # label_type = 1
    elif "metaworld" in FLAGS.env:
        task_name = "-".join(FLAGS.env.split("-")[1:])
        ds = qlearning_metaworld_dataset(os.path.join(FLAGS.dataset_path, task_name, f"{task_name}_train.hdf5"))
        gym_env = MetaWorld(task_name, seed=FLAGS.seed)
        gym_env._env.seed(FLAGS.seed)
        gym_env.action_space.seed(FLAGS.seed)
        gym_env.observation_space.seed(FLAGS.seed)
        # label_type = 1
    elif "factorworld" in FLAGS.env:
        task_name = "-".join(FLAGS.env.split("-")[1:])
        ds = qlearning_factorworld_dataset(os.path.join(FLAGS.data.data_dir, "train"), camera_name="corner2")
        gym_env = MetaWorld(task_name, seed=FLAGS.seed)
        gym_env._env.seed(FLAGS.seed)
        gym_env.action_space.seed(FLAGS.seed)
        gym_env.observation_space.seed(FLAGS.seed)
        # label_type = 0
    elif "ant" in FLAGS.env:
        gym_env = gym.make(FLAGS.env)
        gym_env = wrappers.EpisodeMonitor(gym_env)
        gym_env = wrappers.SinglePrecision(gym_env)
        gym_env.seed(FLAGS.seed)
        gym_env.action_space.seed(FLAGS.seed)
        gym_env.observation_space.seed(FLAGS.seed)
        ds = qlearning_ant_dataset(gym_env)
        # label_type = 1
    else:
        gym_env = gym.make(FLAGS.env)
        eval_sampler = TrajSampler(gym_env.unwrapped, FLAGS.max_traj_length)
        ds = get_d4rl_dataset(eval_sampler.env)

    # use fixed seed for collecting segments.
    set_random_seed(FLAGS.data_seed)

    if "dense" in FLAGS.env:
        env = "-".join(FLAGS.env.split("-")[:-2] + [FLAGS.env.split("-")[-1]])
    elif "metaworld" in FLAGS.env or "factorworld" in FLAGS.env:
        env = "-".join(FLAGS.env.split("-")[1:])
    elif FLAGS.robosuite:
        env = f"{FLAGS.env}_{FLAGS.robosuite_dataset_type}"
    else:
        env = FLAGS.env

    data_path = pathlib.Path(FLAGS.data.data_dir).expanduser()
    query_path = data_path / "queries"
    if os.path.exists(query_path):
        indices_1_file = query_path / f"indices_num{FLAGS.data.num_query}_q{FLAGS.data.query_len}"
        indices_2_file = query_path / f"indices_2_num{FLAGS.data.num_query}_q{FLAGS.data.query_len}"
        if FLAGS.data.use_human_label:
            label_file = query_path / f"label_human_num{FLAGS.data.num_query}_q{FLAGS.data.query_len}"
        else:
            label_file = query_path / f"label_scripted_num{FLAGS.data.num_query}_q{FLAGS.data.query_len}"
        if not indices_1_file.exists() or not indices_2_file.exists():
            query_1, query_2, label = None, None, None
        elif label_file.exists() is False:
            raise "[Num {FLAGS.data.num_query} | query_len {FLAGS.data.query_len}] There's no saved labels. Please check it."
        else:
            if label_file.exists() is False:
                raise "[Num {FLAGS.data.num_query} | query_len {FLAGS.data.query_len}] There's no saved indices and labels. Please check it."
            print("load existing queries and labels.")
            with indices_1_file.open("rb") as fp, indices_2_file.open("rb") as gp, label_file.open("rb") as hp:
                query_1 = pickle.load(fp)
                query_2 = pickle.load(gp)
                label = pickle.load(hp)
    else:
        query_1, query_2, label = None, None, None

    # if not os.path.exists(os.path.join(FLAGS.data.data_dir, "episodes", "data.hdf5")):
    #     raise "NOT IMPLEMENTED YET."

    dataset = PrefDataset(
        env=gym_env,
        update=FLAGS.data,
        ds=ds,
        query_1=query_1,
        query_2=query_2,
        label=label,
    )

    set_random_seed(FLAGS.seed)
    observation_dim = gym_env.observation_space.shape[0]
    action_dim = gym_env.action_space.shape[0]

    data_size = len(dataset)
    step_per_epoch = int(data_size / FLAGS.batch_size)

    early_stop = EarlyStopping(min_delta=FLAGS.min_delta, patience=FLAGS.patience)

    if FLAGS.model_type == "MR":
        config = FLAGS.mlp
        reward_learner = MRLearner(config, observation_dim, action_dim, jax_devices)

    elif FLAGS.model_type == "VMR":
        config = FLAGS.visual_mlp
        image_dim = (FLAGS.data.image_size, FLAGS.data.image_size, 3)
        reward_learner = VMRLearner(config, image_dim, action_dim)

    elif FLAGS.model_type == "EnsembleMR":
        config = FLAGS.mlp
        reward_learner = EnsembleMRLearner(config, observation_dim, action_dim, jax_devices, FLAGS.num_ensembles)

    elif FLAGS.model_type == "NMR":
        total_epochs = FLAGS.n_epochs
        config = transformers.GPT2Config(**FLAGS.lstm)
        config.warmup_steps = int(total_epochs * 0.1 * step_per_epoch)
        config.total_steps = total_epochs * step_per_epoch
        reward_learner = NMRLearner(config, observation_dim, action_dim, jax_devices)

    elif FLAGS.model_type == "PT":
        total_epochs = FLAGS.n_epochs
        config = transformers.GPT2Config(**FLAGS.transformer)
        config.warmup_steps = int(total_epochs * 0.1 * step_per_epoch)
        config.total_steps = total_epochs * step_per_epoch
        reward_learner = PTLearner(config, observation_dim, action_dim, jax_devices)

    elif FLAGS.model_type == "EnsemblePT":
        total_epochs = FLAGS.n_epochs
        config = transformers.GPT2Config(**FLAGS.transformer)
        config.warmup_steps = int(total_epochs * 0.1 * step_per_epoch)
        config.total_steps = total_epochs * step_per_epoch
        reward_learner = EnsemblePTLearner(config, observation_dim, action_dim, jax_devices, FLAGS.num_ensembles)

    elif FLAGS.model_type == "VPT":
        total_epochs = FLAGS.n_epochs
        config = transformers.GPT2Config(**FLAGS.visual_transformer)
        config.warmup_steps = int(total_epochs * 0.1 * step_per_epoch)
        config.total_steps = total_epochs * step_per_epoch
        image_dim = (FLAGS.data.image_size, FLAGS.data.image_size, 3)
        reward_learner = VPTLearner(config, image_dim, action_dim, FLAGS.num_ensembles)

    train_loss = "reward/loss"

    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=FLAGS.batch_size,
        shuffle=True,
        drop_last=True,
        prefetch_factor=2,
    )

    val_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=FLAGS.batch_size,
        shuffle=False,
        drop_last=True,
        prefetch_factor=2,
    )

    def generate_batch(iterator):
        while True:
            for batch in iterator:
                # reshape_fn = lambda x: x.numpy().reshape(n_devices, -1, *x.shape[1:])
                reshape_fn = lambda x: x.numpy()
                data = {}
                for key in batch:
                    data[key] = jax.tree_util.tree_map(reshape_fn, batch[key])
                yield data

    train_iter = generate_batch(train_loader)
    val_iter = generate_batch(val_loader)
    criteria_key = None

    start_step, total_steps = 0, step_per_epoch * FLAGS.n_epochs
    step_counter = trange(start_step, total_steps, desc="Train...", ncols=0)
    step_per_eval = step_per_epoch * FLAGS.eval_period
    save_step = step_per_epoch * FLAGS.save_period
    for step, batch in zip(step_counter, train_iter):
        if step % step_per_epoch == 0:
            metrics = defaultdict(list)
        epoch = step // step_per_epoch
        batch = batch_to_jax(batch)
        for key, val in prefix_metrics(reward_learner.train(batch), "reward").items():
            metrics[key].append(val)

        # eval phase
        if step and step % step_per_eval == 0:
            for _, batch in zip(trange(step_per_epoch, desc="val...", ncols=0), val_iter):
                batch_eval = batch_to_jax(batch)
                for key, val in prefix_metrics(reward_learner.evaluation(batch_eval), "reward").items():
                    metrics[key].append(val)

            criteria_key = train_loss
            criteria = np.mean(metrics[criteria_key])
            has_improved, early_stop = early_stop.update(criteria)
            if early_stop.should_stop and FLAGS.early_stop:
                log_metrics = {k: np.mean(v) for k, v in metrics.items()}
                log_metrics.update({"step": step, "epoch": epoch})
                tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
                wb_logger.log(log_metrics)
                print("Met early stopping criteria, breaking...")
                break
            elif epoch > 0 and has_improved:
                metrics["best_epoch"] = epoch
                metrics[f"{key}_best"] = criteria
                save_data = {
                    "epoch": epoch,
                    "state": jax.device_get(reward_learner._train_states),
                    "config": reward_learner.config.to_dict(),
                }
                save_pickle(save_data, "best_model.pkl", save_dir)

        if step and step % step_per_epoch == step_per_epoch - 1:
            log_metrics = {k: np.mean(v) for k, v in metrics.items()}
            log_metrics.update({"step": step, "epoch": epoch})
            tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
            wb_logger.log(log_metrics)

        if FLAGS.save_model and step % save_step == 0:
            save_data = {
                "epoch": epoch,
                "state": jax.device_get(reward_learner._train_states),
                "config": reward_learner.config.to_dict(),
            }
            save_pickle(save_data, f"model_epoch{epoch}.pkl", save_dir)

    if FLAGS.save_model:
        save_data = {
            "epoch": epoch,
            "state": jax.device_get(reward_learner._train_states),
            "config": reward_learner.config.to_dict(),
        }
        save_pickle(save_data, "model.pkl", save_dir)


if __name__ == "__main__":
    absl.app.run(main)
