# Copyright (C) king.com Ltd 2025
# License: Apache 2.0
#
# NOTE: This code was produced quickly for the conference deadline and needs refactoring. Aiming to provide a nicer, updated version in the future.
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from prompt_dt.prompt_evaluate_episodes import prompt_evaluate_episode_rtg
from prompt_dt.prompt_utils import discount_cumsum


class RewardModel(nn.Module):
    def __init__(self, input_dim, hidden_arch=[16]):
        super(RewardModel, self).__init__()

        layers = []
        prev_dim = input_dim

        for hidden_dim in hidden_arch:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.ReLU())
            prev_dim = hidden_dim

        layers.append(nn.Linear(prev_dim, 1))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class NeuralBandit:
    def __init__(self, input_dim, n_segments, segment_length, regularization_lambda=1.0, learning_rate=0.01,
                 steps_per_update=100, rm_hidden_arch=[16]):
        self.n_segments = n_segments
        self.segment_length = segment_length

        self.segment_reward_models = []
        self.initial_weights = []
        self.optimizers = []
        self.input_dim = input_dim
        for _ in range(n_segments):
            reward_model = RewardModel(input_dim, hidden_arch=rm_hidden_arch)

            initial_weights = {k: v.clone().detach() for k, v in reward_model.state_dict().items()}
            optimizer = optim.SGD(reward_model.parameters(), lr=learning_rate, weight_decay=regularization_lambda)

            self.initial_weights.append(initial_weights)
            self.optimizers.append(optimizer)
            self.segment_reward_models.append(
                reward_model) 
            
        self.regularization_lambda = regularization_lambda
        self.learning_rate = learning_rate
        self.steps_per_update = steps_per_update

        self.loss_fn = nn.MSELoss()
        self.segments = [[] for _ in range(n_segments)]
        self.rewards = []

    def store_observation(self, segments, reward):
        """Store observed context and reward."""
        assert len(segments) == self.n_segments
        for segment_idx, segment in enumerate(segments):
            self.segments[segment_idx].append(segment)
        self.rewards.append(reward)

    def update_model(self):
        """Train the neural network using the observed data."""
        if len(self.rewards) < 2:
            return

        all_segment_model_losses = [[] for _ in range(self.n_segments)]
        for segment_idx in range(self.n_segments):

            shuffled_indices = np.random.permutation(len(self.rewards))
            self.rewards = [self.rewards[i] for i in shuffled_indices]
            for segment_shuffle_idx in range(self.n_segments):
                self.segments[segment_shuffle_idx] = [self.segments[segment_shuffle_idx][i] for i in shuffled_indices]

            contexts_tensor = torch.from_numpy(np.array(self.segments[segment_idx])).float()
            rewards_tensor = torch.from_numpy(np.array(self.rewards)).float()

            losses = []
            for step in range(self.steps_per_update):
                self.optimizers[segment_idx].zero_grad()

                # Compute predictions and loss
                predictions = self.segment_reward_models[segment_idx](contexts_tensor)
                loss = self.loss_fn(predictions.squeeze(), rewards_tensor.squeeze())
                losses.append(loss.item())

                total_loss = loss
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.segment_reward_models[segment_idx].parameters(), 10.0)
                self.optimizers[segment_idx].step()

            all_segment_model_losses[segment_idx].extend(losses)

        return all_segment_model_losses

    def predict(self, context):
        """Predict the reward for a given context."""
        segment_preds = []
        context_tensor = torch.from_numpy(np.array(context)).to(torch.float32)
        with torch.no_grad():
            for segment_idx in range(self.n_segments):
                reward_pred = self.segment_reward_models[segment_idx](context_tensor).squeeze()
                segment_preds.append(reward_pred)

        return segment_preds


class SingleArmContextualTS:
    """
    TS with a single arm (per segment), we assume one reward function that varies based on the context (prompt)
    """
    def __init__(self, context_dim, prior_variance=1.0, noise_variance=1.0):
        self.mu = np.zeros(context_dim)  # Prior mean of θ
        self.Sigma = np.eye(context_dim) * prior_variance  # Prior covariance of θ
        self.noise_variance = noise_variance  # Known observation noise σ²

    def select(self, context):
        """Sample θ from posterior, compute predicted reward for current context."""
        theta_sample = np.random.multivariate_normal(self.mu, self.Sigma)
        predicted_mean = np.dot(theta_sample, context)
        return predicted_mean, theta_sample

    def update(self, context, reward):
        """Update the posterior given (context, reward)."""
        x = context.reshape(-1, 1)
        Sigma_inv = np.linalg.inv(self.Sigma)
        Sigma_post_inv = Sigma_inv + (1 / self.noise_variance) * (x @ x.T)
        self.Sigma = np.linalg.inv(Sigma_post_inv)
        self.mu = self.Sigma @ (
            Sigma_inv @ self.mu + (1 / self.noise_variance) * reward * context
        )


class UCBReplayBuffer:

    def __init__(self, d, capacity):
        self.buffer = {'context':np.zeros((capacity, d)), 'reward': np.zeros((capacity,1))}
        self.capacity = capacity
        self.size = 0
        self.pointer = 0


    def add(self, context, reward):
        self.buffer['context'][self.pointer] = context
        self.buffer['reward'][self.pointer] = reward
        self.size = min(self.size+1, self.capacity)
        self.pointer = (self.pointer+1)%self.capacity

    def sample(self, n):
        idx = np.random.randint(0,self.size,size=n)
        return self.buffer['context'][idx], self.buffer['reward'][idx]


