# Path: PaRI/contrastive.py
import copy
import itertools
from typing import Any, Callable, Dict

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
from gymnasium.wrappers.time_aware_observation import TimeAwareObservation

from experiments.envs import ManyDoorsEnv, TwoDoorsEnv
from experiments.envs.wrappers import OneHotFullImage, OneHotPartialImage
from src.popl import popl_policy_search, popl_search

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def biased_bce_with_logits(adv1, adv2, y, bias=1.0):
    # apply logsum exp trick
    # y=1 if we prefer adv2 over adv1
    # we need to implement numerical stability trick

    # compute the logit difference
    logit21 = adv2 - bias * adv1
    logit12 = adv1 - bias * adv2

    max21 = torch.clamp(-logit21, min=0, max=None)
    max12 = torch.clamp(-logit12, min=0, max=None)

    nlp21 = torch.log(torch.exp(-max21) + torch.exp(-logit21 - max21)) + max21
    nlp12 = torch.log(torch.exp(-max12) + torch.exp(-logit12 - max12)) + max12

    # compute the loss
    loss = y * nlp21 + (1 - y) * nlp12
    loss = loss.mean()

    # compute the accuracy
    with torch.no_grad():
        acc = ((adv2 > adv1) == torch.round(y)).float().mean()

    return loss, acc


def biased_bce_with_scores(adv, scores, bias=1.0):
    # for now label clip does nothing.

    idx = torch.argsort(scores, dim=0)
    adv_sorted = adv[idx]

    # compute normalized loss
    logits = adv_sorted.unsqueeze(0) - bias * adv_sorted.unsqueeze(1)
    max_val = torch.clamp(-logits, min=0, max=None)
    loss = torch.log(torch.exp(-max_val) +
                     torch.exp(-logits - max_val)) + max_val

    loss = torch.triu(loss, diagonal=1)
    mask = loss != 0.0
    loss = loss.sum() / mask.sum()

    with torch.no_grad():
        unbiased_logits = adv.unsqueeze(0) - adv.unsqueeze(1)
        acc = ((unbiased_logits > 0) * mask).sum() / mask.sum()

    return loss, acc


