import os

# sys.path.append("/home/factor-world")
import pathlib
import pickle
from typing import Optional

import cv2
import jax
import jax.numpy as jnp
import numpy as np
import torch
import torchvision.transforms as T
import transformers
from absl import app, flags
from metaworld_generalization.wrappers import make_wrapped_env
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from vip import load_vip

from bpref_v2.data.qlearning_dataset import qlearning_factorworld_dataset
from bpref_v2.reward_learning.algos import PTLearner, VPTLearner
from bpref_v2.utils.dataset_utils import RelabeledDataset
from bpref_v2.utils.jax_utils import batch_to_jax
from bpref_v2.utils.utils import set_random_seed

FLAGS = flags.FLAGS

flags.DEFINE_string("task_name", "pick-place-v2", "task_nae")
flags.DEFINE_string("model_name", "VPT", "model_name")
flags.DEFINE_string("comment", "light", "comment")
flags.DEFINE_string("variations", "light", "variations")
flags.DEFINE_string("tp", "success", "type of trajectories")
flags.DEFINE_integer("seq_len", 30, "sequence length used for training.")
flags.DEFINE_integer("skip_frame", 1, "skip frame.")
flags.DEFINE_integer("n_trajs", 15, "number of trajectories.")
flags.DEFINE_integer("n_mean_trajs", 15, "number of sampled trajectories.")
flags.DEFINE_string("camera_key", "corner2", "image key")


def _tile_first_dim(xs: np.ndarray, reps: int) -> np.ndarray:
    reps_d = (reps,) + (1,) * (xs.ndim - 1)
    return np.tile(xs, reps_d)


def get_traj_info(dataset):
    traj_indices = list(np.nonzero(dataset.dones_float)[0] + 1)
    traj_indices.insert(0, 0)

    trj_mapper = []
    for idx in range(len(traj_indices) - 1):
        traj_len = traj_indices[idx + 1] - traj_indices[idx]
        trj_mapper.extend([idx] * traj_len)

    trj_mapper = np.asarray(trj_mapper)
    return traj_indices, trj_mapper


