import copy
import os
import pathlib
import pickle
import pprint
from collections import defaultdict

import absl.app
import absl.flags
import gym
import jax
import numpy as np
import torch
import transformers
from flax import jax_utils
from flax.jax_utils import prefetch_to_device
from flax.training.early_stopping import EarlyStopping
from rich.console import Console
from tqdm import trange

console = Console()

import bpref_v2.envs.wrappers as wrappers
from bpref_v2.data.arp_factorworld_dataset import ARPFactorworldDataset
from bpref_v2.data.arp_furniturebench_dataset_hdf5 import (
    ARPFurnitureBenchDataset as ARPFurnitureBenchDatasetHDF5,
)
from bpref_v2.data.arp_furniturebench_dataset_inmemory import (
    ARPFurnitureBenchDataset as ARPFuritureBenchDatasetInMemory,
)
from bpref_v2.data.augmentations import pmap_image_aug_fn
from bpref_v2.data.qlearning_dataset import (
    qlearning_ant_dataset,
    qlearning_robosuite_dataset,
)
from bpref_v2.data.replay_buffer import get_d4rl_dataset
from bpref_v2.envs import MetaWorld
from bpref_v2.utils.viskit.logging import setup_logger

from ..utils.jax_utils import batch_to_jax, next_rng
from ..utils.utils import (
    WandBLogger,
    define_flags_with_default,
    get_user_flags,
    prefix_metrics,
    save_pickle,
    set_random_seed,
)
from .algos import ARPV1Learner, REDSLearner
from .sampler import TrajSampler

FLAGS_DEF = define_flags_with_default(
    env="halfcheetah-medium-v2",
    model_type="MLP",
    max_traj_length=1000,
    seed=42,
    data_seed=42,
    save_model=True,
    batch_size=64,
    num_workers=16,
    early_stop=False,
    min_delta=1e-4,
    patience=10,
    n_epochs=2000,
    log_period=100,
    eval_period=10,
    save_period=100,
    comment="",
    robosuite=False,
    dataset_path="",
    robosuite_dataset_type="ph",
    robosuite_max_episode_steps=500,
    factorworld=ARPFactorworldDataset.get_default_config(),
    furniturebench=ARPFuritureBenchDatasetInMemory.get_default_config(),
    in_memory=False,
    arpv1=ARPV1Learner.get_default_config(),
    arpv2=REDSLearner.get_default_config(),
    logging=WandBLogger.get_default_config(),
    num_ensembles=1,
    augmentations="none",
    ckpt_path="",
)