class Policy(nn.Module):
    def __init__(self, input_size, channels, features_size, output_size):
        super(Policy, self).__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(channels, 16, (2, 2)),
            nn.ReLU(),
            nn.Conv2d(16, 64, (2, 2)),
            nn.ReLU(),
            nn.Flatten(),
        )
        with torch.no_grad():
            n_flatten = self.cnn(torch.as_tensor(torch.randn(1, channels, 3, 3))).shape[
                1
            ]
        self.fc1 = nn.Linear(n_flatten, features_size)
        self.last_layer = nn.Linear(features_size, output_size, bias=False)

        self.reward_head = nn.Linear(
            features_size, 1
        )  # this is incase we want to use this policy as a reward model

        self.softmax = nn.Softmax(dim=-1)

        self.features_size = features_size
        self.output_size = output_size

        self.alpha = 1.0
        self.contrastive_bias = 1.0
        self.optimizer = None

    def forward(self, x):
        old_batch = x.shape[0]
        # concatinate all batch dimensions
        x = x.reshape(-1, 3, 3, 3).permute(0, 3, 1, 2)
        x = self.cnn(x)
        x = torch.relu(self.fc1(x))
        # x = torch.relu(self.fc3(x))
        x = self.last_layer(x)
        x = self.softmax(x)
        x = x.reshape(old_batch, -1, self.output_size)
        return self.softmax(x)

    def get_reward(self, x):
        old_batch = x.shape[0]
        # concatinate all batch dimensions
        x = x.reshape(-1, 3, 3, 3).permute(0, 3, 1, 2)
        x = self.cnn(x)
        x = torch.relu(self.fc1(x))
        x = self.reward_head(x)
        x = x.reshape(old_batch, -1, 1)
        return x

    def get_logprob(self, obs, action):
        # gets the log probability of a state action pair or snippet

        dist = self.forward(obs)
        # print(f"dist shape: {dist.shape}")

        if isinstance(dist, torch.distributions.Distribution):
            lp = dist.log_prob(action)
        else:
            # print(f"dist shape: {dist.shape}, action shape: {action.shape}")
            assert dist.shape == action.shape
            # for independent gaussian with unit var, log_prob is MSE
            # lp = torch.log(torch.sum(dist * action, dim=-1) + 1e-8).sum(dim=-1) #-torch.square(dist - action).sum(dim=-1
            # lp = -torch.square(dist - action).sum(dim=-1)
            # print(f"lp: {lp}")
            # print(f"lp shape: {lp.shape}")
            lp = torch.log(torch.sum(dist * action, dim=-1) + 1e-8).sum(dim=-1)

        return lp

    def get_features(self, batch):
        old_batch = batch.shape[0]
        # gets the features from the policy before softmax
        x = batch.reshape(-1, 3, 3, 3).permute(0, 3, 1, 2)
        x = self.cnn(x)
        x = torch.relu(self.fc1(x))
        # notice that we don't apply the last layer
        x = x.reshape(old_batch, -1, self.features_size)
        return x

    # gets the loss of the policy from a batch of preferences
    def _get_cpl_loss(self, batch):
        if isinstance(batch, dict) and "label" in batch:
            obs = torch.cat([batch["obs_1"], batch["obs_2"]], dim=0)
            action = torch.cat([batch["action_1"], batch["action_2"]], dim=0)
            # print(f"obs shape: {obs.shape}, action shape: {action.shape}")
        else:
            assert "score" in batch
            obs, action = batch["obs"], batch["action"]

        # print(f"INSIDE obs shape: {obs.shape}, action shape: {action.shape}")
        dist = self.forward(obs)
        # print(f"dist shape: {dist.shape}")

        if isinstance(dist, torch.distributions.Distribution):
            lp = dist.log_prob(action)
        else:
            # print(f"dist shape: {dist.shape}, action shape: {action.shape}")
            assert dist.shape == action.shape

            lp = torch.log(torch.sum(dist * action, dim=-1) + 1e-8).sum(
                dim=-1
            )  # -torch.square(dist - action).sum(dim=-1)
            # lp = -torch.square(dist - action).sum(dim=-1).sum(dim=-1)

        # compute advantages
        adv = self.alpha * lp
        # print(f"adv shape: {adv.shape}")
        if adv.dim() > 1:
            segment_adv = adv.sum(dim=-1)
        else:
            segment_adv = adv

        # compute loss
        if "score" in batch:
            cpl_loss, accuracy = biased_bce_with_scores(
                segment_adv, batch["score"].float(), bias=self.contrastive_bias
            )
        else:
            adv1, adv2 = torch.chunk(segment_adv, 2, dim=0)  # split
            cpl_loss, accuracy = biased_bce_with_logits(
                adv1, adv2, batch["label"].float(), bias=self.contrastive_bias
            )

        return cpl_loss, accuracy

    def _get_bc_loss(self, obs, action):
        dist = self.forward(obs)

        if isinstance(dist, torch.distributions.Distribution):
            loss = -dist.log_prob(action)  # NLL LossWWW
        else:
            loss = torch.square(dist - action)

        # print(f"predicted: {torch.argmax(dist, dim=-1)}, actual: {torch.argmax(action, dim=-1)}")

        return loss.mean()  # Simple average.

    def bc_train_step(self, batch: Dict, step: int, total_steps: int) -> Dict:
        bc_loss = self._get_bc_loss(batch["obs_1"], batch["action_1"])

        self.optimizer.zero_grad()
        bc_loss.backward()
        self.optimizer.step()

        return dict(bc_loss=bc_loss.item())

    def train_step(self, batch: Dict, step: int, total_steps: int) -> Dict:
        cpl_loss, accuracy = self._get_cpl_loss(batch)

        loss = cpl_loss  # + bc_loss

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return dict(cpl_loss=cpl_loss.item(), accuracy=accuracy.item())