class NeuralUCB:

    def __init__(self, d, K, beta=1, lamb=1, hidden_size=16, lr=1e-4, reg=0.000625, device='cpu'):
        self.d = d
        self.K = K
        self.T = 0
        self.reg = reg
        self.beta = beta
        self.net = RewardModel(d, hidden_arch=[hidden_size])
        self.initial_param = {k: v.clone().detach() for k, v in self.net.state_dict().items()}
        self.hidden_size = hidden_size
        self.net.to(device)
        self.lr = lr
        self.optimizer = optim.Adam(self.net.parameters(), lr=lr)
        self.numel = sum(w.numel() for w in self.net.parameters() if w.requires_grad)
        self.sigma_inv = lamb * np.eye(self.numel, dtype=np.float32)
        self.device = device
        self.loss_hist = []
        self.reward_pred_hist = []
        self.ucb_pred_hist = []

        self.theta0 = torch.cat(
            [w.flatten() for w in self.net.parameters() if w.requires_grad]
        )
        self.replay_buffer = UCBReplayBuffer(d, 10000)

    def take_action(self, context):
        context = torch.tensor(context, dtype=torch.float32)
        context = context.to(self.device)
        context.requires_grad = True

        g = np.zeros((self.K, self.numel), dtype=np.float32)

        for k in range(self.K):
            g[k] = self.grad(context[k].unsqueeze(0)).cpu().numpy()

        with torch.no_grad():
            reward_model_term = self.net(context).cpu().numpy()
            ucb_term = np.sqrt(np.matmul(np.matmul(g[:, None, :], self.sigma_inv), g[:, :, None])[:, 0, :])
            p = reward_model_term + self.beta * ucb_term

            self.ucb_pred_hist.append(ucb_term)
            self.reward_pred_hist.append(reward_model_term)

        action = np.argmax(p)
        return action

    def grad(self, x):
        y = self.net(x)
        self.optimizer.zero_grad()
        y.backward()
        return torch.cat(
            [w.grad.detach().flatten() / np.sqrt(self.hidden_size) for w in self.net.parameters() if w.requires_grad]
        ).to(self.device)

    def update(self, context, action, reward):
        context = torch.tensor(context, dtype=torch.float32)
        context = context.to(self.device)

        to_add = context
        assert to_add.shape == (self.d,)

        to_add_unsqueeze = context.unsqueeze(0)
        assert to_add_unsqueeze.shape == (1, self.d), to_add_unsqueeze.shape

        grad = self.grad(to_add_unsqueeze)
        assert grad.shape == (self.numel,), grad.shape

        grad_unsqueeze = grad[:, None]
        assert grad_unsqueeze.shape == (self.numel, 1), grad_unsqueeze
        
        self.sherman_morrison_update(grad_unsqueeze.cpu().numpy())
        self.replay_buffer.add(to_add.cpu().numpy(), reward)

        self.T += 1
        self.train()

    def sherman_morrison_update(self, v):
        self.sigma_inv -= (self.sigma_inv @ v @ v.T @ self.sigma_inv) / (1+v.T @ self.sigma_inv @ v)

    def train(self):
        self.net.load_state_dict({k: v.clone() for k, v in self.initial_param.items()})
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr)

        for _ in range(100):
            if len(self.replay_buffer.buffer['context']) < 2:
                continue
            x = self.replay_buffer.buffer['context']
            y = self.replay_buffer.buffer['reward']

            p = np.random.permutation(len(x))
            x = x[p]
            y = y[p]

            x = torch.tensor(x, dtype=torch.float32).to(self.device)
            y = torch.tensor(y, dtype=torch.float32).to(self.device).view(-1,1)
            y_hat = self.net(x)
            loss = F.mse_loss(y_hat, y)
            loss += self.reg * torch.norm(torch.cat(
                [w.flatten() for w in self.net.parameters() if w.requires_grad]
            ) - self.theta0)**2
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.loss_hist.append(loss.item())


def all_possible_segments(
        model, expert_prompt_trajs, rtg_scale, state_dim, act_dim, traj_prompt_j, traj_prompt_h,device,
        context_len, eval_batch_size, bandit_use_transformer_features=False, bandit_feature_dim=None):
    
    segments_raw = []
    segments_features = []
    segment_idx = []
    traj_idxs = []
    for traj_prompt_segment_idx in range(len(expert_prompt_trajs)):
        prompt_traj = expert_prompt_trajs[traj_prompt_segment_idx]
        return_to_go = discount_cumsum(prompt_traj["rewards"], gamma=1.)

        for i in range(0, len(prompt_traj["observations"]) - traj_prompt_h + 1):
            segment_observations = prompt_traj["observations"][i:i + traj_prompt_h]
            segment_actions = prompt_traj["actions"][i:i + traj_prompt_h]
            segment_rtgs = return_to_go[i:i + traj_prompt_h]  # prompt_traj["rewards"][i:i + traj_prompt_h]
            segment_idx.append((i, i + traj_prompt_h))
            segment_timesteps = np.arange(i, i + traj_prompt_h, step=1)
            segment_masks = np.ones((traj_prompt_h))

            traj_idxs.append(traj_prompt_segment_idx)

            segment_rtgs = segment_rtgs / rtg_scale

            # stack and interleave observations, actions, and rtgs into (r, s, a, r, s, a ...)
            arm_features_raw = []
            for t in range(traj_prompt_h):
                if state_dim > 1:
                    for obs_scalar in segment_observations[t]:
                        arm_features_raw.append(obs_scalar)
                else:
                    arm_features_raw.append(segment_observations[t])

                if act_dim > 1:
                    for act_scalar in segment_actions[t]:
                        arm_features_raw.append(act_scalar)
                else:
                    arm_features_raw.append(segment_actions[t])
                arm_features_raw.append(segment_rtgs[t][0])
            segments_raw.append(arm_features_raw)

            if bandit_use_transformer_features:
                # get the transformer features for the current segment
                segment_timestep_tensor = torch.from_numpy(segment_timesteps).to(torch.int64).to(device).reshape(1, traj_prompt_h)
                segment_masks_tensor = torch.from_numpy(segment_masks).to(torch.int64).to(device).reshape(1, traj_prompt_h)

                segment_states_tensor = torch.from_numpy(segment_observations).to(torch.float32).to(device).reshape(1, traj_prompt_h, state_dim)
                segment_actions_tensor = torch.from_numpy(segment_actions).to(torch.float32).to(device).reshape(1, traj_prompt_h, act_dim)
                segment_rtgs_tensor = torch.from_numpy(segment_rtgs).to(torch.float32).to(device).reshape(1, traj_prompt_h, 1)

                # repeat for each segment we ought to have, to get right shape of input tensors, we just look at the first one, though...
                segment_timestep_tensor = segment_timestep_tensor.repeat(1, traj_prompt_j)
                segment_masks_tensor = segment_masks_tensor.repeat(1, traj_prompt_j)
                segment_states_tensor = segment_states_tensor.repeat(1, traj_prompt_j, 1)
                segment_actions_tensor = segment_actions_tensor.repeat(1, traj_prompt_j, 1)
                segment_rtgs_tensor = segment_rtgs_tensor.repeat(1, traj_prompt_j, 1)

                prompt = (
                    segment_states_tensor, segment_actions_tensor, None, None, segment_rtgs_tensor,
                    segment_timestep_tensor,
                    segment_masks_tensor)

                _, act_preds, _, rtg_features, state_features, action_features = model.forward(
                    torch.zeros((eval_batch_size, context_len, state_dim), dtype=torch.float32, device=device),
                    torch.zeros((eval_batch_size, context_len, act_dim), dtype=torch.float32, device=device),
                    None,
                    torch.zeros((eval_batch_size, context_len, 1), dtype=torch.float32, device=device),
                    torch.arange(start=0, end=context_len, step=1).repeat(eval_batch_size, 1).to(device),
                    attention_mask=torch.ones((context_len)).repeat(eval_batch_size, 1).to(device),
                    prompt=prompt, features=True,
                )

                assert action_features.shape == (
                    1, traj_prompt_j * traj_prompt_h + context_len, bandit_feature_dim)
                segment_features_transformer = action_features[0, traj_prompt_h - 1, :].cpu().detach().numpy()
                segments_features.append(segment_features_transformer)

    return segments_raw, segments_features, segment_idx, traj_idxs