def main(_):
    FLAGS = absl.flags.FLAGS
    console.log("JAX process: %d / %d", jax.process_index(), jax.process_count())
    console.log("JAX local devices: %r", jax.local_devices())

    jax_devices = jax.local_devices()
    n_devices = len(jax_devices)
    jax_process_index = jax.process_index()
    jax_process_count = jax.process_count()

    process_batch_size = FLAGS.batch_size // jax_process_count

    save_dir = FLAGS.logging.output_dir + "/" + FLAGS.env
    save_dir += "/" + str(FLAGS.model_type) + "/"

    FLAGS.logging.group = f"{FLAGS.env}_{FLAGS.model_type}"
    assert FLAGS.comment, "You must leave your comment for logging experiment."
    FLAGS.logging.group += f"_{FLAGS.comment}"
    FLAGS.logging.experiment_id = FLAGS.logging.group + f"_s{FLAGS.seed}"
    save_dir += FLAGS.comment + "/"
    save_dir += "s" + str(FLAGS.seed)

    FLAGS.logging.output_dir = save_dir
    FLAGS.logging.project = "REDS"

    variant = get_user_flags(FLAGS, FLAGS_DEF)
    setup_logger(variant=variant, seed=FLAGS.seed, base_log_dir=save_dir, include_exp_prefix_sub_dir=False)
    wb_logger = WandBLogger(FLAGS.logging, variant=variant, enable=(jax_process_index == 0))

    set_random_seed(FLAGS.seed)

    if FLAGS.robosuite:
        import robomimic.utils.env_utils as EnvUtils
        from robosuite.wrappers import GymWrapper

        ds = qlearning_robosuite_dataset(
            os.path.join(FLAGS.dataset_path, FLAGS.env.lower(), FLAGS.robosuite_dataset_type, "low_dim.hdf5")
        )
        env = EnvUtils.create_env_from_metadata(env_meta=ds["env_meta"], render=False, render_offscreen=False).env
        gym_env = GymWrapper(env)
        gym_env._max_episode_steps = gym_env.horizon
        gym_env.seed(FLAGS.seed)
        gym_env.action_space.seed(FLAGS.seed)
        gym_env.observation_space.seed(FLAGS.seed)
        gym_env.ignore_done = False
        # label_type = 1
    elif "metaworld" in FLAGS.env:
        task_name = "-".join(FLAGS.env.split("-")[1:])
        # ds = qlearning_metaworld_dataset(os.path.join(FLAGS.dataset_path, task_name, f"{task_name}_train.hdf5"))
        gym_env = MetaWorld(task_name, seed=FLAGS.seed)
        gym_env._env.seed(FLAGS.seed)
        gym_env.action_space.seed(FLAGS.seed)
        gym_env.observation_space.seed(FLAGS.seed)
        # label_type = 1
    elif "factorworld" in FLAGS.env:
        task_name = "-".join(FLAGS.env.split("-")[1:])
        # ds = qlearning_factorworld_dataset(os.path.join(FLAGS.data.data_dir, "train"), camera_name="corner2")
        gym_env = MetaWorld(task_name, seed=FLAGS.seed)
        gym_env._env.seed(FLAGS.seed)
        gym_env.action_space.seed(FLAGS.seed)
        gym_env.observation_space.seed(FLAGS.seed)
        ARPFactorworldDataset(update=FLAGS.factorworld, split="train")
        gym_env.observation_space.shape[0]
        action_dim = gym_env.action_space.shape[0]
        image_size = FLAGS.factorworld.image_size
    elif "furniturebench" in FLAGS.env:
        task_name = "-".join(FLAGS.env.split("-")[1:])
        train_data_dir = pathlib.Path(FLAGS.furniturebench.data_dir)
        filename = f"data_w{FLAGS.furniturebench.window_size}_s{FLAGS.furniturebench.skip_frame}.hdf5"
        train_h5_file_name = (
            train_data_dir / "train" / filename if (train_data_dir / "train" / filename).exists() else None
        )
        if not FLAGS.in_memory and train_h5_file_name:
            train_dataset = ARPFurnitureBenchDatasetHDF5(
                update=FLAGS.furniturebench,
                h5_file_name=train_h5_file_name,
                start_offset_ratio=jax_process_index / jax_process_count,
                split="train",
            )
        else:
            neg_data_config = copy.deepcopy(FLAGS.furniturebench)
            neg_data_config.num_demos = int(FLAGS.furniturebench.num_demos * 0.5)
            neg_data_config.use_liv, neg_data_config.use_nfp = False, False

            train_dataset = ARPFuritureBenchDatasetInMemory(
                update=FLAGS.furniturebench,
                split="train",
                start_offset_ratio=jax_process_index / jax_process_count,
                demo_type="success",
            )
            neg_train_dataset = ARPFuritureBenchDatasetInMemory(
                update=neg_data_config,
                split="train",
                start_offset_ratio=jax_process_index / jax_process_count,
                demo_type="failure",
            )

        val_data_config, neg_val_data_config = copy.deepcopy(FLAGS.furniturebench), copy.deepcopy(neg_data_config)
        # val_data_config.num_demos = max(1, int(FLAGS.furniturebench.num_demos * 0.1))
        val_data_config.num_demos = 50 if not FLAGS.in_memory else 10
        neg_val_data_config.num_demos = 50 if not FLAGS.in_memory else 10
        val_h5_file_name = train_data_dir / "val" / filename if (train_data_dir / "val" / filename).exists() else None
        if not (train_data_dir / "val").exists():
            if FLAGS.in_memory:
                val_dataset = ARPFuritureBenchDatasetInMemory(
                    update=val_data_config,
                    h5_file_name=train_h5_file_name,
                    start_offset_ratio=jax_process_index / jax_process_count,
                    split="train",
                    demo_type="success",
                )
                neg_val_dataset = ARPFuritureBenchDatasetInMemory(
                    update=neg_val_data_config,
                    h5_file_name=train_h5_file_name,
                    start_offset_ratio=jax_process_index / jax_process_count,
                    split="train",
                    demo_type="failure",
                )
            else:
                val_dataset = ARPFurnitureBenchDatasetHDF5(
                    update=val_data_config,
                    h5_file_name=train_h5_file_name,
                    start_offset_ratio=jax_process_index / jax_process_count,
                    split="val",
                )
        elif not FLAGS.in_memory and val_h5_file_name:
            val_dataset = ARPFurnitureBenchDatasetHDF5(
                update=val_data_config,
                h5_file_name=val_h5_file_name,
                start_offset_ratio=jax_process_index / jax_process_count,
                split="val",
            )
        else:
            val_dataset = ARPFuritureBenchDatasetInMemory(
                update=val_data_config,
                split="val",
                start_offset_ratio=jax_process_index / jax_process_count,
                demo_type="success",
            )
            neg_val_dataset = ARPFuritureBenchDatasetInMemory(
                update=val_data_config,
                split="val",
                start_offset_ratio=jax_process_index / jax_process_count,
                demo_type="failure",
            )
        action_dim = 8
        image_size = FLAGS.furniturebench.image_size
    elif "ant" in FLAGS.env:
        gym_env = gym.make(FLAGS.env)
        gym_env = wrappers.EpisodeMonitor(gym_env)
        gym_env = wrappers.SinglePrecision(gym_env)
        gym_env.seed(FLAGS.seed)
        gym_env.action_space.seed(FLAGS.seed)
        gym_env.observation_space.seed(FLAGS.seed)
        ds = qlearning_ant_dataset(gym_env)
        # label_type = 1
    else:
        gym_env = gym.make(FLAGS.env)
        eval_sampler = TrajSampler(gym_env.unwrapped, FLAGS.max_traj_length)
        ds = get_d4rl_dataset(eval_sampler.env)

    # use fixed seed for collecting segments.
    set_random_seed(FLAGS.data_seed)

    if "dense" in FLAGS.env:
        env = "-".join(FLAGS.env.split("-")[:-2] + [FLAGS.env.split("-")[-1]])
    elif "metaworld" in FLAGS.env or "factorworld" in FLAGS.env:
        env = "-".join(FLAGS.env.split("-")[1:])
    elif FLAGS.robosuite:
        env = f"{FLAGS.env}_{FLAGS.robosuite_dataset_type}"
    else:
        env = FLAGS.env

    set_random_seed(FLAGS.seed)
    data_size = len(train_dataset)
    step_per_epoch = int(data_size / FLAGS.batch_size)

    # Load trained Reward Model.
    def load_model(epoch="last"):
        path = pathlib.Path(FLAGS.ckpt_path) / f"s{FLAGS.seed}"
        # path = ckpt_path / f"furnituresim-{task_name}" / model_name / comment / f"s{seed}"
        if epoch is None:
            best_model = path / "best_model.pkl"
        elif epoch == "last":
            best_model = path / "model.pkl"
        else:
            best_model = path / f"model_epoch{epoch}.pkl"
        with best_model.open("rb") as fin:
            checkpoint_data = pickle.load(fin)

        arpv2 = checkpoint_data["config"]
        config = transformers.GPT2Config(**arpv2)
        if not hasattr(config, "max_episode_steps"):
            config.max_episode_steps = 1000
        state = checkpoint_data["state"]
        return state

    early_stop = EarlyStopping(min_delta=FLAGS.min_delta, patience=FLAGS.patience)
    if FLAGS.model_type == "ARP-V1":
        total_epochs = FLAGS.n_epochs
        config = transformers.GPT2Config(**FLAGS.arpv1)
        config.warmup_steps = int(total_epochs * 0.1 * step_per_epoch)
        config.total_steps = total_epochs * step_per_epoch
        image_dim = (image_size, image_size, 3)
        reward_learner = ARPV1Learner(config, image_dim, action_dim, FLAGS.num_ensembles)
    elif FLAGS.model_type == "REDS":
        total_epochs = FLAGS.n_epochs
        config = transformers.GPT2Config(**FLAGS.arpv2)
        config.warmup_steps = int(total_epochs * 0.1 * step_per_epoch)
        config.total_steps = total_epochs * step_per_epoch
        image_dim = (image_size, image_size, 3)
        reward_learner = REDSLearner(config, image_dim, action_dim, FLAGS.num_ensembles, jax_devices=jax_devices)
    console.print(f"Load learned reward model from {FLAGS.ckpt_path}")
    state = load_model()
    params = state["trans"].params
    reward_learner.load_pretrained_model(params, jax_devices)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=process_batch_size,
        shuffle=True,
        drop_last=True,
        prefetch_factor=2,
        num_workers=FLAGS.num_workers,
        persistent_workers=True,
        # multiprocessing_context=torch.multiprocessing.get_context("spawn"),
    )
    neg_train_loader = torch.utils.data.DataLoader(
        neg_train_dataset,
        batch_size=process_batch_size,
        shuffle=True,
        drop_last=True,
        prefetch_factor=2,
        num_workers=FLAGS.num_workers,
        persistent_workers=True,
        # multiprocessing_context=torch.multiprocessing.get_context("spawn"),
    )

    val_batch_size = min(process_batch_size, len(val_dataset) // jax_process_count)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=val_batch_size,
        shuffle=False,
        drop_last=True,
        prefetch_factor=2,
        num_workers=FLAGS.num_workers,
        persistent_workers=True,
        # multiprocessing_context=torch.multiprocessing.get_context("spawn"),
    )
    neg_val_loader = torch.utils.data.DataLoader(
        neg_val_dataset,
        batch_size=val_batch_size,
        shuffle=False,
        drop_last=True,
        prefetch_factor=2,
        num_workers=FLAGS.num_workers,
        persistent_workers=True,
        # multiprocessing_context=torch.multiprocessing.get_context("spawn"),
    )

    sharded_rng = jax.device_put_sharded(next_rng(n_devices), jax_devices)
    aug_rng = jax.device_put_sharded(next_rng(n_devices), jax_devices)
    if FLAGS.augmentations == "crop|jitter":
        aug_fn = pmap_image_aug_fn(
            image_size=image_size,
            padding=int(image_size * (4 / 84)),
            window_size=FLAGS.furniturebench.window_size,
            jax_devices=jax_devices,
        )
    elif FLAGS.augmentations == "none":
        aug_fn = lambda x, y: (x, y)

    def generate_batch(iterator, rng, split="train"):
        while True:
            for batch in iterator:
                reshape_fn = lambda x: x.numpy().reshape(n_devices, -1, *x.shape[1:])
                data = {}
                new_rng = rng
                for key in batch:
                    if split == "train" and "image" in key:
                        images = jax.tree_util.tree_map(
                            lambda x: x.numpy().reshape(n_devices, -1, *x.shape[2:]), batch[key]
                        )
                        for _key in images:
                            _val, new_rng = aug_fn(images[_key], rng)
                            images[_key] = _val
                        data[key] = jax.tree_util.tree_map(
                            lambda x: x.reshape(n_devices, -1, FLAGS.furniturebench.window_size, *x.shape[2:]), images
                        )
                    else:
                        data[key] = jax.tree_util.tree_map(reshape_fn, batch[key])
                rng = new_rng
                yield data

    train_iter = prefetch_to_device(generate_batch(train_loader, aug_rng, split="train"), 2, jax_devices)
    neg_train_iter = prefetch_to_device(generate_batch(neg_train_loader, aug_rng, split="train"), 2, jax_devices)
    val_iter = prefetch_to_device(generate_batch(val_loader, aug_rng, split="val"), 2, jax_devices)
    neg_val_iter = prefetch_to_device(generate_batch(neg_val_loader, aug_rng, split="val"), 2, jax_devices)
    criteria_key = None

    start_step, total_steps = 0, step_per_epoch * FLAGS.n_epochs + 1
    step_counter = trange(start_step, total_steps, desc="Train...", ncols=0)
    step_per_log = FLAGS.log_period
    step_per_eval = step_per_epoch * FLAGS.eval_period
    step_per_save = step_per_epoch * FLAGS.save_period
    step_per_eval_epoch = int(len(val_dataset) / val_batch_size)

    for step, batch, negative_batch in zip(step_counter, train_iter, neg_train_iter):
        if step % step_per_epoch == 0:
            metrics = defaultdict(list)
        epoch = step // step_per_epoch
        batch = batch_to_jax(batch)
        train_metrics, sharded_rng = reward_learner.train_arp_step(batch, sharded_rng, negative_batch=negative_batch)
        for key, val in prefix_metrics(train_metrics, "reward").items():
            metrics[key].append(val)

        # eval phase
        if step and step % step_per_eval == 0:
            for _, batch, negative_batch in zip(
                trange(step_per_eval_epoch, desc="val...", ncols=0), val_iter, neg_val_iter
            ):
                batch = batch_to_jax(batch)
                eval_metrics, sharded_rng = reward_learner.eval_arp_step(
                    batch, sharded_rng, negative_batch=negative_batch
                )
                for key, val in prefix_metrics(eval_metrics, "reward_eval").items():
                    metrics[key].append(val)

            criteria_key = "reward_eval/loss"
            criteria = np.mean(metrics[criteria_key])
            has_improved, early_stop = early_stop.update(criteria)
            if early_stop.should_stop and FLAGS.early_stop:
                log_metrics = {k: np.mean(v) for k, v in metrics.items()}
                log_metrics.update({"step": step, "epoch": epoch})
                console.print("\n" + pprint.pformat(log_metrics) + "\n")
                wb_logger.log(log_metrics)
                console.print("Met early stopping criteria, breaking...")
                break
            elif epoch > 0 and has_improved:
                metrics["best_epoch"] = epoch
                metrics[f"{criteria_key}_best"] = criteria
                save_data = {
                    "epoch": epoch,
                    "state": jax.device_get(jax_utils.unreplicate(reward_learner._train_states)),
                    "config": reward_learner.config.to_dict(),
                }
                save_pickle(save_data, "best_model.pkl", save_dir)

        if step and step % step_per_log == 0:
            log_metrics = {k: np.mean(v) for k, v in metrics.items()}
            log_metrics.update({"step": step, "epoch": epoch})
            console.log("\n" + pprint.pformat(log_metrics) + "\n")
            wb_logger.log(log_metrics)

        if FLAGS.save_model and step and step % step_per_save == 0:
            if jax_process_index == 0:
                save_data = {
                    "epoch": epoch,
                    "state": jax.device_get(jax_utils.unreplicate(reward_learner._train_states)),
                    "config": reward_learner.config.to_dict(),
                }
                save_pickle(save_data, f"model_epoch{epoch}.pkl", save_dir)

    if FLAGS.save_model:
        if jax_process_index == 0:
            save_data = {
                "epoch": epoch,
                "state": jax.device_get(jax_utils.unreplicate(reward_learner._train_states)),
                "config": reward_learner.config.to_dict(),
            }
            save_pickle(save_data, "model.pkl", save_dir)


if __name__ == "__main__":
    absl.app.run(main)