class MetaworldPolicy(nn.Module):
    def __init__(self, input_size, features_size, output_size):
        super(MetaworldPolicy, self).__init__()

        self.fc0 = nn.Linear(input_size, features_size)
        self.last_layer = nn.Linear(features_size, output_size, bias=False)

        self.reward_head = nn.Linear(
            features_size, 1
        )  # this is incase we want to use this policy as a reward model

        self.features_size = features_size
        self.output_size = output_size

        self.alpha = 1.0
        self.contrastive_bias = 1.0
        self.optimizer = None

    def forward(self, x):
        if len(x.shape) == 1:
            x = x.unsqueeze(0).unsqueeze(0)
        old_batch = x.shape[0]
        # concatinate all batch dimensions
        x = torch.relu(self.fc0(x))
        # x = torch.relu(self.fc3(x))
        if x.shape[0] == 1:
            x = x.squeeze(0)
        if x.shape[0] == 1:
            x = x.squeeze(0)
        x = self.last_layer(x)
        return x

    def get_reward(self, x):
        if len(x.shape) == 1:
            x = x.unsqueeze(0)
        old_batch = x.shape[0]
        # concatinate all batch dimensions
        x = torch.relu(self.fc0(x))
        #x = torch.relu(self.fc1(x))
        x = self.reward_head(x)
        x = x.reshape(old_batch, -1, 1)
        if len(x.shape) == 3:
            x = x.squeeze(0).squeeze(0)
        return x

    def get_logprob(self, obs, action):
        # gets the log probability of a state action pair or snippet

        dist = self.forward(obs)
        # print(f"dist shape: {dist.shape}")

        if isinstance(dist, torch.distributions.Distribution):
            lp = dist.log_prob(action)
        else:
            # print(f"dist shape: {dist.shape}, action shape: {action.shape}")
            assert dist.shape == action.shape
            # for independent gaussian with unit var, log_prob is MSE
            # lp = torch.log(torch.sum(dist * action, dim=-1) + 1e-8).sum(dim=-1) #-torch.square(dist - action).sum(dim=-1
            # lp = -torch.square(dist - action).sum(dim=-1)
            # print(f"lp: {lp}")
            # print(f"lp shape: {lp.shape}")
            lp = -torch.square(dist - action).sum(dim=-1)

        return lp

    def get_features(self, batch):
        old_batch = batch.shape[0]
        # gets the features from the policy before softmax
        x = torch.relu(self.fc0(batch))
       # x = torch.relu(self.fc1(x))
        # notice that we don't apply the last layer
        x = x.reshape(old_batch, -1, self.features_size)
        return x

    # gets the loss of the policy from a batch of preferences
    def _get_cpl_loss(self, batch):
        if isinstance(batch, dict) and "label" in batch:
            obs = torch.cat([batch["obs_1"], batch["obs_2"]], dim=0)
            action = torch.cat([batch["action_1"], batch["action_2"]], dim=0)
            # print(f"obs shape: {obs.shape}, action shape: {action.shape}")
        else:
            assert "score" in batch
            obs, action = batch["obs"], batch["action"]

        # print(f"INSIDE obs shape: {obs.shape}, action shape: {action.shape}")
        dist = self.forward(obs)
        # print(f"dist shape: {dist.shape}")

        if isinstance(dist, torch.distributions.Distribution):
            lp = dist.log_prob(action)
        else:
            # print(f"dist shape: {dist.shape}, action shape: {action.shape}")
            assert dist.shape == action.shape

            # -torch.square(dist - action).sum(dim=-1)
            lp = -torch.square(dist - action).sum(dim=-1)
            # lp = -torch.square(dist - action).sum(dim=-1).sum(dim=-1)

        # compute advantages
        adv = self.alpha * lp
        # print(f"adv shape: {adv.shape}")
        if adv.dim() > 1:
            segment_adv = adv.sum(dim=-1)
        else:
            segment_adv = adv

        # compute loss
        if "score" in batch:
            cpl_loss, accuracy = biased_bce_with_scores(
                segment_adv, batch["score"].float(), bias=self.contrastive_bias
            )
        else:
            adv1, adv2 = torch.chunk(segment_adv, 2, dim=0)  # split
            cpl_loss, accuracy = biased_bce_with_logits(
                adv1, adv2, batch["label"].float(), bias=self.contrastive_bias
            )

        return cpl_loss, accuracy

    def _get_bc_loss(self, obs, action):
        dist = self.forward(obs)

        if isinstance(dist, torch.distributions.Distribution):
            loss = -dist.log_prob(action)  # NLL LossWWW
        else:
            loss = torch.square(dist - action)

        # print(f"predicted: {torch.argmax(dist, dim=-1)}, actual: {torch.argmax(action, dim=-1)}")

        return loss.mean()  # Simple average.

    def bc_train_step(self, batch: Dict, step: int, total_steps: int) -> Dict:
        bc_loss = self._get_bc_loss(batch["obs_1"], batch["action_1"])

        self.optimizer.zero_grad()
        bc_loss.backward()
        self.optimizer.step()

        return dict(bc_loss=bc_loss.item())

    def train_step(self, batch: Dict, step: int, total_steps: int) -> Dict:
        cpl_loss, accuracy = self._get_cpl_loss(batch)

        loss = cpl_loss  # + bc_loss

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return dict(cpl_loss=cpl_loss.item(), accuracy=accuracy.item())