def select_segments(
        mab, eval_batch_size, expert_prompt_trajs, traj_idxs, segments_raw, segments_features, segment_idx,
        epsilon, state_dim, act_dim, device, traj_prompt_j, traj_prompt_h,
        bandit_use_transformer_features, info, variant):
    
    max_ep_len, state_mean, state_std, scale = info['max_ep_len'], info['state_mean'], info['state_std'], info['scale']

    if np.random.rand() < epsilon:
        selected_segment_idxs = [np.random.randint(len(segments_raw)) for _ in range(mab.n_segments)]
        selected_segments_start_end = [segment_idx[arm_idx] for arm_idx in selected_segment_idxs]
    else:
        with torch.no_grad():
            if bandit_use_transformer_features:
                rewards_pred = mab.predict(segments_features)
            else:
                rewards_pred = mab.predict(segments_raw)

        rewards_pred = torch.stack(rewards_pred).numpy().T
        selected_segment_idxs = np.argmax(rewards_pred, axis=0)
        assert len(selected_segment_idxs) == mab.n_segments
        selected_segments_start_end = [segment_idx[arm_idx] for arm_idx in selected_segment_idxs]

    s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []

    for seg_counter, seg_idx in enumerate(selected_segment_idxs):
        prompt_traj = expert_prompt_trajs[traj_idxs[seg_idx]]
        start, end = selected_segments_start_end[seg_counter]

        # get sequences from dataset
        s.append(prompt_traj['observations'][start:end].reshape(1, -1, state_dim))
        a.append(prompt_traj['actions'][start:end].reshape(1, -1, act_dim))
        r.append(prompt_traj['rewards'][start:end].reshape(1, -1, 1))
        if 'terminals' in prompt_traj:
            d.append(prompt_traj['terminals'][start:end].reshape(1, -1))
        else:
            d.append(prompt_traj['dones'][start:end].reshape(1, -1))
        timesteps.append(np.arange(start, start + s[-1].shape[1]).reshape(1, -1))
        timesteps[-1][timesteps[-1] >= max_ep_len] = max_ep_len - 1
        rtg.append(discount_cumsum(prompt_traj['rewards'][start:], gamma=1.)[:s[-1].shape[1]].reshape(1, -1, 1))

        # padding and state + reward normalization
        tlen = s[-1].shape[1]
        max_len = end - start
        s[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1)
        if not variant['no_state_normalize']:
            s[-1] = (s[-1] - state_mean) / state_std
        a[-1] = np.concatenate([np.ones((1, max_len - tlen, act_dim)) * -10., a[-1]], axis=1)
        r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1)
        d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)
        rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) / scale
        timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1)
        mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1))

    s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device)
    a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device)
    rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device)
    timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device)
    mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device)

    n_traj_prompt_segments = traj_prompt_j
    traj_prompt_seg_len = traj_prompt_h
    s = s.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, state_dim)
    a = a.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, act_dim)
    rtg = rtg.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, 1)
    timesteps = timesteps.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len)
    mask = mask.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len)
    return s, a, rtg, timesteps, mask, selected_segment_idxs, epsilon


def rollout_bandit(
        env, model, max_test_ep_len, rtg_target, rtg_scale, act_dim, state_dim, device,
        traj_prompt_timesteps, traj_prompt_states, traj_prompt_actions, traj_prompt_rtgs,
        traj_prompt_masks, info, variant, traj_prompt_j, traj_prompt_h):

    n_traj_prompt_segments = traj_prompt_j
    traj_prompt_seg_len = traj_prompt_h

    prompt = (
        traj_prompt_states, traj_prompt_actions, None, None, traj_prompt_rtgs, traj_prompt_timesteps, traj_prompt_masks)

    with torch.no_grad():
        ret, _ = prompt_evaluate_episode_rtg(
            env,
            state_dim,
            act_dim,
            model,
            max_ep_len=max_test_ep_len,
            scale=rtg_scale,
            target_return=rtg_target[0] / rtg_scale,
            state_mean=info['state_mean'],
            state_std=info['state_std'],
            device=device,
            prompt=prompt,
            no_r=variant['no_r'],
            no_rtg=variant['no_rtg'],
            no_state_normalize=variant['no_state_normalize']
        )

    prompt_state_segments = []

    for segment in range(n_traj_prompt_segments):
        prompt_state_segments.append(
            traj_prompt_states[
            0,  # index of the only segment we have, because this code is just for prompts with containing one segment
            segment * traj_prompt_seg_len:(segment + 1) * traj_prompt_seg_len
            ].cpu().numpy())

    return prompt_state_segments, ret


def update_mab(mab, sparse_reward_sum, selected_segment_idxs, segments_raw, segments_features,
               bandit_use_transformer_features):
    
    store_segments = []
    for j_idx, segment_update_idx in enumerate(selected_segment_idxs):

        if bandit_use_transformer_features:
            store_segments.append(segments_features[segment_update_idx])
        else:
            store_segments.append(segments_raw[segment_update_idx])

    mab.store_observation(store_segments, sparse_reward_sum)

    for model_idx in range(len(mab.segment_reward_models)):
        initial_parameter = mab.initial_weights[model_idx]
        mab.segment_reward_models[model_idx].load_state_dict({k: v.clone() for k, v in initial_parameter.items()})

    for reward_model in mab.segment_reward_models:
        reward_model.train()
    losses = mab.update_model()

    return losses


