import os
import pathlib
import pickle

import jax
import transformers
from absl import app, flags

from bpref_v2.data.qlearning_dataset import qlearning_factorworld_dataset
from bpref_v2.envs import MetaWorld
from bpref_v2.reward_learning.algos import PTLearner, VMRLearner, VPTLearner
from bpref_v2.utils.dataset_utils import (
    RelabeledDataset,
    reward_from_preference,
    reward_from_preference_transformer,
)
from bpref_v2.utils.utils import set_random_seed

FLAGS = flags.FLAGS

flags.DEFINE_string("task_name", "pick-place-v2", "task_name")
flags.DEFINE_string("model_name", "VPT", "model_name")
flags.DEFINE_string("comment", "light", "comment")
flags.DEFINE_string("tp", "success", "type of trajectories")
flags.DEFINE_integer("seq_len", 30, "sequence length used for training.")
flags.DEFINE_integer("skip_frame", 1, "skip frame.")
flags.DEFINE_integer("seed", 0, "seed.")
flags.DEFINE_string("camera_key", "corner2", "image key")


def main(_):
    base_path = pathlib.Path("/home/pref_data/reward_learning").expanduser()

    task_name = FLAGS.task_name
    model_name = FLAGS.model_name
    comment = task_name if FLAGS.comment == "" else f"{task_name}-{FLAGS.comment}"
    seed = FLAGS.seed

    os.environ["XLA_PYTHON_CLIENT_MEM_PREALLOCATE"] = "false"

    set_random_seed(seed)

    path = base_path / f"factorworld-{task_name}" / f"{model_name}" / comment / f"s{seed}"

    best_model = path / "best_model.pkl"
    if not best_model.exists():
        best_model = path / "model.pkl"

    with best_model.open("rb") as fin:
        checkpoint_data = pickle.load(fin)
    state = checkpoint_data["state"]

    # load env
    env = MetaWorld(task_name)
    observation_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    jax_devices = jax.local_devices()

    # Load trained PT model.
    if "PT" in model_name:
        if checkpoint_data.get("config") is not None:
            transformer = checkpoint_data["config"]
        else:
            # Load trained PT model.
            if model_name == "PT":
                transformer = PTLearner.get_default_config()
            elif model_name == "VPT":
                transformer = VPTLearner.get_default_config()
        transformer.embd_dim = 128
        transformer.n_layer = 1
        transformer.n_head = 4

        config = transformers.GPT2Config(**transformer)
        config.warmup_steps = 10
        config.total_steps = 1000
    elif "MR" in model_name:
        if checkpoint_data.get("config") is not None:
            mlp = checkpoint_data["config"]
        else:
            mlp = None
        config = VMRLearner.get_default_config(mlp)

    if model_name == "PT":
        reward_learner = PTLearner(config, observation_dim, action_dim, jax_devices)
    elif model_name == "VPT":
        image_dim = (224, 224, 3)
        reward_learner = VPTLearner(config, image_dim, action_dim)
    elif model_name == "VMR":
        image_dim = (224, 224, 3)
        reward_learner = VMRLearner(config, image_dim, action_dim)

    reward_learner.load(state)

    ds_base_path = pathlib.Path("/home/pref_data").expanduser()

    if task_name == "pick-place-v2":
        samples = [
            ("pick-place-v2", "light-object_pos-goal_pos-table_pos-floor_texture-table_texture"),
            (
                "pick-place-v2",
                "light-object_pos-goal_pos-object_size-object_texture-table_pos-floor_texture-table_texture",
            ),
            ("pick-place-v2", "light"),
            ("pick-place-v2", "object_texture"),
        ]
    elif task_name == "door-open-v2":
        samples = [("door-open-v2", "light-goal_pos-table_pos-floor_texture-table_texture")]

    for ds_task_name, ds_variation_name in samples:
        tp = FLAGS.tp
        if tp == "success":
            dataset_path = ds_base_path / ds_task_name / ds_variation_name / "train"
        elif tp == "failure":
            dataset_path = ds_base_path / ds_task_name / ds_variation_name / "failure_episodes"
        print(f"dataset_path: {dataset_path}")
        ds = qlearning_factorworld_dataset(dataset_path, camera_keys=[FLAGS.camera_key])
        if len(ds[FLAGS.camera_key]) == 0:
            return

        if model_name in ["VPT", "VMR"]:
            dataset = RelabeledDataset(
                ds["observations"],
                ds["actions"],
                ds["rewards"],
                ds["terminals"],
                ds["next_observations"],
                images=ds[FLAGS.camera_key],
            )
        else:
            dataset = RelabeledDataset(
                ds["observations"], ds["actions"], ds["rewards"], ds["terminals"], ds["next_observations"]
            )

        if "PT" in model_name:
            batch_size = 32 if model_name == "VPT" else 1024
            rds = reward_from_preference_transformer(
                env_name="",
                dataset=dataset,
                reward_model=reward_learner,
                seq_len=FLAGS.seq_len,
                batch_size=batch_size,
                skip_frame=FLAGS.skip_frame,
            )
        elif "MR" in model_name:
            batch_size = 256
            rds = reward_from_preference(
                env_name="",
                dataset=dataset,
                reward_model=reward_learner,
                batch_size=batch_size,
            )
        with (ds_base_path / ds_task_name / ds_variation_name / f"{model_name}-{FLAGS.comment}_reward_{tp}.pkl").open(
            "wb"
        ) as f:
            pickle.dump(rds.rewards, f)


if __name__ == "__main__":
    app.run(main)