def train_cpl_last_layers(
    *,
    policy: Policy,
    sample_batch: Callable[[int], Dict[str, torch.Tensor]],
    batch_size: int,
    lr: float,
    num_iterations: int,
    device: torch.device,
    popsize: int,
    step_stdev: float,
    env_name: str = "TwoDoors",
    experiment_dir: str = "./results/minigrid/",
) -> Policy:
    # this is for experimental rigor, we train CPL like lexicase
    policy.to(device).to(torch.float)
    policy.optimizer = optim.Adam(policy.parameters(), lr=lr)

    # make popsize copies of the policy
    policies = []

    for i in tqdm.trange(popsize, desc="MultiCPL Training"):
        copy_policy = copy.deepcopy(policy).to(device)

        # train copy policy
        copy_policy.optimizer = optim.Adam(copy_policy.parameters(), lr=lr)
        # all params require grad false
        for param in copy_policy.parameters():
            param.requires_grad = False

        copy_policy.last_layer.weight.requires_grad = True

        for _ in range(num_iterations):
            batch_data = sample_batch(batch_size, env_name)
            step_info = copy_policy.train_step(batch_data, 0, 0)
            # print(f"loss = {step_info['cpl_loss']:.4f}, accuracy = {step_info['accuracy']:.4f}")

        policies.append(copy_policy)

    return policies


def train_cpl(
    *,
    policy: Policy,
    sample_batch: Callable[[int], Dict[str, torch.Tensor]],
    batch_size: int,
    lr: float,
    num_iterations: int,
    device: torch.device,
    env_name: str = "TwoDoors",
    experiment_dir: str = "./results/minigrid/",
) -> Policy:
    policy.to(device).to(torch.float)
    policy.optimizer = optim.Adam(policy.parameters(), lr=lr)

    progress_bar = tqdm.tqdm(range(num_iterations), desc="CPL Training")
    for _ in progress_bar:
        batch_data = sample_batch(batch_size, env_name)

        step_info = policy.train_step(batch_data, 0, 0)
        progress_bar.set_description(
            f"loss = {step_info['cpl_loss']:.4f}, accuracy = {step_info['accuracy']:.4f}"
        )

    return policy