def prompt_tuning_bandit(model, prompt_trajectory, env, info, variant, env_name, wandb):
    logs = dict()
    seed = variant['seed']
    print("Running bandit on seed", seed)
    random.seed()
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

    eval_batch_size = 1
    num_rollouts = variant["eval_rollouts"]
    epsilon_decay = 1.0 / (num_rollouts * 0.2)

    rtg_scale = info["scale"]
    rtg_target = info["env_targets"]
    max_test_ep_len = info["max_ep_len"]
    act_dim = info["act_dim"]
    state_dim = info["state_dim"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    context_len = variant["K"]
    traj_prompt_j = variant["num_traj_prompt_j"]
    traj_prompt_h = variant['prompt_length']
    bandit_use_transformer_features = variant['bandit_use_transformer_features']

    if bandit_use_transformer_features:
        bandit_feature_dim = variant['embed_dim']
    else:
        n_segment_tokens = traj_prompt_h * state_dim
        n_segment_tokens += traj_prompt_h * act_dim
        n_segment_tokens += traj_prompt_h * 1
        bandit_feature_dim = n_segment_tokens

    rollouts_rewards = []
    rollout_prompt_states = []

    mab = NeuralBandit(
        input_dim=bandit_feature_dim,
        n_segments=traj_prompt_j,
        segment_length=traj_prompt_h * (state_dim + act_dim + 1),
        regularization_lambda=0.0,
        learning_rate=0.01,
        steps_per_update=1000,
        rm_hidden_arch=variant["bandit_arch"]
    )

    # get all possible segments
    segments_raw, segments_features, segment_idx, traj_idxs = all_possible_segments(
        model, prompt_trajectory, rtg_scale, state_dim, act_dim, traj_prompt_j, traj_prompt_h,
        device, context_len, eval_batch_size, bandit_use_transformer_features, bandit_feature_dim)
    
    for rollout_idx in range(num_rollouts):
        if variant["bandit_epsilon"] == "schedule":
            epsilon = max(0, 1 - (epsilon_decay * rollout_idx))
        else:
            epsilon = float(variant["bandit_epsilon"])

        # bandit select segments
        traj_prompt_states, traj_prompt_actions, traj_prompt_rtgs, traj_prompt_timesteps, traj_prompt_masks, selected_segment_idxs, epsilon = select_segments(
            mab, eval_batch_size, prompt_trajectory, traj_idxs, segments_raw, segments_features, segment_idx, epsilon,
            state_dim, act_dim, device, traj_prompt_j, traj_prompt_h, bandit_use_transformer_features,
            info, variant)

        # rollout bandit selected prompts and select the best arm
        prompt_state_segments, ret = rollout_bandit(env, model, max_test_ep_len, rtg_target,
                                                    rtg_scale,
                                                    act_dim, state_dim, device, traj_prompt_timesteps,
                                                    traj_prompt_states,
                                                    traj_prompt_actions, traj_prompt_rtgs,
                                                    traj_prompt_masks, info, variant, traj_prompt_j, traj_prompt_h)

        rollouts_rewards.append(ret)
        rollout_prompt_states.append(prompt_state_segments)
        print(f"{rollout_idx}: {ret}, (epsilon: {epsilon})")

        # update MAB
        train_losses = update_mab(mab, ret / 10, selected_segment_idxs, segments_raw, segments_features,
                                  bandit_use_transformer_features)

        if train_losses is not None:
            logs[f'training/{env_name}_train_loss_mean'] = np.mean(train_losses)
            logs[f'training/{env_name}_train_loss_std'] = np.std(train_losses)

        logs.update({f'{env_name}_target_0_return_mean': ret})
        wandb.log(logs)

    return rollouts_rewards, logs


def select_segments_thompson(
        mab, eval_batch_size, expert_prompt_trajs, traj_idxs, segments_raw, segments_features, segment_idx,
        state_dim, act_dim, device, traj_prompt_j, traj_prompt_h,
        bandit_use_transformer_features, info, variant):

    max_ep_len, state_mean, state_std, scale = info['max_ep_len'], info['state_mean'], info['state_std'], info['scale']

    per_segment_preds = []
    if bandit_use_transformer_features:
        for segment_idx_j in range(traj_prompt_j):
            reward_preds = []
            for segment in segments_features:
                reward_pred, _ = mab[segment_idx_j].select(segment)
                reward_preds.append(reward_pred)
            per_segment_preds.append(reward_preds)
    else:
        for segment_idx_j in range(traj_prompt_j):
            reward_preds = []
            for segment in segments_raw:
                reward_pred, _ = mab[segment_idx_j].select(segment)
                reward_preds.append(reward_pred)
            per_segment_preds.append(reward_preds)

    rewards_pred = np.array(per_segment_preds).T
    selected_segment_idxs = np.argmax(rewards_pred, axis=0)
    assert len(selected_segment_idxs) == traj_prompt_j
    selected_segments_start_end = [segment_idx[arm_idx] for arm_idx in selected_segment_idxs]
    print(f"MAB: Selecting segments {selected_segment_idxs}, start end: {selected_segments_start_end}")

    s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []

    for seg_counter, seg_idx in enumerate(selected_segment_idxs):
        prompt_traj = expert_prompt_trajs[traj_idxs[seg_idx]]
        start, end = selected_segments_start_end[seg_counter]

        # get sequences from dataset
        s.append(prompt_traj['observations'][start:end].reshape(1, -1, state_dim))
        a.append(prompt_traj['actions'][start:end].reshape(1, -1, act_dim))
        r.append(prompt_traj['rewards'][start:end].reshape(1, -1, 1))
        if 'terminals' in prompt_traj:
            d.append(prompt_traj['terminals'][start:end].reshape(1, -1))
        else:
            d.append(prompt_traj['dones'][start:end].reshape(1, -1))
        timesteps.append(np.arange(start, start + s[-1].shape[1]).reshape(1, -1))
        timesteps[-1][timesteps[-1] >= max_ep_len] = max_ep_len - 1  # padding cutoff
        rtg.append(discount_cumsum(prompt_traj['rewards'][start:], gamma=1.)[:s[-1].shape[1]].reshape(1, -1, 1))

        # padding and state + reward normalization
        tlen = s[-1].shape[1]
        max_len = end - start
        s[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1)
        if not variant['no_state_normalize']:
            s[-1] = (s[-1] - state_mean) / state_std
        a[-1] = np.concatenate([np.ones((1, max_len - tlen, act_dim)) * -10., a[-1]], axis=1)
        r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1)
        d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)
        rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) / scale
        timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1)
        mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1))

    s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device)
    a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device)
    rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device)
    timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device)
    mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device)

    n_traj_prompt_segments = traj_prompt_j
    traj_prompt_seg_len = traj_prompt_h
    s = s.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, state_dim)
    a = a.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, act_dim)
    rtg = rtg.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, 1)
    timesteps = timesteps.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len)
    mask = mask.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len)

    epsilon = None
    return s, a, rtg, timesteps, mask, selected_segment_idxs, epsilon