def reward_from_pt(model, batch, seq_len=30, label_mode="last", batch_size=16, use_image=False, skip_frame=1, **kwargs):
    trajs = [[]]

    for i in tqdm(range(len(batch["obs"])), desc="split", leave=False):
        elem = [batch["obs"][i], batch["action"][i]]
        img = batch["images"][i] if use_image else None
        elem.append(img)
        trajs[-1].append(elem)
        if batch["done"][i] == 1.0 and i + 1 < len(batch["obs"]):
            trajs.append([])

    trajectories = []
    trj_mapper = []
    observation_dim = batch["obs"].shape[-1]
    action_dim = batch["action"].shape[-1]

    if use_image:
        image_dim = batch["images"].shape[1:]

    for trj_idx, traj in tqdm(enumerate(trajs), total=len(trajs), desc="chunk trajectories", leave=False):
        _obs, _act, _attn_mask = [], [], []
        if use_image:
            _img = []

        for i in range(seq_len - 1):
            _obs.append(np.zeros(observation_dim))
            _act.append(np.zeros(action_dim))
            _attn_mask.append(0.0)
            if use_image:
                _img.append(np.zeros(image_dim, dtype=np.uint8))

        for _o, _a, _im in traj:
            _obs.append(_o)
            _act.append(_a)
            _attn_mask.append(1.0)
            if use_image:
                _img.append(_im)

        traj_len = len(traj)
        if use_image:
            _obs, _act, _attn_mask, _img = np.asarray(_obs), np.asarray(_act), np.asarray(_attn_mask), np.asarray(_img)
            trajectories.append((_obs, _act, _attn_mask, _img))
        else:
            _obs, _act, _attn_mask = np.asarray(_obs), np.asarray(_act), np.asarray(_attn_mask)
            trajectories.append((_obs, _act, _attn_mask))

        for seg_idx in range(traj_len):
            trj_mapper.append((trj_idx, seg_idx))

    data_size = batch["obs"].shape[0]
    interval = int(data_size / batch_size)
    if interval * batch_size < data_size:
        interval += 1
    new_r = np.zeros(data_size)
    for i in trange(interval, desc="compute reward from (V)PT", leave=False):
        start_pt = i * batch_size
        end_pt = min((i + 1) * batch_size, data_size)

        _input_obs, _input_act, _input_timestep, _input_attn_mask = [], [], [], []
        if use_image:
            _input_img = []
        for pt in range(start_pt, end_pt):
            _trj_idx, _seg_idx = trj_mapper[pt]
            __input_obs = trajectories[_trj_idx][0][_seg_idx : _seg_idx + seq_len, :][::skip_frame]
            __input_act = trajectories[_trj_idx][1][_seg_idx : _seg_idx + seq_len, :][::skip_frame]
            if use_image:
                __input_img = trajectories[_trj_idx][-1][_seg_idx : _seg_idx + seq_len, ...][::skip_frame]

            if _seg_idx < seq_len:
                __input_timestep = np.concatenate(
                    [np.zeros(seq_len - _seg_idx, dtype=np.int32), np.arange(_seg_idx, dtype=np.int32)], axis=0
                )
            elif _seg_idx <= 500:
                __input_timestep = np.arange(_seg_idx - seq_len, _seg_idx, dtype=np.int32)
            elif 0 < _seg_idx - 500 < seq_len:
                __input_timestep = np.concatenate(
                    [np.arange(seq_len - _seg_idx + 500, dtype=np.int32), np.zeros(_seg_idx - 500, dtype=np.int32)],
                    axis=0,
                )
            else:
                __input_timestep = np.zeros(seq_len, dtype=np.int32)
            __input_timestep = __input_timestep[::skip_frame]
            __input_attn_mask = trajectories[_trj_idx][2][_seg_idx : _seg_idx + seq_len, ...][::skip_frame]

            _input_obs.append(__input_obs)
            _input_act.append(__input_act)
            _input_timestep.append(__input_timestep)
            _input_attn_mask.append(__input_attn_mask)
            if use_image:
                _input_img.append(__input_img)

        _input_obs, _input_act, _input_timestep, _input_attn_mask = map(
            lambda x: np.asarray(x), [_input_obs, _input_act, _input_timestep, _input_attn_mask]
        )
        if use_image:
            _input_img = np.asarray(_input_img)

        input = dict(
            observations=_input_obs,
            actions=_input_act,
            timestep=_input_timestep,
            attn_mask=_input_attn_mask,
        )
        if use_image:
            input.update(images=_input_img)

        jax_input = batch_to_jax(input)
        new_reward, _ = model.get_reward(jax_input)
        new_reward = new_reward.reshape(end_pt - start_pt, seq_len // skip_frame)

        # NOT USE
        if label_mode == "mean":
            new_reward = jnp.sum(new_reward, axis=1) / jnp.sum(_input_attn_mask, axis=1)
            new_reward = new_reward.reshape(-1, 1)
        elif label_mode == "last":
            new_reward = new_reward[:, -1].reshape(-1, 1)

        new_reward = np.asarray(list(new_reward))
        new_r[start_pt:end_pt, ...] = new_reward.squeeze(-1)

    return new_r


def load_embedding(rep="vip"):
    if rep == "vip":
        model = load_vip()
        transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
    if rep == "r3m":
        from r3m import load_r3m

        model = load_r3m("resnet50")
        transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
    if rep == "liv":
        from liv import load_liv

        model = load_liv()
        transform = T.Compose([T.ToTensor()])
    model.eval()
    return model, transform


def reward_from_gcr(model_group, batch, model_type="vip", use_normalize=False, batch_size=256, **kwargs):
    model, transform, device = model_group
    imgs, goal_imgs = batch["images"], batch["goal_images"]
    if model_type in ["r3m", "vip"]:
        for i in range(len(imgs)):
            imgs[i] = cv2.cvtColor(imgs[i], cv2.COLOR_RGB2BGR)
            goal_imgs[i] = cv2.cvtColor(goal_imgs[i], cv2.COLOR_RGB2BGR)

    total_distances = []
    for idx in trange(0, len(imgs), batch_size, desc="compute reward from gcr", leave=False):
        start, end = idx, min(idx + batch_size, len(imgs))
        imgs_cur, goal_imgs_cur = [], []
        for i in range(start, end):
            imgs_cur.append(transform(Image.fromarray(imgs[i].astype(np.uint8))))
            goal_imgs_cur.append(transform(Image.fromarray(goal_imgs[i].astype(np.uint8))))
        imgs_cur, goal_imgs_cur = torch.stack(imgs_cur).to(device), torch.stack(goal_imgs_cur).to(device)
        if model_type in ["r3m", "vip"]:
            imgs_cur = imgs_cur * 255

        with torch.no_grad():
            embeddings = model(imgs_cur).cpu().numpy()
            goal_embeddings = model(goal_imgs_cur).cpu().numpy()

        distances = []
        for t in range(embeddings.shape[0]):
            goal_embedding, cur_embedding = goal_embeddings[t], embeddings[t]
            cur_distance = np.linalg.norm(goal_embedding - cur_embedding)
            distances.append(cur_distance)

        if use_normalize:
            distances /= distances[0]
        distances = -1 * np.asarray(distances)
        total_distances.extend(distances)

    if use_normalize:
        total_distances += 1

    return np.asarray(total_distances)


def evaluate_model(model, batch, env_flag=True, **kwargs):
    if env_flag:
        assert batch["obs"].shape[0] == batch["action"].shape[0]
        rewards = []
        traj_cursor = 0
        for i in trange(batch["obs"].shape[0], desc="GT Reward", leave=False):
            traj_idx = batch["traj_idx"][i]
            if traj_cursor != traj_idx:
                model.reset()
                model.set_factor_values(kwargs["data_factor_values"][traj_idx])
                traj_cursor = traj_idx
            rewards.append(model.evaluate_state(batch["obs"][i], batch["action"][i])[0])
        return np.asarray(rewards)
    elif FLAGS.model_name in ["vip", "liv", "r3m"]:
        rewards = reward_from_gcr(model, batch, model_type=FLAGS.model_name, **kwargs)
    else:
        rewards = reward_from_pt(model, batch, **kwargs)

    return rewards


def sample_mean_rew(
    model,
    mean_from_obs,  ## (n_mean_samples x S', n_samples x s, n_samples x s')
    act_samples,  ## n_mean_samples x (A)
    next_obs_samples,  ## n_mean_samples x (S')
    done_samples,
    traj_samples,
    env_flag,
    batch_size: int = 2**26,
    reward_batch_size: int = 16,
    use_image=False,
    seq_len=30,
    skip_frame=1,
    mean_from_images=None,
    data_factor_values=None,
    goal_images=None,
):
    assert act_samples.shape[0] == next_obs_samples.shape[0]
    assert mean_from_obs.shape[1:] == next_obs_samples.shape[1:]

    # Compute indexes to not exceed batch size
    sample_mem_usage = act_samples.nbytes + mean_from_obs.nbytes
    obs_per_batch = batch_size // sample_mem_usage
    if use_image or obs_per_batch <= 0:
        msg = f"`batch_size` too small to compute a batch: {batch_size} < {sample_mem_usage}. / obs_per_batch: {obs_per_batch}"
        print(msg)
        obs_per_batch = 32
    else:
        print(f"obs_per_batch: {obs_per_batch}")
    idxs = np.arange(0, len(mean_from_obs), obs_per_batch)
    idxs = np.concatenate((idxs, [len(mean_from_obs)]))  # include end point

    # Compute mean rewards
    mean_rew = []
    reps = min(obs_per_batch, len(mean_from_obs))
    act_tiled = _tile_first_dim(act_samples, reps)
    next_obs_tiled = _tile_first_dim(next_obs_samples, reps)
    count = 0
    for start, end in tqdm(zip(idxs[:-1], idxs[1:]), desc="compute sample mean", total=len(idxs)):
        obs = mean_from_obs[start:end]
        done = done_samples[start:end]
        traj = traj_samples[start:end]

        obs_repeated = np.repeat(obs, len(act_samples), axis=0)
        done_repeated = np.repeat(done, len(act_samples), axis=0)
        traj_repeated = np.repeat(traj, len(act_samples), axis=0)

        batch = {
            "obs": obs_repeated,
            "action": act_tiled[: len(obs_repeated), :],
            "next_obs": next_obs_tiled[: len(obs_repeated), :],
            "done": done_repeated,
            "traj_idx": traj_repeated,
        }
        if use_image:
            images = mean_from_images[start:end]
            images_repeated = np.repeat(images, len(act_samples), axis=0)
            batch["images"] = images_repeated
        if FLAGS.model_name in ["vip", "liv", "r3m"]:
            _goal_images = goal_images[start:end]
            _goal_images_repeated = np.repeat(_goal_images, len(act_samples), axis=0)
            batch["goal_images"] = _goal_images_repeated
        rew = evaluate_model(
            model,
            batch,
            batch_size=reward_batch_size,
            env_flag=env_flag,
            seq_len=seq_len,
            use_image=use_image,
            skip_frame=skip_frame,
            data_factor_values=data_factor_values,
        )
        rew = rew.reshape(len(obs), -1)
        mean = np.mean(rew, axis=1)
        mean_rew.extend(mean)
        count += 1

    mean_rew = np.array(mean_rew)
    assert mean_rew.shape == (len(mean_from_obs),)

    return mean_rew


def _check_dist(dist: np.ndarray) -> None:
    assert np.allclose(np.sum(dist), 1)
    assert np.all(dist >= 0)


def lp_norm(arr: np.ndarray, p: int, dist: Optional[np.ndarray] = None) -> float:
    r"""Computes the L^{p} norm of arr, weighted by dist.
    Args:
        arr: The array to compute the norm of.
        p: The power to raise elements to.
        dist: A distribution to weight elements of array by.
    Returns:
        The L^{p} norm of arr with respect to the measure dist.
        That is, (\sum_i dist_i * |arr_i|^p)^{1/p}.
    """
    if dist is None:
        # Fast path: use optimized np.linalg.norm
        n = np.product(arr.shape)
        raw_norm = np.linalg.norm(arr.flatten(), ord=p)
        return raw_norm / (n ** (1 / p))

    # Otherwise, weighted; use our implementation (up to 2x slower).
    assert arr.shape == dist.shape
    _check_dist(dist)

    arr = np.abs(arr)
    arr **= p
    arr *= dist
    accum = np.sum(arr)
    accum **= 1 / p
    return accum


def canonical_scale_normalizer(
    rew: np.ndarray, p: int = 1, dist: Optional[np.ndarray] = None, eps: float = 1e-10
) -> float:
    """
    Compute coefficient by which to scale `rew` for it to have canonical scale.
    Coefficient is rounded down to `0` if computed scale is less than `eps`.
    Args:
        rew: The three-dimensional reward array to compute the normalizer for.
        p: The power to raise elements to.
        dist: The measure for the L^{p} norm.
        eps: Threshold to treat reward as zero (needed due to floating point error).
    Returns:
        Scaling coefficient by which to multiply `rew` to have unit norm.
    """
    scale = lp_norm(rew, p, dist)
    return 0 if abs(scale) < eps else 1 / scale


def sample_canon_shaping(
    model,
    batch,  ## n_samples x (s, a, s')
    act_samples,  ## n_mean_samples x (A)
    next_obs_samples,  ## n_mean_samples x (S')
    done_samples,
    traj_samples,
    env_flag=True,
    discount: float = 1.0,
    p: int = 1,
    use_image: bool = False,
    next_image_samples: np.ndarray = None,
    goal_image_samples: np.ndarray = None,
    batch_size: int = 16,
    seq_len: int = 30,
    skip_frame: int = 1,
    data_factor_values=None,
):
    # Sample-based estimate of mean reward
    n_mean_samples = len(act_samples)
    if len(next_obs_samples) != n_mean_samples:
        raise ValueError(f"Different sample length: {len(next_obs_samples)} != {n_mean_samples}")

    # EPIC only defined on infinite-horizon MDPs, so pretend episodes never end.
    ## n_samples x R(s, a, s')
    raw_rew = evaluate_model(
        model,
        batch,
        env_flag=env_flag,
        seq_len=seq_len,
        skip_frame=skip_frame,
        use_image=use_image,
        batch_size=batch_size,
        data_factor_values=data_factor_values,
    )

    ## (n_mean_samples x S', n_samples x s, n_samples x s')
    all_obs = np.concatenate(
        (next_obs_samples, batch["obs"], batch["next_obs"]), axis=0
    )  ## (n_mean_samples+2*n_samples, D)
    all_done = np.concatenate((done_samples, batch["done"], batch["next_done"]), axis=0)
    all_traj = np.concatenate((traj_samples, batch["traj_idx"], batch["next_traj_idx"]), axis=0)
    unique_obs, unique_idx, unique_inv = np.unique(all_obs, return_index=True, return_inverse=True, axis=0)  ## (K, D)
    unique_done = all_done[unique_idx]
    unique_traj = all_traj[unique_idx]
    if use_image:
        all_images = np.concatenate((next_image_samples, batch["images"], batch["next_images"]), axis=0)
        unique_images = all_images[unique_idx]
        if goal_image_samples is not None:
            all_goal_images = np.concatenate(
                (goal_image_samples, batch["goal_images"], batch["next_goal_images"]), axis=0
            )
            unique_goal_images = all_goal_images[unique_idx]
        else:
            unique_goal_images = None

    if use_image:
        mean_rew = sample_mean_rew(
            model,
            unique_obs,
            act_samples,
            next_obs_samples,
            unique_done,
            unique_traj,
            env_flag=env_flag,
            seq_len=seq_len,
            skip_frame=skip_frame,
            reward_batch_size=batch_size,
            use_image=use_image,
            mean_from_images=unique_images,
            data_factor_values=data_factor_values,
            goal_images=unique_goal_images,
        )
    else:
        mean_rew = sample_mean_rew(
            model,
            unique_obs,
            act_samples,
            next_obs_samples,
            unique_done,
            unique_traj,
            env_flag=env_flag,
            seq_len=seq_len,
            skip_frame=skip_frame,
            reward_batch_size=batch_size,
            use_image=use_image,
            mean_from_images=None,
            data_factor_values=data_factor_values,
        )
    mean_rew = mean_rew[unique_inv]  ## (n_mean_samples+2*n_samples, )
    ## n_mean_samples x E_{A,S'}[R(S, A, S')]
    dataset_mean_rew = mean_rew[0:n_mean_samples]
    ## E[R(S, A, S')]
    total_mean = np.mean(dataset_mean_rew)

    ## (E[R(s, A, S')], E[R(s', A, S')]) x n_samples
    batch_mean_rew = mean_rew[n_mean_samples:].reshape(2, -1)

    # Use mean rewards to canonicalize reward up to shaping
    deshaped_rew = {}

    raw = raw_rew  ## n_samples x R(s, a, s')
    mean = batch_mean_rew  ## (E[R(s, A, S')], E[R(s', A, S')]) x n_samples
    total = total_mean  ## E[R(S, A, S')]

    mean_obs = mean[0, :]  ## E[R(s, A, S')] x n_samples
    mean_next_obs = mean[1, :]  ## E[R(s', A, S')] x n_samples
    # Note this is the only part of the computation that depends on discount, so it'd be
    # cheap to evaluate for many values of `discount` if needed.
    deshaped = raw + discount * mean_next_obs - mean_obs - discount * total  ## n_samples x C_R(s, a, s')
    deshaped *= canonical_scale_normalizer(deshaped, p)
    deshaped_rew = deshaped

    return deshaped_rew


def _center(x: np.ndarray, weights: np.ndarray) -> np.ndarray:
    mean = np.average(x, weights=weights)
    return x - mean


def pearson_distance(rewa: np.ndarray, rewb: np.ndarray, dist: Optional[np.ndarray] = None) -> float:
    """Computes pseudometric derived from the Pearson correlation coefficient.
    It is invariant to positive affine transformations like the Pearson correlation coefficient.
    Args:
        rewa: A reward array.
        rewb: A reward array.
        dist: Optionally, a probability distribution of the same shape as rewa and rewb.
    Returns:
        Computes the Pearson correlation coefficient rho, optionally weighted by dist.
        Returns the square root of 1 minus rho.
    """
    if dist is None:
        dist = np.ones_like(rewa) / np.product(rewa.shape)
    _check_dist(dist)
    assert rewa.shape == dist.shape
    assert rewa.shape == rewb.shape

    dist = dist.flatten()
    rewa = _center(rewa.flatten(), dist)
    rewb = _center(rewb.flatten(), dist)

    vara = np.average(np.square(rewa), weights=dist)
    varb = np.average(np.square(rewb), weights=dist)
    cov = np.average(rewa * rewb, weights=dist)
    corr = cov / (np.sqrt(vara) * np.sqrt(varb))
    corr = min(corr, 1.0)  # floating point error sometimes rounds above 1.0

    return np.sqrt(0.5 * (1 - corr))


def main(_):
    base_path = pathlib.Path("/home/pref_data/reward_learning").expanduser()

    task_name = FLAGS.task_name
    model_name = FLAGS.model_name
    comment = task_name if FLAGS.comment == "" else f"{task_name}-{FLAGS.comment}"
    seed = 0

    os.environ["XLA_PYTHON_MEM_PREALLOCATE"] = "false"

    set_random_seed(seed)

    # Load Dataset.
    ds_base_path = pathlib.Path("/home/pref_data").expanduser()
    ds_task_name, ds_variation_name, tp = FLAGS.task_name, FLAGS.variations, FLAGS.tp
    if tp == "success":
        dataset_path = ds_base_path / ds_task_name / ds_variation_name / "train"
    elif tp == "failure":
        dataset_path = ds_base_path / ds_task_name / ds_variation_name / "failure_episodes"
    print(f"dataset_path: {dataset_path}")

    ds = qlearning_factorworld_dataset(dataset_path, camera_keys=[FLAGS.camera_key])
    if len(ds[FLAGS.camera_key]) == 0:
        return
    dataset = RelabeledDataset(
        ds["observations"],
        ds["actions"],
        ds["rewards"],
        ds["terminals"],
        ds["next_observations"],
        images=ds[FLAGS.camera_key],
    )

    traj_indices, trj_mapper = get_traj_info(dataset)
    n_trajs, n_mean_trajs = FLAGS.n_trajs, FLAGS.n_mean_trajs
    n_samples = traj_indices[n_trajs]
    n_mean_samples = traj_indices[n_trajs + n_mean_trajs]

    # ====== Load Data ======
    batch = {
        "obs": dataset.observations[:n_samples],
        "next_obs": dataset.next_observations[:n_samples],
        "action": dataset.actions[:n_samples],
        "images": dataset.images[:n_samples],
        "next_images": dataset.images[1 : n_samples + 1],
        "done": dataset.dones_float[:n_samples],
        "next_done": dataset.dones_float[1 : n_samples + 1],
        "traj_idx": trj_mapper[:n_samples],
        "next_traj_idx": trj_mapper[1 : n_samples + 1],
    }

    act_samples = dataset.actions[n_samples:n_mean_samples]
    next_obs_samples = dataset.next_observations[n_samples:n_mean_samples]
    next_image_samples = dataset.images[n_samples + 1 : n_mean_samples + 1]
    done_samples = dataset.dones_float[n_samples + 1 : n_mean_samples + 1]
    traj_samples = trj_mapper[n_samples + 1 : n_mean_samples + 1]

    # ===== Load Environment, if env_flag is True. =====
    with open(dataset_path.parent / "data_factor_values.pkl", "rb") as f:
        data_factor_values = pickle.load(f)

    cfg = OmegaConf.load(pathlib.Path("/home/factor-world/metaworld_generalization/cfgs") / "data.yaml")
    factors = ds_variation_name.split("-")

    factor_kwargs = {factor: cfg.env.factors[factor] for factor in factors}
    env = make_wrapped_env(
        task_name,
        use_train_xml=True,
        factor_kwargs=factor_kwargs,
        image_obs_size=cfg.env.image_obs_size,
        camera_name=[FLAGS.camera_key],
        observe_goal=cfg.env.observe_goal,
        random_init=cfg.env.random_init,
        default_num_resets_per_randomize=1,
    )
    _ = env.reset()
    observation_dim = env.unwrapped.observation_space["proprio"].shape[0]
    action_dim = env.unwrapped.action_space.shape[0]

    # ====== Load Model =====
    if FLAGS.model_name in ["PT", "VPT"]:
        path = base_path / f"factorworld-{task_name}" / f"{model_name}" / comment / f"s{seed}"

        best_model = path / "best_model.pkl"
        if not best_model.exists():
            best_model = path / "model.pkl"

        with best_model.open("rb") as fin:
            checkpoint_data = pickle.load(fin)
        state = checkpoint_data["state"]

        jax_devices = jax.local_devices()

        # Load trained PT model.
        if checkpoint_data.get("config") is not None:
            transformer = checkpoint_data["config"]
        else:
            if model_name == "PT":
                transformer = PTLearner.get_default_config()
            elif model_name == "VPT":
                transformer = VPTLearner.get_default_config()
            transformer.embd_dim = 128
            transformer.n_layer = 1
            transformer.n_head = 4
            transformer.skip_frame = FLAGS.skip_frame

        config = transformers.GPT2Config(**transformer)
        config.warmup_steps = 10
        config.total_steps = 1000

        if model_name == "PT":
            reward_learner = PTLearner(config, observation_dim, action_dim, jax_devices)
        elif model_name == "VPT":
            image_dim = (224, 224, 3)
            reward_learner = VPTLearner(config, image_dim, action_dim)
        reward_learner.load(state)

    elif FLAGS.model_name in ["r3m", "vip", "liv"]:
        device = torch.device("cuda")
        model, transform = load_embedding(rep=FLAGS.model_name)

        # Make batches of goal images.
        goal_images = []
        for idx in range(len(traj_indices) - 1):
            traj_len = traj_indices[idx + 1] - traj_indices[idx]
            for _ in range(traj_len):
                goal_images.append(dataset.images[traj_indices[idx]])
        goal_images = np.asarray(goal_images)
        batch.update({"goal_images": goal_images[: n_samples + 1], "next_goal_images": goal_images[1 : n_samples + 1]})
        goal_image_samples = goal_images[n_samples + 1 : n_mean_samples + 1]

    if model_name == "env":
        canon_rew = sample_canon_shaping(
            env,
            batch,
            act_samples,
            next_obs_samples,
            done_samples,
            traj_samples,
            env_flag=True,
            data_factor_values=data_factor_values,
        )
    elif model_name == "PT":
        canon_rew = sample_canon_shaping(
            reward_learner,
            batch,
            act_samples,
            next_obs_samples,
            done_samples,
            traj_samples,
            batch_size=4096,
            env_flag=False,
            use_image=False,
            skip_frame=FLAGS.skip_frame,
        )
    elif model_name == "VPT":
        canon_rew = sample_canon_shaping(
            reward_learner,
            batch,
            act_samples,
            next_obs_samples,
            done_samples,
            traj_samples,
            batch_size=48,
            env_flag=False,
            use_image=True,
            next_image_samples=next_image_samples,
            skip_frame=FLAGS.skip_frame,
        )
    elif model_name == "liv":
        canon_rew = sample_canon_shaping(
            (model, transform, device),
            batch,
            act_samples,
            next_obs_samples,
            done_samples,
            traj_samples,
            batch_size=256,
            env_flag=False,
            use_image=True,
            next_image_samples=next_image_samples,
            goal_image_samples=goal_image_samples,
            skip_frame=1,
        )

    filename = f"canon-{FLAGS.camera_key}-n{n_trajs}-mean{n_mean_trajs}-{FLAGS.tp}-{model_name}-{FLAGS.comment}.pkl"
    with open(dataset_path.parent / filename, "wb") as f:
        pickle.dump(canon_rew, f)


if __name__ == "__main__":
    app.run(main)