def train_bc(
    *,
    policy: Policy,
    sample_batch: Callable[[int], Dict[str, torch.Tensor]],
    batch_size: int,
    lr: float,
    num_iterations: int,
    device: torch.device,
    env_name: str = "TwoDoors",
    experiment_dir: str = "./results/minigrid/",
) -> Policy:

    policy.to(device).to(torch.float)
    policy.optimizer = optim.Adam(policy.parameters(), lr=lr)

    progress_bar = tqdm.tqdm(range(num_iterations), desc="BC Training")
    for _ in progress_bar:
        batch_data = sample_batch(batch_size, env_name)

        step_info = policy.bc_train_step(batch_data, 0, 0)
        progress_bar.set_description(f"loss = {step_info['bc_loss']:.4f}")

    return policy


def train_popl(
    *,
    policy: Policy,
    sample_batch: Callable[[int], Dict[str, torch.Tensor]],
    batch_size: int,
    popsize: int,
    step_stdev: float,
    num_iterations: int,
    device: torch.device,
    env_name: str = "TwoDoors",
    experiment_dir: str = "./results/minigrid/",
    resamples: int = 1,
    num_features: int = 128,
    mutation_fn=None,
    downsample_level=1,
) -> Policy:

    policy.to(device).to(torch.float)

    progress_bar = tqdm.tqdm(range(num_iterations), desc="Lex Training")
    resample_steps = num_iterations // resamples
    population  = policy #just for starting

    #logging
    info = {}
    best_scores = []


    for i in tqdm.trange(resamples):
        # get features for lex
        batch = sample_batch(batch_size, env_name)
        # get features of all the obs in the batch

        features_1 = policy.get_features(batch["obs_1"])
        features_2 = policy.get_features(batch["obs_2"])
        labels = batch["label"]

        population, scores, _ = popl_policy_search(
            population,
            labels,
            features_1,
            features_2,
            batch["action_1"],
            batch["action_2"],
            popsize,
            resample_steps,
            step_stdev,
            True,
            downsample_level=downsample_level,
            elitism=True,
            alpha=None,
            mutation_fn=mutation_fn,
        )

        print(
            f"best total score: {torch.max(torch.sum(scores, dim=1))}/{scores.shape[1]}")

        best_scores.append(torch.max(torch.sum(scores, dim=1)))

        
    # sort policies by score
    sorted_pop = population[torch.argsort(torch.sum(scores, dim=1)).flip(0)]

    policies = []

    action_dim = batch["action_1"].shape[-1]

    for i in range(sorted_pop.shape[0]):
        copy_policy = copy.deepcopy(policy).to(device)
        last_layer = nn.Linear(action_dim, num_features, bias=False).to(device)
        last_layer = nn.Linear(action_dim, num_features, bias=False).to(device)
        last_layer.weight = nn.Parameter(torch.Tensor(population[i]))

        copy_policy.last_layer = last_layer
        copy_policy = copy_policy.to(device)
        policies.append(copy_policy)

    info["best_scores"] = best_scores

    return policies, sorted_pop, info


def train_popl_reward(
    *,
    policy: Policy,
    sample_batch: Callable[[int], Dict[str, torch.Tensor]],
    batch_size: int,
    popsize: int,
    step_stdev: float,
    num_iterations: int,
    device: torch.device,
    env_name: str = "TwoDoors",
    experiment_dir: str = "./results/minigrid/",
    resamples: int = 1,
    num_features: int = 1024,
    downsample_level: int = 1,
    mutation_fn = None,
    elitism=True
) -> Policy:

    policy.to(device).to(torch.float)

    #logging
    info = {}
    best_scores = []

    progress_bar = tqdm.tqdm(range(num_iterations), desc="POPL Reward Training")

    population = None
    resample_steps = num_iterations // resamples

    for i in tqdm.trange(resamples):
        # get features for lex
        batch = sample_batch(
            batch_size, env_name
        )  # this is the sample batch function with labels generated basd on partial return
        # get features of all the obs in the batch

        features_1 = policy.get_features(batch["obs_1"])
        features_2 = policy.get_features(batch["obs_2"])
        labels = batch["label"]

        population, scores, _ = popl_search(
            population,
            labels,
            features_1,
            features_2,
            popsize,
            resample_steps,
            step_stdev,
            True,
            downsample_level=downsample_level,
            elitism=True,
            mutation_fn=mutation_fn,
            bt=None,
        )

        print(
            f"best total score: {torch.max(torch.sum(scores, dim=1))}/{scores.shape[1]}")
        best_scores.append(torch.max(torch.sum(scores, dim=1)))

    # sort policies by score
    sorted_pop = population[torch.argsort(torch.sum(scores, dim=1))[::-1]]
    rfuncs = []

    for i in range(sorted_pop.shape[0]):
        copy_policy = copy.deepcopy(policy).to(device)
        last_layer = nn.Linear(1, num_features, bias=False).to(device)
        last_layer.weight = nn.Parameter(torch.Tensor(population[i].T))

        copy_policy.reward_head = last_layer
        copy_policy = copy_policy.to(device)
        rfuncs.append(copy_policy)

    info["best_scores"] = best_scores
    return rfuncs, sorted_pop, info