def prompt_tuning_thompson(model, prompt_trajectory, env, info, variant, env_name, wandb):
    logs = dict()
    seed = variant['seed']
    print("Running bandit on seed", seed)
    random.seed()
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

    eval_batch_size = 1
    num_rollouts = variant["eval_rollouts"]

    rtg_scale = info["scale"]
    rtg_target = info["env_targets"]
    max_test_ep_len = info["max_ep_len"]
    act_dim = info["act_dim"]
    state_dim = info["state_dim"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    context_len = variant["K"]
    traj_prompt_j = variant["num_traj_prompt_j"]
    traj_prompt_h = variant['prompt_length']
    bandit_use_transformer_features = variant['bandit_use_transformer_features']

    if bandit_use_transformer_features:
        bandit_feature_dim = variant['embed_dim']
    else:
        n_segment_tokens = traj_prompt_h * state_dim
        n_segment_tokens += traj_prompt_h * act_dim
        n_segment_tokens += traj_prompt_h * 1
        bandit_feature_dim = n_segment_tokens

    rollouts_rewards = []
    rollout_prompt_states = []

    mab = [
        SingleArmContextualTS(
            context_dim=bandit_feature_dim,
        ) for _ in range(traj_prompt_j)
    ]

    segments_raw, segments_features, segment_idx, traj_idxs = all_possible_segments(
        model, prompt_trajectory, rtg_scale, state_dim, act_dim, traj_prompt_j, traj_prompt_h,
        device, context_len, eval_batch_size, bandit_use_transformer_features, bandit_feature_dim)
        
    for rollout_idx in range(num_rollouts):

        traj_prompt_states, traj_prompt_actions, traj_prompt_rtgs, traj_prompt_timesteps, traj_prompt_masks, selected_segment_idxs, epsilon = select_segments_thompson(
            mab, eval_batch_size, prompt_trajectory, traj_idxs, segments_raw, segments_features, segment_idx,
            state_dim, act_dim, device, traj_prompt_j, traj_prompt_h, bandit_use_transformer_features,
            info, variant)

        prompt_state_segments, ret = rollout_bandit(env, model, max_test_ep_len, rtg_target,
                                                    rtg_scale,
                                                    act_dim, state_dim, device, traj_prompt_timesteps,
                                                    traj_prompt_states,
                                                    traj_prompt_actions, traj_prompt_rtgs,
                                                    traj_prompt_masks, info, variant, traj_prompt_j, traj_prompt_h)

        rollouts_rewards.append(ret)
        rollout_prompt_states.append(prompt_state_segments)
        print(f"{rollout_idx}: {ret}, (epsilon: {epsilon})")

        for j_idx, segment_update_idx in enumerate(selected_segment_idxs):
            if bandit_use_transformer_features:
                store_segment = segments_features[segment_update_idx]

            else:
                store_segment = segments_raw[segment_update_idx]

            mab[j_idx].update(
                context=np.array(store_segment),
                reward=ret / 10,
            )

        logs.update({f'{env_name}_target_0_return_mean': ret})
        wandb.log(logs)

    return rollouts_rewards, logs


def select_segments_ucb(
        mab, eval_batch_size, expert_prompt_trajs, traj_idxs, segments_raw, segments_features, segment_idx,
        state_dim, act_dim, device, traj_prompt_j, traj_prompt_h,
        bandit_use_transformer_features, info, variant):

    max_ep_len, state_mean, state_std, scale = info['max_ep_len'], info['state_mean'], info['state_std'], info['scale']

    selected_segment_idxs = []
    for s_mab in mab:
        if bandit_use_transformer_features:
            selected_segment_idx = s_mab.take_action(np.array(segments_features))
        else:
            selected_segment_idx = s_mab.take_action(segments_raw)
        selected_segment_idxs.append(selected_segment_idx)  # store the selected segment idx for each arm

    selected_segments_start_end = [segment_idx[arm_idx] for arm_idx in selected_segment_idxs]

    s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []

    for seg_counter, seg_idx in enumerate(selected_segment_idxs):
        prompt_traj = expert_prompt_trajs[traj_idxs[seg_idx]]
        start, end = selected_segments_start_end[seg_counter]

        s.append(prompt_traj['observations'][start:end].reshape(1, -1, state_dim))
        a.append(prompt_traj['actions'][start:end].reshape(1, -1, act_dim))
        r.append(prompt_traj['rewards'][start:end].reshape(1, -1, 1))
        if 'terminals' in prompt_traj:
            d.append(prompt_traj['terminals'][start:end].reshape(1, -1))
        else:
            d.append(prompt_traj['dones'][start:end].reshape(1, -1))
        timesteps.append(np.arange(start, start + s[-1].shape[1]).reshape(1, -1))
        timesteps[-1][timesteps[-1] >= max_ep_len] = max_ep_len - 1  # padding cutoff
        rtg.append(discount_cumsum(prompt_traj['rewards'][start:], gamma=1.)[:s[-1].shape[1]].reshape(1, -1, 1))

        # padding and state + reward normalization
        tlen = s[-1].shape[1]
        max_len = end - start
        s[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1)
        if not variant['no_state_normalize']:
            s[-1] = (s[-1] - state_mean) / state_std
        a[-1] = np.concatenate([np.ones((1, max_len - tlen, act_dim)) * -10., a[-1]], axis=1)
        r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1)
        d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)
        rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) / scale
        timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1)
        mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1))

    s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device)
    a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device)
    rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device)
    timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device)
    mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device)

    n_traj_prompt_segments = traj_prompt_j
    traj_prompt_seg_len = traj_prompt_h
    s = s.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, state_dim)
    a = a.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, act_dim)
    rtg = rtg.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, 1)
    timesteps = timesteps.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len)
    mask = mask.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len)

    epsilon = None
    return s, a, rtg, timesteps, mask, selected_segment_idxs, epsilon


