import copy
import os
import pickle
import pprint
from collections import defaultdict

import absl.app
import absl.flags
import jax
import numpy as np
import torch
import transformers
from flax import jax_utils
from flax.jax_utils import prefetch_to_device
from rich.console import Console
from tqdm import trange

# from jax import config

# config.update("jax_debug_nans", True)

console = Console()

from bpref_v2.data.arp_furniturebench_dataset_inmemory_stream import (
    ARPFurnitureBenchDataset,
    worker_init_fn,
)
from bpref_v2.data.arp_maniskill_dataset_inmemory_stream import ARPManiSkillDataset
from bpref_v2.data.arp_robot_dataset_inmemory_stream import ARPRobotDataset
from bpref_v2.data.augmentations import single_pmap_image_aug_fn, tube_pmap_image_aug_fn
from bpref_v2.utils.viskit.logging import setup_logger

from ..utils.jax_utils import batch_to_jax, next_rng
from ..utils.utils import (
    WandBLogger,
    define_flags_with_default,
    get_user_flags,
    prefix_metrics,
    save_pickle,
    set_random_seed,
)
from .algos import (
    ARPV1Learner,
    DiscriminatorLearner,
    DrSLearner,
    R2RRankLearner,
    REDSCNNLearner,
    REDSLearner,
    REDSNOEPICLearner,
    REDSNOSUPCONLearner,
    REDSNoTransLearner,
)

FLAGS_DEF = define_flags_with_default(
    env="halfcheetah-medium-v2",
    model_type="MLP",
    max_traj_length=1000,
    seed=42,
    data_seed=42,
    save_model=True,
    batch_size=64,
    num_workers=16,
    early_stop=True,
    min_delta=1e-4,
    patience=10,
    train_steps=5000,
    eval_steps=100,
    log_period=100,
    eval_period=1000,
    save_period=1000,
    comment="",
    furniturebench=ARPFurnitureBenchDataset.get_default_config(),
    maniskill=ARPManiSkillDataset.get_default_config(),
    robot=ARPRobotDataset.get_default_config(),
    use_failure=True,
    num_failure_demos=10,
    arpv1=ARPV1Learner.get_default_config(),
    reds=REDSLearner.get_default_config(),
    redscnn=REDSCNNLearner.get_default_config(),
    redsnotrans=REDSNoTransLearner.get_default_config(),
    redsnoepic=REDSNOEPICLearner.get_default_config(),
    redsnosupcon=REDSNOSUPCONLearner.get_default_config(),
    disc=DiscriminatorLearner.get_default_config(),
    r2r=R2RRankLearner.get_default_config(),
    drs=DrSLearner.get_default_config(),
    logging=WandBLogger.get_default_config(),
    checkpoint_path="",
    augmentations="none",
)