# gets all the data loaded and seperated by identity
def sample_data(env_name, demo_file):
    path = demo_file
    # load the numpy array
    demos = np.load(path, allow_pickle=True)
    demos = list(demos)

    demos_by_identity = {1: [], 2: []}

    # print(f"size of demos: {len(demos)}")

    print(f"actions: {demos[0]['actions'][0]}")

    minigrid = False
    if env_name[-2:] == "v2":
        action_space = demos[0]["actions"][0].shape[0]
    else:
        action_space = 7
        minigrid = True

    for i in range(len(demos)):
        demo = demos[i]
        obses = demo["obs"]
        actions = demo["actions"]
        # convert to onehot for minigrid
        if minigrid:
            actions = np.eye(action_space)[actions]
            # add a zero action at the end
            actions = np.concatenate(
                [actions, np.zeros((1, action_space))], axis=0)

        rewards = demo["rewards"]
        identity = demo["identity"]
        logprobs = demo["logprobs"]
        values = demo["values"]

        demos_by_identity[identity].append(
            {"obs": obses, "actions": actions,
                "logprobs": logprobs, "rewards": rewards}
        )

    return demos_by_identity


# samples all the demos and returns them as a batch
def sample_all(_, env_name, identities_to_use=[1, 2]):
    if env_name == "TwoDoors":
        # load from demos/minigrid.npy
        path = "demos/twodoors.npy"

    elif env_name == "ManyDoors":
        path = "demos/manyDoors.npy"
        # load the numpy array
        demos = np.load(path, allow_pickle=True)
        demos = list(demos)

        demos_by_identity = {1: [], 2: []}

        print(f"size of demos: {len(demos)}")

        minigrid = False
        if env_name[-2:] == "v2":
            action_space = demos[0]["actions"][0].shape[0]
        else:
            action_space = 7
            minigrid = True

        for i in range(len(demos)):
            demo = demos[i]
            obses = demo["obs"]
            actions = demo["actions"]

            # convert to onehot
            if minigrid:
                actions = np.eye(action_space)[actions]
                # add a zero action at the end
                actions = np.concatenate(
                    [actions, np.zeros((1, action_space))], axis=0)

            rewards = demo["rewards"]
            identity = demo["identity"]
            logprobs = demo["logprobs"]
            values = demo["values"]

            demos_by_identity[identity].append(
                {
                    "obs": obses,
                    "actions": actions,
                    "logprobs": logprobs,
                    "rewards": rewards,
                }
            )

        obs_1 = []
        acts_1 = []

        obs_2 = []
        acts_2 = []

        labels = []

        snippet_length = 8

        # take every pair of demos from each identity
        for identity in identities_to_use:
            for demo_1, demo_2 in itertools.combinations(
                demos_by_identity[identity], 2
            ):
                obs_1.append(demo_1["obs"])
                acts_1.append(demo_1["actions"])

                obs_2.append(demo_2["obs"])
                acts_2.append(demo_2["actions"])

                # label is 0 if the first demo has higher rewards
                if np.sum(demo_1["logprobs"]) > np.sum(demo_2["logprobs"]):
                    labels.append(0)
                elif np.sum(demo_1["logprobs"]) < np.sum(demo_2["logprobs"]):
                    labels.append(1)
                else:
                    # discard if the rewards are the same
                    obs_1.pop()
                    acts_1.pop()
                    obs_2.pop()
                    acts_2.pop()

        acts_1 = torch.Tensor(acts_1).to(device)
        acts_2 = torch.Tensor(acts_2).to(device)
        labels = torch.Tensor(labels).to(device)

        batch = {
            "obs_1": obs_1,
            "obs_2": obs_2,
            "action_1": acts_1,
            "action_2": acts_2,
            "label": labels,
        }
        return batch