def prompt_tuning_ucb(model, prompt_trajectory, env, info, variant, env_name, wandb):
    logs = dict()
    seed = variant['seed']
    random.seed()
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

    eval_batch_size = 1
    num_rollouts = variant["eval_rollouts"]

    rtg_scale = info["scale"]
    rtg_target = info["env_targets"]
    max_test_ep_len = info["max_ep_len"]
    act_dim = info["act_dim"]
    state_dim = info["state_dim"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    context_len = variant["K"]
    traj_prompt_j = variant["num_traj_prompt_j"]
    traj_prompt_h = variant['prompt_length']
    bandit_use_transformer_features = variant['bandit_use_transformer_features']

    if bandit_use_transformer_features:
        bandit_feature_dim = variant['embed_dim']
    else:
        n_segment_tokens = traj_prompt_h * state_dim
        n_segment_tokens += traj_prompt_h * act_dim
        n_segment_tokens += traj_prompt_h * 1
        bandit_feature_dim = n_segment_tokens

    rollouts_rewards = []
    rollout_prompt_states = []

    # get all possible segments
    segments_raw, segments_features, segment_idx, traj_idxs = all_possible_segments(
        model, prompt_trajectory, rtg_scale, state_dim, act_dim, traj_prompt_j, traj_prompt_h, device, context_len,
        eval_batch_size,bandit_use_transformer_features, bandit_feature_dim)

    if bandit_use_transformer_features:
        num_arms = len(segments_features)
    else:
        num_arms = len(segments_raw)
    mab = [NeuralUCB(d=bandit_feature_dim, K=num_arms, device=device, reg=0, beta=3, lr=0.001) for _ in range(traj_prompt_j)]

    for rollout_idx in range(num_rollouts):

        # bandit select segments
        traj_prompt_states, traj_prompt_actions, traj_prompt_rtgs, traj_prompt_timesteps, traj_prompt_masks, selected_segment_idxs, epsilon = select_segments_ucb(
            mab, eval_batch_size, prompt_trajectory, traj_idxs, segments_raw, segments_features, segment_idx,
            state_dim, act_dim, device, traj_prompt_j, traj_prompt_h, bandit_use_transformer_features,
            info, variant)

        # rollout bandit selected prompts and select the best arm
        prompt_state_segments, ret = rollout_bandit(env, model, max_test_ep_len, rtg_target,
                                                    rtg_scale,
                                                    act_dim, state_dim, device, traj_prompt_timesteps,
                                                    traj_prompt_states,
                                                    traj_prompt_actions, traj_prompt_rtgs,
                                                    traj_prompt_masks, info, variant, traj_prompt_j, traj_prompt_h)

        rollouts_rewards.append(ret)
        rollout_prompt_states.append(prompt_state_segments)
        print(f"{rollout_idx}: {ret}, (epsilon: {epsilon})")

        for j_idx, segment_update_idx in enumerate(selected_segment_idxs):
            if bandit_use_transformer_features:
                store_segment = segments_features[segment_update_idx]

            else:
                store_segment = segments_raw[segment_update_idx]

            mab[j_idx].update(
                context=store_segment,
                action=segment_update_idx,
                reward=ret / 10)

        logs.update({f'{env_name}_target_0_return_mean': ret})
        wandb.log(logs)

    return rollouts_rewards, logs


def prompt_tuning_hill_climbing(model, prompt_trajectory, env, info, variant, env_name, wandb):
    logs = dict()
    seed = variant['seed']
    random.seed()
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

    eval_batch_size = 1
    num_rollouts = variant["eval_rollouts"]
    rtg_scale = info["scale"]
    rtg_target = info["env_targets"]
    max_test_ep_len = info["max_ep_len"]
    act_dim = info["act_dim"]
    state_dim = info["state_dim"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    context_len = variant["K"]
    traj_prompt_j = variant["num_traj_prompt_j"]
    traj_prompt_h = variant['prompt_length']
    bandit_use_transformer_features = variant['bandit_use_transformer_features']

    if bandit_use_transformer_features:
        bandit_feature_dim = variant['embed_dim']
    else:
        n_segment_tokens = traj_prompt_h * state_dim
        n_segment_tokens += traj_prompt_h * act_dim
        n_segment_tokens += traj_prompt_h * 1
        bandit_feature_dim = n_segment_tokens

    rollouts_rewards = []
    rollout_prompt_states = []

    # get all possible segments
    segments_raw, segments_features, segment_idx, traj_idxs = all_possible_segments(
        model, prompt_trajectory, rtg_scale, state_dim, act_dim, traj_prompt_j, traj_prompt_h, device, context_len,
        eval_batch_size, bandit_use_transformer_features, bandit_feature_dim)

    # select an initial, random prompt segment
    def select_random_prompt():
        max_ep_len, state_mean, state_std, scale = info['max_ep_len'], info['state_mean'], info['state_std'], info['scale']

        selected_segment_idxs = [np.random.randint(len(segments_raw)) for _ in range(traj_prompt_j)]
        selected_segments_start_end = [segment_idx[arm_idx] for arm_idx in selected_segment_idxs]

        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []

        for seg_counter, seg_idx in enumerate(selected_segment_idxs):
            prompt_traj = prompt_trajectory[traj_idxs[seg_idx]]
            start, end = selected_segments_start_end[seg_counter]

            # get sequences from dataset
            s.append(prompt_traj['observations'][start:end].reshape(1, -1, state_dim))
            a.append(prompt_traj['actions'][start:end].reshape(1, -1, act_dim))
            r.append(prompt_traj['rewards'][start:end].reshape(1, -1, 1))
            if 'terminals' in prompt_traj:
                d.append(prompt_traj['terminals'][start:end].reshape(1, -1))
            else:
                d.append(prompt_traj['dones'][start:end].reshape(1, -1))
            timesteps.append(np.arange(start, start + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= max_ep_len] = max_ep_len - 1  # padding cutoff
            rtg.append(discount_cumsum(prompt_traj['rewards'][start:], gamma=1.)[:s[-1].shape[1]].reshape(1, -1, 1))

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            max_len = end - start
            s[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1)
            if not variant['no_state_normalize']:
                s[-1] = (s[-1] - state_mean) / state_std
            a[-1] = np.concatenate([np.ones((1, max_len - tlen, act_dim)) * -10., a[-1]], axis=1)
            r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) / scale
            timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device)
        a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device)
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device)
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device)
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device)

        n_traj_prompt_segments = traj_prompt_j
        traj_prompt_seg_len = traj_prompt_h
        s = s.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, state_dim)
        a = a.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, act_dim)
        rtg = rtg.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, 1)
        timesteps = timesteps.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len)
        mask = mask.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len)

        return s, a, rtg, timesteps, mask, selected_segment_idxs

    traj_prompt_states, traj_prompt_actions, traj_prompt_rtgs, traj_prompt_timesteps, traj_prompt_masks, selected_segment_idxs = select_random_prompt()

    best_return = -np.inf
    eta_schedule = np.linspace(1, 0.01, num_rollouts)

    for rollout_idx in range(num_rollouts):

        if rollout_idx == 0:
            noisy_prompt_states = traj_prompt_states.clone()
            noisy_prompt_actions = traj_prompt_actions.clone()
            noisy_prompt_rtgs = traj_prompt_rtgs.clone()
        else:
            # apply noise to prompt
            if variant["hillclimbing_eta"] == "schedule":
                noise_scale = eta_schedule[rollout_idx]
            else:
                noise_scale = float(variant["hillclimbing_eta"])
            
            noisy_prompt_states = traj_prompt_states.clone() + torch.randn_like(traj_prompt_states) * noise_scale
            noisy_prompt_actions = traj_prompt_actions.clone() + torch.randn_like(traj_prompt_actions) * noise_scale
            noisy_prompt_rtgs = traj_prompt_rtgs.clone() + torch.randn_like(traj_prompt_rtgs) * noise_scale

        # rollout bandit selected prompts and select the best arm
        prompt_state_segments, ret = rollout_bandit(
            env, model, max_test_ep_len, rtg_target, rtg_scale,act_dim, state_dim, device,
            traj_prompt_timesteps, noisy_prompt_states, noisy_prompt_actions, noisy_prompt_rtgs, traj_prompt_masks,
            info, variant, traj_prompt_j, traj_prompt_h)

        rollouts_rewards.append(ret)
        rollout_prompt_states.append(prompt_state_segments)

        if ret > best_return:  # update prompt if noisy version was better!
            best_return = ret

            traj_prompt_states = noisy_prompt_states
            traj_prompt_actions = noisy_prompt_actions
            traj_prompt_rtgs = noisy_prompt_rtgs


        if variant["log_to_wandb"]:
            logs.update({f'{env_name}_target_0_return_mean': ret})
            wandb.log(logs)

    return rollouts_rewards, logs