def main(_):
    #######################################################################################
    ################################ DEFINE HYPERPARAMETERS ###############################
    #######################################################################################

    FLAGS = absl.flags.FLAGS
    console.log("JAX process: %d / %d", jax.process_index(), jax.process_count())
    console.log("JAX local devices: %r", jax.local_devices())

    jax_devices = jax.local_devices()
    n_devices = len(jax_devices)
    jax_process_index = jax.process_index()
    jax_process_count = jax.process_count()

    process_batch_size = FLAGS.batch_size // jax_process_count

    save_dir = FLAGS.logging.output_dir + "/" + str(FLAGS.model_type)
    save_dir += "/" + str(FLAGS.env) + "/"

    FLAGS.logging.group = f"{FLAGS.model_type}"
    assert FLAGS.comment, "You must leave your comment for logging experiment."
    FLAGS.logging.group += f"_{FLAGS.comment}"
    FLAGS.logging.experiment_id = FLAGS.logging.group + f"_s{FLAGS.seed}"
    save_dir += FLAGS.comment + "/"
    save_dir += "s" + str(FLAGS.seed)

    FLAGS.logging.output_dir = save_dir
    FLAGS.logging.project = "CoRL-REDS"

    variant = get_user_flags(FLAGS, FLAGS_DEF)
    setup_logger(variant=variant, seed=FLAGS.seed, base_log_dir=save_dir, include_exp_prefix_sub_dir=False)
    wb_logger = WandBLogger(FLAGS.logging, variant=variant, enable=(jax_process_index == 0))

    set_random_seed(FLAGS.seed)

    step_per_log = FLAGS.log_period
    step_per_save = FLAGS.save_period

    # use fixed seed for collecting segments.
    set_random_seed(FLAGS.data_seed)
    set_random_seed(FLAGS.seed)

    #######################################################################################
    ################################ DataLoader Setup #####################################
    #######################################################################################

    if "furniturebench" in FLAGS.env:
        train_dataset = ARPFurnitureBenchDataset(
            update=FLAGS.furniturebench,
            split="train",
            start_offset_ratio=jax_process_index / jax_process_count,
            demo_type="success",
        )
        # val_data_config = copy.deepcopy(FLAGS.furniturebench)
        # val_data_config.num_demos = 10
        # val_dataset = ARPFurnitureBenchDataset(
        #     update=val_data_config,
        #     start_offset_ratio=jax_process_index / jax_process_count,
        #     split="val",
        #     demo_type="success",
        # )
        if FLAGS.use_failure:
            neg_data_config = copy.deepcopy(FLAGS.furniturebench)
            # MH: Added for negative data finetune.
            neg_data_config.data_dir = FLAGS.furniturebench.finetune_data_dir
            neg_val_data_config = copy.deepcopy(neg_data_config)
            neg_train_dataset = ARPFurnitureBenchDataset(
                update=neg_data_config,
                split="train",
                start_offset_ratio=jax_process_index / jax_process_count,
                demo_type="failure",
            )
            # neg_val_data_config.num_demos = 10
            # neg_val_dataset = ARPFurnitureBenchDataset(
            #     update=neg_val_data_config,
            #     start_offset_ratio=jax_process_index / jax_process_count,
            #     split="val",
            #     demo_type="failure",
            # )
        action_dim = 8
        image_size = FLAGS.furniturebench.image_size
        target_config = FLAGS.furniturebench

    elif any(elem in FLAGS.env for elem in ["metaworld", "rlbench"]):
        train_dataset = ARPRobotDataset(
            update=FLAGS.robot,
            split="train",
            start_offset_ratio=jax_process_index / jax_process_count,
            demo_type="success",
        )
        # val_data_config = copy.deepcopy(FLAGS.robot)
        # val_data_config.num_demos = 10
        # val_dataset = ARPRobotDataset(
        #     update=val_data_config,
        #     start_offset_ratio=jax_process_index / jax_process_count,
        #     split="val",
        #     demo_type="success",
        # )
        if FLAGS.use_failure:
            neg_data_config = copy.deepcopy(FLAGS.robot)
            neg_data_config.num_demos = FLAGS.num_failure_demos
            neg_val_data_config = copy.deepcopy(neg_data_config)
            neg_val_data_config.num_demos = 5
            neg_train_dataset = ARPRobotDataset(
                update=neg_data_config,
                split="train",
                start_offset_ratio=jax_process_index / jax_process_count,
                demo_type="failure",
            )
            # neg_val_dataset = ARPRobotDataset(
            #     update=neg_val_data_config,
            #     start_offset_ratio=jax_process_index / jax_process_count,
            #     split="val",
            #     demo_type="failure",
            # )
        action_dim = 4
        image_size = FLAGS.robot.image_size
        target_config = FLAGS.robot

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=process_batch_size,
        prefetch_factor=2,
        pin_memory=True,
        num_workers=target_config.num_workers,
        worker_init_fn=worker_init_fn,
    )
    # val_batch_size = min(process_batch_size, len(val_dataset) // jax_process_count)
    # val_batch_size = min(process_batch_size, FLAGS.batch_size // jax_process_count)
    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset,
    #     batch_size=val_batch_size,
    #     prefetch_factor=2,
    #     pin_memory=True,
    #     num_workers=max(target_config.num_workers // 4, 1),
    #     worker_init_fn=worker_init_fn,
    # )
    if FLAGS.use_failure:
        neg_train_loader = torch.utils.data.DataLoader(
            neg_train_dataset,
            batch_size=process_batch_size,
            prefetch_factor=2,
            pin_memory=True,
            num_workers=target_config.num_workers,
            worker_init_fn=worker_init_fn,
        )
        # neg_val_loader = torch.utils.data.DataLoader(
        #     neg_val_dataset,
        #     batch_size=val_batch_size,
        #     prefetch_factor=2,
        #     pin_memory=True,
        #     num_workers=max(target_config.num_workers // 4, 1),
        #     worker_init_fn=worker_init_fn,
        # )

    sharded_rng = jax.device_put_sharded(next_rng(n_devices), jax_devices)
    aug_rng = jax.device_put_sharded(next_rng(n_devices), jax_devices)
    if FLAGS.augmentations == "crop|jitter":
        single_aug_fn = single_pmap_image_aug_fn(
            image_size=image_size,
            padding=int(image_size * (4 / 84)),
            window_size=target_config.window_size,
            jax_devices=jax_devices,
        )
        tube_aug_fn = tube_pmap_image_aug_fn(
            image_size=image_size,
            padding=int(image_size * (4 / 84)),
            window_size=target_config.window_size,
            jax_devices=jax_devices,
        )

    elif FLAGS.augmentations == "none":
        single_aug_fn, tube_aug_fn = lambda x, y: (x, y), lambda x, y: (x, y)

    def generate_batch(iterator, rng, split="train"):
        while True:
            for batch in iterator:
                reshape_fn = lambda x: x.numpy().reshape(n_devices, -1, *x.shape[1:])
                data = {}
                for key in batch:
                    if split == "train" and "image" in key:
                        if batch[key][list(batch[key].keys())[0]].ndim == 5:
                            images = jax.tree_util.tree_map(
                                lambda x: x.numpy().reshape(n_devices, -1, *x.shape[2:]), batch[key]
                            )
                            for _key in images:
                                _val, new_rng = (
                                    tube_aug_fn(images[_key], rng)
                                    if "pearson" not in key
                                    else single_aug_fn(images[_key], rng)
                                )
                                images[_key] = _val
                        else:
                            images = jax.tree_util.tree_map(reshape_fn, batch[key])
                            for _key in images:
                                _val, new_rng = single_aug_fn(images[_key], rng)
                                images[_key] = _val
                        if batch[key][list(batch[key].keys())[0]].ndim == 5:
                            window_size = (
                                target_config.window_size if "pearson" not in key else target_config.pearson_size
                            )
                            data[key] = jax.tree_util.tree_map(
                                lambda x: x.reshape(n_devices, -1, window_size, *x.shape[2:]), images
                            )
                        else:
                            data[key] = images
                    else:
                        data[key] = jax.tree_util.tree_map(reshape_fn, batch[key])
                if split == "train":
                    rng = new_rng
                yield data

    #######################################################################################
    ################################ DEFINE REWARD MODEL ##################################
    #######################################################################################
    num_steps = FLAGS.train_steps

    if FLAGS.model_type == "ARP-V1":
        config = transformers.GPT2Config(**FLAGS.arpv1)
        config.warmup_steps = int(0.1 * num_steps)
        config.total_steps = num_steps
        image_dim = (image_size, image_size, 3)
        reward_learner = ARPV1Learner(config, image_dim, action_dim)
    elif FLAGS.model_type == "REDS":
        config = transformers.GPT2Config(**FLAGS.reds)
        config.warmup_steps = int(0.1 * num_steps)
        config.total_steps = num_steps
        config.image_keys = target_config.image_keys
        config.num_images = len(target_config.image_keys.split("|"))
        config.window_size = target_config.window_size
        image_dim = (image_size, image_size, 3)
        task_name = FLAGS.env.split("-", 1)[-1]
        reward_learner = REDSLearner(config, task_name, image_dim, action_dim, jax_devices=jax_devices)
    elif FLAGS.model_type == "REDSCNN":
        config = transformers.GPT2Config(**FLAGS.redscnn)
        config.warmup_steps = int(0.1 * num_steps)
        config.total_steps = num_steps
        config.image_keys = target_config.image_keys
        config.num_images = len(target_config.image_keys.split("|"))
        config.window_size = target_config.window_size
        image_dim = (image_size, image_size, 3)
        task_name = FLAGS.env.split("-", 1)[-1]
        reward_learner = REDSCNNLearner(config, task_name, image_dim, action_dim, jax_devices=jax_devices)
    elif FLAGS.model_type == "REDSNoTrans":
        config = transformers.GPT2Config(**FLAGS.redsnotrans)
        config.warmup_steps = int(0.1 * num_steps)
        config.total_steps = num_steps
        config.image_keys = target_config.image_keys
        config.num_images = len(target_config.image_keys.split("|"))
        config.window_size = target_config.window_size
        image_dim = (image_size, image_size, 3)
        task_name = FLAGS.env.split("-", 1)[-1]
        reward_learner = REDSNoTransLearner(config, task_name, image_dim, action_dim, jax_devices=jax_devices)
    elif FLAGS.model_type == "REDSNOEPIC":
        config = transformers.GPT2Config(**FLAGS.redsnoepic)
        config.warmup_steps = int(0.1 * num_steps)
        config.total_steps = num_steps
        config.image_keys = target_config.image_keys
        config.num_images = len(target_config.image_keys.split("|"))
        config.window_size = target_config.window_size
        image_dim = (image_size, image_size, 3)
        task_name = FLAGS.env.split("-", 1)[-1]
        reward_learner = REDSNOEPICLearner(config, task_name, image_dim, action_dim, jax_devices=jax_devices)
    elif FLAGS.model_type == "REDSNOSUPCON":
        config = transformers.GPT2Config(**FLAGS.redsnosupcon)
        config.warmup_steps = int(0.1 * num_steps)
        config.total_steps = num_steps
        config.image_keys = target_config.image_keys
        config.num_images = len(target_config.image_keys.split("|"))
        config.window_size = target_config.window_size
        image_dim = (image_size, image_size, 3)
        task_name = FLAGS.env.split("-", 1)[-1]
        reward_learner = REDSNOSUPCONLearner(config, task_name, image_dim, action_dim, jax_devices=jax_devices)
    elif FLAGS.model_type == "DISC":
        config = FLAGS.disc.unlock()
        config.image_keys = target_config.image_keys
        config.num_images = len(target_config.image_keys.split("|"))
        config.window_size = target_config.window_size
        image_dim = (image_size, image_size, 3)
        task_name = FLAGS.env.split("-", 1)[-1]
        reward_learner = DiscriminatorLearner(config, task_name, image_dim, action_dim, jax_devices=jax_devices)
    elif FLAGS.model_type == "DRS":
        config = FLAGS.drs.unlock()
        config.image_keys = target_config.image_keys
        config.num_images = len(target_config.image_keys.split("|"))
        config.window_size = target_config.window_size
        image_dim = (image_size, image_size, 3)
        task_name = FLAGS.env.split("-", 1)[-1]
        reward_learner = DrSLearner(config, task_name, image_dim, action_dim, jax_devices=jax_devices)
    elif FLAGS.model_type == "R2R":
        config = FLAGS.r2r.unlock()
        config.image_keys = target_config.image_keys
        config.num_images = len(target_config.image_keys.split("|"))
        config.window_size = target_config.window_size
        image_dim = (image_size, image_size, 3)
        task_name = FLAGS.env.split("-", 1)[-1]
        reward_learner = R2RRankLearner(config, task_name, image_dim, action_dim, jax_devices=jax_devices)

    #######################################################################################
    ################################ GLOBAL TRAINING PHASE ################################
    #######################################################################################
    assert os.path.exists(FLAGS.checkpoint_path), f"Best model not found from {FLAGS.checkpoint_path}."

    with open(FLAGS.checkpoint_path, "rb") as fin:
        checkpoint_data = pickle.load(fin)
        config, state = checkpoint_data["config"], checkpoint_data["state"]
    reward_learner.load_state(state, jax_devices=jax_devices, reset_optimizer=True, use_scheduler=False)

    start_step = checkpoint_data["step"]
    total_steps = start_step + num_steps + 1

    step_counter = trange(start_step, total_steps, desc="Reward Learning Step...", ncols=0)
    metrics = defaultdict(list)

    train_loader.dataset.set_mode("global")
    # val_loader.dataset.set_mode("global")
    train_iter = prefetch_to_device(generate_batch(train_loader, aug_rng, split="train"), 2, jax_devices)
    # val_iter = prefetch_to_device(generate_batch(val_loader, aug_rng, split="val"), 2, jax_devices)
    if FLAGS.use_failure:
        neg_train_loader.dataset.set_mode("global")
        # neg_val_loader.dataset.set_mode("global")
        neg_train_iter = prefetch_to_device(generate_batch(neg_train_loader, aug_rng, split="train"), 2, jax_devices)
        # neg_val_iter = prefetch_to_device(generate_batch(neg_val_loader, aug_rng, split="val"), 2, jax_devices)

    for step in step_counter:
        batch = next(train_iter)
        neg_batch = next(neg_train_iter) if FLAGS.use_failure else None
        if step % step_per_log == 0:
            metrics = defaultdict(list)
        batch = batch_to_jax(batch)
        train_metrics, sharded_rng = reward_learner.train_step(batch, sharded_rng, neg_batch=neg_batch)
        for key, val in prefix_metrics(train_metrics, "reward").items():
            metrics[key].append(val)

        if step and step % step_per_log == 0:
            log_metrics = {k: np.mean(v) for k, v in metrics.items()}
            log_metrics.update({"step": step})
            console.log("\n" + pprint.pformat(log_metrics) + "\n")
            wb_logger.log(log_metrics, step=step)

        if FLAGS.save_model and step and step % step_per_save == 0:
            if jax_process_index == 0:
                save_data = {
                    "step": step,
                    "state": jax.device_get(jax_utils.unreplicate(reward_learner._train_states)),
                    "config": reward_learner.config.to_dict(),
                }
                save_pickle(save_data, f"model_step{step}.pkl", save_dir)

    if FLAGS.save_model:
        if jax_process_index == 0:
            save_data = {
                "step": step,
                "state": jax.device_get(jax_utils.unreplicate(reward_learner._train_states)),
                "config": reward_learner.config.to_dict(),
            }
            save_pickle(save_data, "last_model.pkl", save_dir)


if __name__ == "__main__":
    absl.app.run(main)