# this has all the demo specific code to get everything into one unified representation.
def sample_batch(
    demos_by_identity,
    batch_size,
    env_name,
    identities_to_use=[1, 2],
    snippet_length=8,
    ranking="regret",
    ratio=0.5,
):

    obs_1 = []
    acts_1 = []

    obs_2 = []
    acts_2 = []

    labels = []

    i = 0
    while i < batch_size:
        # take a random identity, and a random pair from that identity
        # identity = np.random.choice([1, 2])

        if len(identities_to_use) == 1:
            identity = identities_to_use[0]
        else:
            identity = np.random.choice(
                identities_to_use, p=[ratio, 1 - ratio])

        demo_1 = np.random.choice(demos_by_identity[identity])
        demo_2 = np.random.choice(demos_by_identity[identity])

        # pick a random start point for both
        start_point = min(
            np.random.randint(0, len(demo_1["obs"]) - snippet_length),
            np.random.randint(0, len(demo_2["obs"]) - snippet_length),
        )

        # pick a length n segment
        obs_1.append(demo_1["obs"][start_point: start_point + snippet_length])
        acts_1.append(demo_1["actions"]
                          [start_point: start_point + snippet_length])

        obs_2.append(demo_2["obs"][start_point: start_point + snippet_length])
        acts_2.append(demo_2["actions"]
                      [start_point: start_point + snippet_length])

        if ranking == "regret":
            # label is 1 if the second demo has higher advatnage
            if np.sum(
                demo_1["logprobs"][start_point: start_point + snippet_length]
            ) > np.sum(demo_2["logprobs"][start_point: start_point + snippet_length]):
                labels.append(0)
            elif np.sum(
                demo_1["logprobs"][start_point: start_point + snippet_length]
            ) < np.sum(demo_2["logprobs"][start_point: start_point + snippet_length]):
                labels.append(1)
            else:
                # discard if the logprobs are the same
                obs_1.pop()
                acts_1.pop()
                obs_2.pop()
                acts_2.pop()
                i -= 1
            i += 1

        elif ranking == "partial_return":
            # label is 1 if the first demo has higher advatnage
            if np.sum(
                demo_1["rewards"][start_point: start_point + snippet_length]
            ) > np.sum(demo_2["rewards"][start_point: start_point + snippet_length]):
                labels.append(0)
            elif np.sum(
                demo_1["rewards"][start_point: start_point + snippet_length]
            ) < np.sum(demo_2["rewards"][start_point: start_point + snippet_length]):
                labels.append(1)
            else:
                # discard if the logprobs are the same
                obs_1.pop()
                acts_1.pop()
                obs_2.pop()
                acts_2.pop()
                i -= 1
            i += 1

    obs1 = np.array(obs_1)
    obs2 = np.array(obs_2)
    acts1 = np.array(acts_1)
    acts2 = np.array(acts_2)
    labels = np.array(labels)

    obs_1 = torch.Tensor(obs_1).to(device)
    obs_2 = torch.Tensor(obs_2).to(device)
    acts_1 = torch.Tensor(acts_1).to(device)
    acts_2 = torch.Tensor(acts_2).to(device)
    labels = torch.Tensor(labels).to(device)

    batch = {
        "obs_1": obs_1,
        "obs_2": obs_2,
        "action_1": acts_1,
        "action_2": acts_2,
        "label": labels,
    }
    return batch