def prompt_tuning_zoranksgd(model, prompt_trajectory, env, info, variant, env_name, wandb):
    logs = dict()
    seed = variant['seed']
    random.seed()
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

    eval_batch_size = 1
    num_rollouts = variant["eval_rollouts"]
    epsilon = 1.0
    rtg_scale = info["scale"]
    rtg_target = info["env_targets"]
    max_test_ep_len = info["max_ep_len"]
    act_dim = info["act_dim"]
    state_dim = info["state_dim"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    B = 1
    m = variant["zorank_m"]
    mu = 1
    eta_schedule = np.linspace(1, 0.01, num_rollouts)

    context_len = variant["K"]
    traj_prompt_j = variant["num_traj_prompt_j"]
    traj_prompt_h = variant['prompt_length']
    bandit_use_transformer_features = variant['bandit_use_transformer_features']

    if bandit_use_transformer_features:
        bandit_feature_dim = variant['embed_dim']
    else:
        n_segment_tokens = traj_prompt_h * state_dim
        n_segment_tokens += traj_prompt_h * act_dim
        n_segment_tokens += traj_prompt_h * 1
        bandit_feature_dim = n_segment_tokens

    rollouts_rewards = []

    # get all possible segments
    segments_raw, segments_features, segment_idx, traj_idxs = all_possible_segments(
        model, prompt_trajectory, rtg_scale, state_dim, act_dim, traj_prompt_j, traj_prompt_h, device, context_len,
        eval_batch_size, bandit_use_transformer_features, bandit_feature_dim)

    def select_random_prompt():
        max_ep_len, state_mean, state_std, scale = info['max_ep_len'], info['state_mean'], info['state_std'], info['scale']

        selected_segment_idxs = [np.random.randint(len(segments_raw)) for _ in range(traj_prompt_j)]
        selected_segments_start_end = [segment_idx[arm_idx] for arm_idx in selected_segment_idxs]

        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []

        for seg_counter, seg_idx in enumerate(selected_segment_idxs):
            prompt_traj = prompt_trajectory[traj_idxs[seg_idx]]
            start, end = selected_segments_start_end[seg_counter]

            # get sequences from dataset
            s.append(prompt_traj['observations'][start:end].reshape(1, -1, state_dim))
            a.append(prompt_traj['actions'][start:end].reshape(1, -1, act_dim))
            r.append(prompt_traj['rewards'][start:end].reshape(1, -1, 1))
            if 'terminals' in prompt_traj:
                d.append(prompt_traj['terminals'][start:end].reshape(1, -1))
            else:
                d.append(prompt_traj['dones'][start:end].reshape(1, -1))
            timesteps.append(np.arange(start, start + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= max_ep_len] = max_ep_len - 1  # padding cutoff
            rtg.append(discount_cumsum(prompt_traj['rewards'][start:], gamma=1.)[:s[-1].shape[1]].reshape(1, -1, 1))

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            max_len = end - start
            s[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1)
            if not variant['no_state_normalize']:
                s[-1] = (s[-1] - state_mean) / state_std
            a[-1] = np.concatenate([np.ones((1, max_len - tlen, act_dim)) * -10., a[-1]], axis=1)
            r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) / scale
            timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device)
        a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device)
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device)
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device)
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device)

        n_traj_prompt_segments = traj_prompt_j
        traj_prompt_seg_len = traj_prompt_h
        s = s.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, state_dim)
        a = a.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, act_dim)
        rtg = rtg.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, 1)
        timesteps = timesteps.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len)
        mask = mask.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len)

        return s, a, rtg, timesteps, mask, selected_segment_idxs, epsilon

    traj_prompt_states, traj_prompt_actions, traj_prompt_rtgs, traj_prompt_timesteps, traj_prompt_masks, selected_segment_idxs, epsilon = select_random_prompt()

    for rollout_idx in range(num_rollouts):

        if rollout_idx > 0:

            flat_prompt = torch.cat([traj_prompt_states.flatten(), traj_prompt_actions.flatten(), traj_prompt_rtgs.flatten()])
            prompt_size = traj_prompt_j * traj_prompt_h * (state_dim + act_dim + 1)
            flattened_prompt_vec = flat_prompt.reshape(1, prompt_size).to(device)

            noise_dist = torch.distributions.MultivariateNormal(
                torch.zeros(prompt_size, device=device),
                torch.eye(prompt_size, device=device)
            )

            xi = noise_dist.sample((B, m))
            pts = flattened_prompt_vec.unsqueeze(1) + mu * xi
            assert pts.shape == (B, m, prompt_size)

            if variant["zorank_mode"] == "online":

                noise_scores = []  # store the scores for each noisy prompt, then rank and estimate gradient
                for noisy_idx in range(m):

                    mth_noised_prompt = pts[:, noisy_idx, :].squeeze()  # get the mth noised prompt

                    # extract the modalities from the flattened prompt again
                    noisy_states = mth_noised_prompt[:traj_prompt_j * traj_prompt_h * state_dim].reshape(traj_prompt_states.shape)
                    noisy_actions = mth_noised_prompt[traj_prompt_j * traj_prompt_h * state_dim:traj_prompt_j * traj_prompt_h * (state_dim + act_dim)].reshape(traj_prompt_actions.shape)
                    noisy_rtgs = mth_noised_prompt[traj_prompt_j * traj_prompt_h * (state_dim + act_dim):].reshape(traj_prompt_rtgs.shape)

                    # do rollout with the mth perturbed prompt
                    _, mth_ret = rollout_bandit(
                        env, model, max_test_ep_len, rtg_target, rtg_scale, act_dim, state_dim, device,
                        traj_prompt_timesteps, noisy_states, noisy_actions, noisy_rtgs, traj_prompt_masks,
                        info, variant, traj_prompt_j, traj_prompt_h
                    )
                    noise_scores.append(mth_ret)

            elif variant["zorank_mode"] == "offline":

                # make a batch of data for calculating loss
                traj_idx = np.random.randint(len(prompt_trajectory))
                offline_eval_traj = prompt_trajectory[traj_idx]
                offline_eval_traj_len = len(offline_eval_traj["observations"])
                all_offline_eval_states = []
                all_offline_eval_actions = []
                all_offline_eval_rewards = []
                all_offline_eval_rtgs = []
                all_offline_eval_timesteps = []
                all_target_actions = []
                offline_eval_batch_size = 32
                for batch_idx in range(offline_eval_batch_size):
                    start = np.random.randint(offline_eval_traj_len - traj_prompt_h)
                    end = start + traj_prompt_h
                    states = offline_eval_traj["observations"][start:end]
                    actions = offline_eval_traj["actions"][start:end]
                    target_action = offline_eval_traj["actions"][end]
                    rewards = offline_eval_traj["rewards"][start:end]
                    target_return = discount_cumsum(offline_eval_traj['rewards'][start:], gamma=1.)[:traj_prompt_h]
                    timesteps = np.arange(start, start + states.shape[0]).reshape(1, -1)

                    all_offline_eval_states.append(states)
                    all_offline_eval_actions.append(actions)
                    all_target_actions.append(target_action)
                    all_offline_eval_rewards.append(rewards)
                    all_offline_eval_rtgs.append(target_return)
                    all_offline_eval_timesteps.append(timesteps)

                # TODO: This double loop very slow, should refactor ASAP
                noise_scores = []
                for noisy_idx in range(m):
                    mth_noised_prompt = pts[:, noisy_idx, :].squeeze()  # get the mth noised prompt

                    # extract the modalities from the flattened prompt again
                    noisy_states = mth_noised_prompt[:traj_prompt_j * traj_prompt_h * state_dim].reshape(traj_prompt_states.shape)
                    noisy_actions = mth_noised_prompt[traj_prompt_j * traj_prompt_h * state_dim:traj_prompt_j * traj_prompt_h * (state_dim + act_dim)].reshape(traj_prompt_actions.shape)
                    noisy_rtgs = mth_noised_prompt[traj_prompt_j * traj_prompt_h * (state_dim + act_dim):].reshape(traj_prompt_rtgs.shape)

                    prompt = (noisy_states, noisy_actions, None, None, noisy_rtgs, traj_prompt_timesteps, traj_prompt_masks)

                    target_actions = []
                    pred_actions = []
                    for b_idx in range(offline_eval_batch_size):

                        state = torch.from_numpy(all_offline_eval_states[b_idx]).to(dtype=torch.float32, device=device)
                        action = torch.from_numpy(all_offline_eval_actions[b_idx]).to(dtype=torch.float32, device=device)
                        target_action = torch.from_numpy(all_target_actions[b_idx]).to(dtype=torch.float32, device=device)
                        reward = torch.from_numpy(all_offline_eval_rewards[b_idx]).to(dtype=torch.float32, device=device)
                        target_return = torch.from_numpy(all_offline_eval_rtgs[b_idx]).to(dtype=torch.float32, device=device)
                        timesteps = torch.from_numpy(all_offline_eval_timesteps[b_idx]).to(dtype=torch.float32, device=device)

                        state_mean = torch.from_numpy(info['state_mean']).to(dtype=torch.float32, device=device)
                        state_std = torch.from_numpy(info['state_std']).to(dtype=torch.float32, device=device)
                        no_state_normalize = variant['no_state_normalize']
                        if no_state_normalize:
                            pred_action = model.get_action(
                                state.to(dtype=torch.float32),
                                action.to(dtype=torch.float32),
                                reward.to(dtype=torch.float32),
                                target_return.to(dtype=torch.float32),
                                timesteps.to(dtype=torch.long),
                                prompt=prompt
                            )
                        else:
                            pred_action = model.get_action(
                                (state.to(dtype=torch.float32) - state_mean) / state_std,
                                action.to(dtype=torch.float32),
                                reward.to(dtype=torch.float32),
                                target_return.to(dtype=torch.float32),
                                timesteps.to(dtype=torch.long),
                                prompt=prompt
                            )

                        target_actions.append(target_action)
                        pred_actions.append(pred_action)

                    noise_scores.append(torch.functional.F.mse_loss(torch.stack(target_actions), torch.stack(pred_actions)))

            # rank the m rollouts according to oracle. unsqueeze to make for batch dimension
            s_vals, s_idx = torch.sort(torch.tensor(noise_scores).unsqueeze(0), descending=True if variant["zorank_mode"] == "online" else False)  # online: descending, maximize performance, offline: ascending, minimize loss
            assert s_idx.shape == (B, m)
            s_idx = s_idx.to(device)

            edges_count = m * (m - 1) // 2
            weight = 2 * torch.arange(m, device=device) - (m - 1)
            ranked_xi = torch.gather(
                xi, 1, s_idx.unsqueeze(-1).expand(-1, -1, prompt_size)
            )
            s = (ranked_xi * weight.view(1, m, 1)).sum(dim=1)  # scale the noise based on relative ranking
            g = s / edges_count

            if variant["zorank_eta"] == "schedule":
                eta = eta_schedule[rollout_idx]
            else:
                eta = float(variant["zorank_eta"])

            flattened_prompt_vec = flattened_prompt_vec - eta * g

            noisy_states = flattened_prompt_vec[0, :traj_prompt_j * traj_prompt_h * state_dim]
            noisy_actions = flattened_prompt_vec[0, traj_prompt_j * traj_prompt_h * state_dim:traj_prompt_j * traj_prompt_h * (state_dim + act_dim)]
            noisy_rtgs = flattened_prompt_vec[0, traj_prompt_j * traj_prompt_h * (state_dim + act_dim):]

            traj_prompt_states = noisy_states.reshape(eval_batch_size, traj_prompt_j * traj_prompt_h, state_dim).to(device)
            traj_prompt_actions = noisy_actions.reshape(eval_batch_size, traj_prompt_j * traj_prompt_h, act_dim).to(device)
            traj_prompt_rtgs = noisy_rtgs.reshape(eval_batch_size, traj_prompt_j * traj_prompt_h, 1).to(device)

        _, ret = rollout_bandit(
            env, model, max_test_ep_len, rtg_target, rtg_scale, act_dim, state_dim, device,
            traj_prompt_timesteps, traj_prompt_states, traj_prompt_actions, traj_prompt_rtgs, traj_prompt_masks,
            info, variant, traj_prompt_j, traj_prompt_h
        )

        rollouts_rewards.append(ret)
        print(f"ZORANKSGD return {rollout_idx}", ret)

        if variant["log_to_wandb"]:
            logs.update({f'{env_name}_target_0_return_mean': ret})
            wandb.log(logs)

    return rollouts_rewards, logs






