# Copyright (C) king.com Ltd 2025
# License: Apache 2.0

import argparse
import os
import random
from datetime import datetime
import numpy as np
import gymnasium as gym
import torch

import envs # register custom environments
from misc.utils import TrajectoryDataset
from models.transformer import DecisionTransformer
from models.bandit import NeuralBanditEpsGreedy, NeuralBanditUCB, ThompsonSamplingBandit
from misc.utils import get_device


def parse_bool_arg(arg):
    return arg.lower() in ["true", "1", "yes"]


def get_all_prompt_segments(target_dataset, model, state_mean, state_std, rtg_scale, state_dim, act_dim, context_len, device):

    if args.data_mixture == "expert":
        prompt_trajs = random.sample(target_dataset.expert_prompt_trajs, 10)  # to keep prompt dataset small, extract segments from 10 random expert demos.
    elif args.data_mixture in target_dataset.mixture_datasets.keys():
        prompt_trajs = target_dataset.mixture_datasets[args.data_mixture]
    else:
        raise ValueError(f"which_data should be 'expert' or one of {target_dataset.mixture_datasets.keys()}")

    segments_raw = []
    segments_features = []
    segment_idx = []
    traj_idxs = []
    eval_batch_size = 1
    for traj_prompt_segment_idx in range(len(prompt_trajs)):
        prompt_traj = prompt_trajs[traj_prompt_segment_idx]

        for i in range(0, len(prompt_traj["observations"]) - model.traj_prompt_h + 1):
            segment_observations = prompt_traj["observations"][i:i + model.traj_prompt_h]
            segment_actions = prompt_traj["actions"][i:i + model.traj_prompt_h]
            segment_rtgs = prompt_traj["returns_to_go"][i:i + model.traj_prompt_h]
            segment_idx.append((i, i + model.traj_prompt_h))
            segment_timesteps = np.arange(i, i + model.traj_prompt_h, step=1)
            traj_idxs.append(traj_prompt_segment_idx)

            segment_rtgs = segment_rtgs / rtg_scale

            # stack and interleave observations, actions, and rtgs into (r, s, a, r, s, a ...)
            arm_features_raw = []
            for t in range(model.traj_prompt_h):
                if state_dim > 1:
                    for obs_scalar in segment_observations[t]:
                        arm_features_raw.append(obs_scalar)
                else:
                    arm_features_raw.append(segment_observations[t])

                if act_dim > 1:
                    for act_scalar in segment_actions[t]:
                        arm_features_raw.append(act_scalar)
                else:
                    arm_features_raw.append(segment_actions[t])

                arm_features_raw.append(segment_rtgs[t])

            segments_raw.append(arm_features_raw)

            if args.bandit_use_transformer_features:
                # get the transformer features for the current segment
                segment_timestep_tensor = torch.from_numpy(segment_timesteps).to(torch.int64).to(device).reshape(1, model.traj_prompt_h)
                segment_states_tensor = torch.from_numpy(segment_observations).to(torch.float32).to(device).reshape(1, model.traj_prompt_h,state_dim)
                segment_actions_tensor = torch.from_numpy(segment_actions).to(torch.float32).to(device).reshape(1, model.traj_prompt_h, act_dim)
                segment_rtgs_tensor = torch.from_numpy(segment_rtgs).to(torch.float32).to(device).reshape(1, model.traj_prompt_h, 1)

                segment_timestep_tensor = segment_timestep_tensor.repeat(1, model.traj_prompt_j)
                segment_states_tensor = segment_states_tensor.repeat(1, model.traj_prompt_j, 1)
                segment_actions_tensor = segment_actions_tensor.repeat(1, model.traj_prompt_j, 1)
                segment_rtgs_tensor = segment_rtgs_tensor.repeat(1, model.traj_prompt_j, 1)

                _, act_preds, _, _, _, _, _, rtg_features, state_features, action_features = model.forward(
                    # use placeholder data for everything none-prompt, causally it comes after the prompt and doesn't affect prompt representation
                    timesteps=torch.arange(start=0, end=context_len, step=1).repeat(eval_batch_size, 1).to(device),
                    states=torch.zeros((eval_batch_size, context_len, state_dim), dtype=torch.float32, device=device),
                    actions=torch.zeros((eval_batch_size, context_len, act_dim), dtype=torch.float32, device=device),
                    returns_to_go=torch.zeros((eval_batch_size, context_len, 1), dtype=torch.float32, device=device),
                    traj_prompt_timesteps=segment_timestep_tensor,
                    traj_prompt_states=segment_states_tensor,
                    traj_prompt_actions=segment_actions_tensor,
                    traj_prompt_rtgs=segment_rtgs_tensor,
                    return_token_features=True
                )

                assert action_features.shape == (
                    1, model.traj_prompt_j * model.traj_prompt_h + context_len, model.h_dim)
                segment_features_transformer = action_features[0, model.traj_prompt_h - 1, :].cpu().detach().numpy()
                segments_features.append(segment_features_transformer)

    return segments_raw, segments_features, segment_idx, traj_idxs, prompt_trajs


def do_online_rollouts(
        model,
        env,
        num_rollouts,
        use_state_dims,
        rtg_target,
        rtg_scale,
        act_dim,
        state_dim,
        context_len,
        state_mean,
        state_std,
        prompt_sampling_method,
        mab=None,
        max_ep_len=100,
        segments_raw=None,
        segments_features=None,
        segment_idxs=None,
        prompt_trajs=None,
        traj_idxs=None,
):

    rollout_trajs = []
    rollouts_rewards = []
    rollout_sparse_rewards = []
    rollout_prompt_states = []
    mab_loss_hist = [[] for _ in range(model.traj_prompt_j)]
    mab_segment_t_hist = [[] for _ in range(model.traj_prompt_j)]
    mab_epsilon_hist = []
    device = get_device(use_cuda=True)
    eval_batch_size = 1
    max_test_ep_len = max_ep_len

    best_hillclimbing_prompt = None
    best_hillclimbing_reward = -np.inf
    curr_zorank_prompt = None 

    for rollout_idx in range(num_rollouts):

        if prompt_sampling_method in ["eps_greedy", "ucb", "ts"]:
            if args.bandit_use_transformer_features:
                selected_segment_idxs = mab.take_action(segments_features)
            else:
                selected_segment_idxs = mab.take_action(segments_raw)

            mab_epsilon_hist.append(mab.epsilon)

        elif prompt_sampling_method == "random":
            selected_segment_idxs = np.random.choice(len(segments_raw), model.traj_prompt_j)

        elif prompt_sampling_method == "hillclimbing":
            expl_std = ((num_rollouts + num_rollouts * 0.1) - rollout_idx) / (num_rollouts)
            if best_hillclimbing_prompt is None:
                # take a random initial prompt from the available segments
                seg_idx = np.random.choice(len(segments_raw), model.traj_prompt_j)
                selected_segments_start_end = [segment_idxs[s] for s in seg_idx]
                traj_prompt_actions = []
                traj_prompt_states = []
                traj_prompt_rtgs = []
                traj_prompt_timesteps = []
                for s in seg_idx:
                    start, end = segment_idxs[s]
                    prompt_traj = prompt_trajs[traj_idxs[s]]
                    traj_prompt_actions.append(torch.from_numpy(prompt_traj['actions'][start:end]).to(torch.float32))
                    traj_prompt_states.append(torch.from_numpy(prompt_traj['observations'][start:end]).to(torch.float32))
                    traj_prompt_rtgs.append(torch.from_numpy(prompt_traj['returns_to_go'][start:end]).to(torch.float32))
                    traj_prompt_timesteps.append(torch.arange(start=start, end=end, step=1))
                traj_prompt_actions = torch.cat(traj_prompt_actions)
                traj_prompt_states = torch.cat(traj_prompt_states)
                traj_prompt_rtgs = torch.cat(traj_prompt_rtgs)
                traj_prompt_timesteps = torch.cat(traj_prompt_timesteps)
                best_hillclimbing_prompt = [traj_prompt_timesteps.clone(), traj_prompt_states.clone(), traj_prompt_actions.clone(), traj_prompt_rtgs.clone()]
            else:
                # mutate the best prompt with Gaussian noise
                traj_prompt_timesteps, traj_prompt_states, traj_prompt_actions, traj_prompt_rtgs = [
                    t.clone() for t in best_hillclimbing_prompt
                ]
                traj_prompt_states = traj_prompt_states + torch.randn_like(traj_prompt_states) * expl_std
                traj_prompt_actions = traj_prompt_actions + torch.randn_like(traj_prompt_actions) * expl_std
                traj_prompt_rtgs = traj_prompt_rtgs + torch.randn_like(traj_prompt_rtgs) * expl_std

            traj_prompt_states = traj_prompt_states.to(device)
            traj_prompt_actions.to(device)
            traj_prompt_rtgs.to(device)
            traj_prompt_timesteps.to(device)

        elif prompt_sampling_method == "zoranksgd":
            
            if curr_zorank_prompt is None:
                print("Sampling initial random prompt for ZORankSGD")
                # take a random initial prompt from the available segments
                seg_idx = np.random.choice(len(segments_raw), model.traj_prompt_j)
                selected_segments_start_end = [segment_idxs[s] for s in seg_idx]
                traj_prompt_actions = []
                traj_prompt_states = []
                traj_prompt_rtgs = []
                traj_prompt_timesteps = []
                for s in seg_idx:
                    start, end = segment_idxs[s]
                    prompt_traj = prompt_trajs[traj_idxs[s]]
                    traj_prompt_actions.append(torch.from_numpy(prompt_traj['actions'][start:end]).to(torch.float32))
                    traj_prompt_states.append(torch.from_numpy(prompt_traj['observations'][start:end]).to(torch.float32))
                    traj_prompt_rtgs.append(torch.from_numpy(prompt_traj['returns_to_go'][start:end]).to(torch.float32))
                    traj_prompt_timesteps.append(torch.arange(start=start, end=end, step=1))
                traj_prompt_actions = torch.cat(traj_prompt_actions)
                traj_prompt_states = torch.cat(traj_prompt_states)
                traj_prompt_rtgs = torch.cat(traj_prompt_rtgs)
                traj_prompt_timesteps = torch.cat(traj_prompt_timesteps)
                curr_zorank_prompt = [traj_prompt_timesteps.clone(), traj_prompt_states.clone(), traj_prompt_actions.clone(), traj_prompt_rtgs.clone()]

            traj_prompt_timesteps, traj_prompt_states, traj_prompt_actions, traj_prompt_rtgs = curr_zorank_prompt
            
            # ZORankSGD online requires that we evaluate m noisy prompts using online return...
            flattened_prompt_vec = []
            for prompt_step in range(model.traj_prompt_h * model.traj_prompt_j):
                flattened_prompt_vec.append(traj_prompt_rtgs[prompt_step])
                for si in range(state_dim):
                    flattened_prompt_vec.append(traj_prompt_states[prompt_step, si])

                for ai in range(act_dim):
                    flattened_prompt_vec.append(traj_prompt_actions[prompt_step, ai])

            flattened_prompt_vec = torch.stack(flattened_prompt_vec).reshape(1, model.traj_prompt_j * model.traj_prompt_h * (state_dim + act_dim + 1)).to(device)
            prompt_size = model.traj_prompt_j * model.traj_prompt_h * (state_dim + act_dim + 1)
            noise_dist = torch.distributions.MultivariateNormal(
                torch.zeros(prompt_size, device=device),
                torch.eye(prompt_size, device=device)
            )

            B = 1  # batch size for ZORankSGD is just one
            m = args.zoranksgd_m  # num of samples per step for ZORankSGD
            mu = args.zoranksgd_mu  # noise scale for ZORankSGD
            eta = args.zoranksgd_eta  # learning rate with estimated gradient

            expl_std = ((num_rollouts + num_rollouts * 0.1) - rollout_idx) / (num_rollouts)
            mu = expl_std

            xi = noise_dist.sample((B, m))  # sample noise vectors
            pts = flattened_prompt_vec.unsqueeze(1) + mu * xi  # make noised version of prompt
            assert pts.shape == (B, m, prompt_size)

            noise_scores = []  # store the scores for each noisy prompt, then rank and estimate gradient
            for noisy_idx in range(m):

                n_traj_prompt_segments = model.traj_prompt_j
                traj_prompt_seg_len = model.traj_prompt_h

                mth_noised_prompt = pts[:, noisy_idx, :].squeeze()  # get current noised prompt and and unpack
                noisy_rtgs = []
                noisy_states = []
                noisy_actions = []
                for prompt_step in range(model.traj_prompt_h * model.traj_prompt_j):
                    noisy_rtgs.append(mth_noised_prompt[prompt_step * (state_dim + act_dim + 1)])
                    noisy_states.append(mth_noised_prompt[prompt_step * (state_dim + act_dim + 1) + 1:prompt_step * (state_dim + act_dim + 1) + state_dim + 1])
                    noisy_actions.append(mth_noised_prompt[prompt_step * (state_dim + act_dim + 1) + state_dim + 1:prompt_step * (state_dim + act_dim + 1) + state_dim + act_dim + 1])

                noisy_rtgs = torch.stack(noisy_rtgs).reshape(eval_batch_size, model.traj_prompt_h * model.traj_prompt_j, 1).to(device)
                noisy_states = torch.stack(noisy_states).reshape(eval_batch_size, model.traj_prompt_h * model.traj_prompt_j, state_dim).to(device)
                noisy_actions = torch.stack(noisy_actions).reshape(eval_batch_size, model.traj_prompt_h * model.traj_prompt_j, act_dim).to(device)

                # reset env to do online evaluation of the noisy prompt
                obs, info = env.reset(options={
                    "target_angle_rad": args.target_angle * np.pi,
                    "target_radius": args.target_radius,
                })
                pos_hist = []

                obs = obs[use_state_dims]

                running_state = obs
                reward_sum = 0
                sparse_reward_sum = 0
                running_reward = 0

                # make DT input tensors
                running_rtg = rtg_target / rtg_scale
                t = 0  # the current timestep
                actions = torch.zeros((eval_batch_size, max_test_ep_len, act_dim), dtype=torch.float32, device=device)
                states = torch.zeros((eval_batch_size, max_test_ep_len, state_dim), dtype=torch.float32, device=device)
                rewards_to_go = torch.zeros((eval_batch_size, max_test_ep_len, 1), dtype=torch.float32, device=device)
                timesteps = torch.arange(start=0, end=max_test_ep_len, step=1)
                timesteps = timesteps.repeat(eval_batch_size, 1).to(device)
                for t_inner in range(max_test_ep_len):  # rollout using the current noised prompt...

                    # add state in placeholder and normalize
                    if not hasattr(args, "norm_obs") or parse_bool_arg(str(args.norm_obs)):
                        running_state = (running_state - state_mean) / state_std
                    states[0, t_inner] = torch.from_numpy(running_state).to(torch.float32).to(device)

                    # calcualate running rtg and add it in placeholder
                    running_rtg = running_rtg - (running_reward / rtg_scale)
                    rewards_to_go[0, t_inner] = running_rtg

                    if t_inner < context_len:
                        _, act_preds, _, _, _, _, _ = model.forward(
                            timesteps=timesteps[:,:context_len],
                            states=states[:,:context_len],
                            actions=actions[:,:context_len],
                            returns_to_go=rewards_to_go[:,:context_len],
                            traj_prompt_timesteps=traj_prompt_timesteps.to(device),
                            traj_prompt_states=noisy_states,
                            traj_prompt_actions=noisy_actions,
                            traj_prompt_rtgs=noisy_rtgs
                        )
                        act = act_preds[0, t_inner].detach()
                    else:
                        _, act_preds, _, _, _, _, _ = model.forward(
                            timesteps=timesteps[:,t_inner-context_len+1:t_inner+1],
                            states=states[:,t_inner-context_len+1:t_inner+1],
                            actions=actions[:,t_inner-context_len+1:t_inner+1],
                            returns_to_go=rewards_to_go[:,t_inner-context_len+1:t_inner+1],
                            traj_prompt_timesteps=traj_prompt_timesteps.to(device),
                            traj_prompt_states=noisy_states,
                            traj_prompt_actions=noisy_actions,
                            traj_prompt_rtgs=noisy_rtgs
                        )
                        act = act_preds[0, -1].detach()

                    running_state, running_dense_reward, done, trunc, info = env.step(act.cpu().numpy())
                    done = done or trunc
                    success = info["sparse_reward"] >= 10

                    if parse_bool_arg(args.pdt_use_sparse_reward):
                        running_reward = info["sparse_reward"]
                    else:
                        running_reward = running_dense_reward

                    running_state = running_state[use_state_dims]
                    reward_sum += running_reward
                    sparse_reward_sum += info["sparse_reward"]
                    pos_hist.append(info["agent_pos"])

                    actions[0, t_inner] = act

                    if done:
                        print(f"Eval prompt {noisy_idx}, reward: {reward_sum}, len: {t_inner}",  flush=True)
                        break

                noise_scores.append(sparse_reward_sum)

                prompt_state_segments = []
                if model.which_model == "traj_pdt":
                    for segment in range(n_traj_prompt_segments):
                        prompt_state_segments.append(
                            noisy_states[
                            0,
                            # index of the only segment we have, because this code is just for prompts with containing one segment
                            segment * traj_prompt_seg_len:(segment + 1) * traj_prompt_seg_len
                            ].cpu().numpy())

                    rollout_prompt_states.append(prompt_state_segments)

            # rank according to oracle = episode return
            s_vals, s_idx = torch.sort(torch.tensor(noise_scores).unsqueeze(0), descending=True)
            assert s_idx.shape == (B, m)
            s_idx = s_idx.to(device)

            edges_count = m * (m - 1) // 2
            weight = 2 * torch.arange(m, device=device) - (m - 1)
            ranked_xi = torch.gather(xi, 1, s_idx.unsqueeze(-1).expand(-1, -1, prompt_size))  # rank the noise vectors according to the oracle
            s = (ranked_xi * weight.view(1, m, 1)).sum(dim=1)
            g = s / edges_count
            flattened_prompt_vec = flattened_prompt_vec - eta * g  # update the promp with estimated gradient

            # extract the states, actions, and rtgs from the flattened noised prompt
            noisy_rtgs = []
            noisy_states = []
            noisy_actions = []
            for prompt_step in range(model.traj_prompt_h * model.traj_prompt_j):
                noisy_rtgs.append(flattened_prompt_vec[0, prompt_step * (state_dim + act_dim + 1)])
                noisy_states.append(flattened_prompt_vec[0, prompt_step * (state_dim + act_dim + 1) + 1:prompt_step * (state_dim + act_dim + 1) + state_dim + 1])
                noisy_actions.append(flattened_prompt_vec[0, prompt_step * (state_dim + act_dim + 1) + state_dim + 1:prompt_step * (state_dim + act_dim + 1) + state_dim + act_dim + 1])

            traj_prompt_states = torch.stack(noisy_states).reshape(eval_batch_size, model.traj_prompt_j * model.traj_prompt_h, state_dim).to(device)
            traj_prompt_actions = torch.stack(noisy_actions).reshape(eval_batch_size, model.traj_prompt_j * model.traj_prompt_h, act_dim).to(device)
            traj_prompt_rtgs = torch.stack(noisy_rtgs).reshape(eval_batch_size, model.traj_prompt_j * model.traj_prompt_h, 1).to(device)

            curr_zorank_prompt = [
                traj_prompt_timesteps.clone().squeeze(),
                traj_prompt_states.clone().squeeze(),
                traj_prompt_actions.clone().squeeze(),
                traj_prompt_rtgs.clone().squeeze()
            ]
        
        else:
            raise ValueError("Invalid prompt sampling method")

        if prompt_sampling_method not in ["hillclimbing", "zoranksgd"]:
            selected_segments_start_end = [segment_idxs[arm_idx] for arm_idx in selected_segment_idxs]

            traj_prompt_actions = []
            traj_prompt_states = []
            traj_prompt_rtgs = []
            traj_prompt_timesteps = []
            for seg_counter, seg_idx in enumerate(selected_segment_idxs):
                prompt_traj = prompt_trajs[traj_idxs[seg_idx]]
                start, end = selected_segments_start_end[seg_counter]

                traj_prompt_actions.append(torch.from_numpy(prompt_traj['actions'][start:end]).to(torch.float32))
                traj_prompt_states.append(torch.from_numpy(prompt_traj['observations'][start:end]).to(torch.float32))
                traj_prompt_rtgs.append(torch.from_numpy(prompt_traj['returns_to_go'][start:end]).to(torch.float32))
                traj_prompt_timesteps.append(torch.arange(start=start, end=end, step=1))

            traj_prompt_actions = torch.cat(traj_prompt_actions)
            traj_prompt_states = torch.cat(traj_prompt_states)
            assert traj_prompt_states.shape == (model.traj_prompt_j * model.traj_prompt_h, state_dim)
            traj_prompt_rtgs = torch.cat(traj_prompt_rtgs)
            traj_prompt_timesteps = torch.cat(traj_prompt_timesteps)
        
        traj_prompt_states = traj_prompt_states.reshape(eval_batch_size,  model.traj_prompt_j *  model.traj_prompt_h, state_dim).to(device)
        traj_prompt_actions = traj_prompt_actions.reshape(eval_batch_size,  model.traj_prompt_j *  model.traj_prompt_h, act_dim).to(device)
        traj_prompt_rtgs = traj_prompt_rtgs.reshape(eval_batch_size,  model.traj_prompt_j *  model.traj_prompt_h, 1).to(device)
        traj_prompt_timesteps = traj_prompt_timesteps.reshape(eval_batch_size,  model.traj_prompt_j *  model.traj_prompt_h).to(device)

        obs, info = env.reset(options={
            "target_angle_rad": args.target_angle * np.pi,
            "target_radius": args.target_radius,
        })

        obs = obs[use_state_dims]

        running_state = obs
        reward_sum = 0
        sparse_reward_sum = 0
        running_reward = 0
        pos_hist = []

        running_rtg = rtg_target / rtg_scale
        t = 0
        actions = torch.zeros((eval_batch_size, max_test_ep_len, act_dim), dtype=torch.float32, device=device)
        states = torch.zeros((eval_batch_size, max_test_ep_len, state_dim), dtype=torch.float32, device=device)
        rewards_to_go = torch.zeros((eval_batch_size, max_test_ep_len, 1), dtype=torch.float32, device=device)
        timesteps = torch.arange(start=0, end=max_test_ep_len, step=1)
        timesteps = timesteps.repeat(eval_batch_size, 1).to(device)
        # do rollout using sampled prompt
        for t in range(max_test_ep_len):

            if parse_bool_arg(str(args.norm_obs)):
                running_state = (running_state - state_mean) / state_std
            states[0, t] = torch.from_numpy(running_state).to(torch.float32).to(device)

            running_rtg = running_rtg - (running_reward / rtg_scale)
            rewards_to_go[0, t] = running_rtg

            if t < context_len:
                _, act_preds, _, _, _, _, _ = model.forward(
                    timesteps=timesteps[:,:context_len],
                    states=states[:,:context_len],
                    actions=actions[:,:context_len],
                    returns_to_go=rewards_to_go[:,:context_len],
                    traj_prompt_timesteps = traj_prompt_timesteps,
                    traj_prompt_states = traj_prompt_states,
                    traj_prompt_actions = traj_prompt_actions,
                    traj_prompt_rtgs = traj_prompt_rtgs
                )
                act = act_preds[0, t].detach()
            else:
                _, act_preds, _, _, _, _, _ = model.forward(
                    timesteps=timesteps[:,t-context_len+1:t+1],
                    states=states[:,t-context_len+1:t+1],
                    actions=actions[:,t-context_len+1:t+1],
                    returns_to_go=rewards_to_go[:,t-context_len+1:t+1],
                    traj_prompt_timesteps = traj_prompt_timesteps,
                    traj_prompt_states = traj_prompt_states,
                    traj_prompt_actions = traj_prompt_actions,
                    traj_prompt_rtgs = traj_prompt_rtgs
                )
                act = act_preds[0, -1].detach()

            running_state, running_dense_reward, done, trunc, info = env.step(act.cpu().numpy())
            done = done or trunc

            if parse_bool_arg(args.pdt_use_sparse_reward):
                running_reward = info["sparse_reward"]
            else:
                running_reward = running_dense_reward

            running_state = running_state[use_state_dims]
            reward_sum += running_reward
            sparse_reward_sum += info["sparse_reward"]
            pos_hist.append(info["agent_pos"])

            actions[0, t] = act

            if done:
                break

        rollout_trajs.append(pos_hist)
        rollouts_rewards.append(reward_sum)
        prompt_state_segments = []
        rollout_sparse_rewards.append(info["sparse_reward"])

        for segment in range(model.traj_prompt_j):
            prompt_state_segments.append(
                traj_prompt_states[
                0,
                segment * model.traj_prompt_h:(segment + 1) *  model.traj_prompt_h
                ].cpu().numpy())

        rollout_prompt_states.append(prompt_state_segments)

        print(f"Rollout {rollout_idx}, len: {len(pos_hist)}, reward: {sparse_reward_sum}", flush=True)

        if prompt_sampling_method in ["eps_greedy", "ucb", "ts"]:
            store_segments = []
            for j_idx, segment_update_idx in enumerate(selected_segment_idxs):
                if args.bandit_use_transformer_features:
                    store_segments.append(segments_features[segment_update_idx])
                else:
                    store_segments.append(segments_raw[segment_update_idx])

                start, end = selected_segments_start_end[j_idx]
                mab_segment_t_hist[j_idx].append((start + end) / 2)

            mab.store_data(store_segments, sparse_reward_sum / 10)

            losses = mab.train()
            if losses is not None:
                for segment_update_idx in range(len(losses)):
                    mab_loss_hist[segment_update_idx].append(losses[segment_update_idx])

        if prompt_sampling_method == "hillclimbing":
            if sparse_reward_sum > best_hillclimbing_reward:
                best_hillclimbing_reward = sparse_reward_sum
                best_hillclimbing_prompt = [
                    traj_prompt_timesteps.clone(),
                    traj_prompt_states.clone(),
                    traj_prompt_actions.clone(),
                    traj_prompt_rtgs.clone()
                ]
                print(f"New best hillclimbing reward: {best_hillclimbing_reward}")

    if prompt_sampling_method in ["eps_greedy", "ucb", "ts"]:
        mab_results = {
            "mab_losses": mab_loss_hist,
            "mab_segment_t_hist": mab_segment_t_hist,
            "mab_epsilon_hist": mab_epsilon_hist
        }

    else:
        mab_results = None

    return rollout_trajs, rollouts_rewards, mab_results, rollout_prompt_states


def main(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True

    device = get_device(use_cuda=bool(args.cuda))

    env_name = args.env
    env = gym.make(env_name)
    _ = env.reset(seed=int(args.seed))
    act_dim = env.action_space.shape[0]

    model_str = f"/{args.model_file}"
    load_path = args.load_path

    local_base_dir = load_path
    model_local_path = f"{load_path}{model_str}"
    args_local_path = f"{load_path}/args.txt"
    state_mean_local_path = f"{load_path}/state_mean.npy"
    state_std_local_path = f"{load_path}/state_std.npy"

    state_mean = np.load(state_mean_local_path)
    state_std = np.load(state_std_local_path)

    with open(args_local_path, 'r') as f:
        for line in f:
            key, val = line.split(": ")
            if key == "seed":
                continue
            setattr(args, key, val.strip())

    context_len = int(args.context_len)
    n_blocks = int(args.n_blocks)
    embed_dim = int(args.embed_dim)
    n_heads = int(args.n_heads)
    transformer_dropout_p = float(args.transformer_dropout_p)
    mlp_dropout_p = float(args.pred_mlp_dropout_p)
    use_dims = args.use_state_dims.replace("[", "").replace("]", "").replace("'", "").split(", ")
    use_dims = [int(dim) for dim in use_dims]
    state_dim = len(use_dims)
    traj_prompt_j = int(args.traj_prompt_j)
    traj_prompt_h = int(args.traj_prompt_h)

    assert args.model == "traj_pdt", "Prompt tuning requires an underlying PROMPTING (!) DT model."
    model = DecisionTransformer(
        state_dim=state_dim,
        act_dim=act_dim,
        n_blocks=n_blocks,
        h_dim=embed_dim,
        context_len=context_len,
        n_heads=n_heads,
        transformer_drop_p=transformer_dropout_p,
        mlp_drop_p=mlp_dropout_p,
        mlp_num_layers=int(args.pred_mlp_num_layers),
        which_model=args.model,
        traj_prompt_j=int(args.traj_prompt_j),
        traj_prompt_h=int(args.traj_prompt_h),
    ).to(device)

    model.load_state_dict(torch.load(model_local_path, weights_only=False, map_location=torch.device('cpu')))
    model.eval()

    dataset_dirs = args.dataset_dirs.replace("[", "").replace("]", "").replace("'", "").split(", ")

    target_dataset = None
    for ds_idx, dataset_path in enumerate(dataset_dirs):
        dataset_env_name = dataset_path.split("CircleStopEnv")[1]
        dataset_env_name = dataset_env_name.split("-v0")[0]
        angle = dataset_env_name.split("angle")[1]
        angle = float(angle.split("_")[0])
        radius = float(dataset_env_name.split("radius")[1])

        if angle != float(args.target_angle) or radius != float(args.target_radius):
            continue

        assert target_dataset is None, "Seems there are multiple datasets with same angle and radius (aka for one task), that shouldn't be the case."

        noise_scale = float(args.traj_prmopt_noise_scale) if hasattr(args, "traj_prmopt_noise_scale") else 0.0
        dataset_traj_file = dataset_path + "/trajectories.pkl"
        target_dataset = TrajectoryDataset(
            dataset_path=dataset_traj_file,
            context_len=context_len,
            rtg_scale=float(args.rtg_scale),
            traj_prompt_j=int(args.traj_prompt_j),
            traj_prompt_h=int(args.traj_prompt_h),
            use_state_dims=use_dims,
            traj_prompt_noise_scale=noise_scale,
            use_sparse_reward=parse_bool_arg(str(args.pdt_use_sparse_reward))
        )

        if parse_bool_arg(str(args.norm_obs)):
            target_dataset.state_mean = state_mean
            target_dataset.state_std = state_std
            target_dataset.normalize_states()


    # extract all possible prompt segments
    segments_raw, segments_features, segment_idxs, traj_idxs, prompt_trajs = get_all_prompt_segments(
        target_dataset=target_dataset,
        model=model,
        state_mean=state_mean,
        state_std=state_std,
        rtg_scale=int(args.rtg_scale),
        state_dim=state_dim,
        act_dim=act_dim,
        context_len=context_len,
        device=device
    )

    # init bandit
    if args.bandit_use_transformer_features:
        bandit_feature_dim = model.h_dim
    else:
        n_segment_tokens = traj_prompt_h * state_dim
        n_segment_tokens += traj_prompt_h * act_dim
        n_segment_tokens += traj_prompt_h * 1
        bandit_feature_dim = n_segment_tokens

    if args.sampling_method in ["random", "hillclimbing", "zoranksgd"]:
        mab = None
    else:

        if args.sampling_method == "eps_greedy":
            bandit = NeuralBanditEpsGreedy
        elif args.sampling_method == "ucb":
            bandit = NeuralBanditUCB
        elif args.sampling_method == "ts":
            bandit = ThompsonSamplingBandit
        else:
            raise ValueError("Invalid sampling method")

        mab = bandit(
            input_dim=bandit_feature_dim,
            segments_per_prompt=traj_prompt_j,
            device=device,
            hidden_size=16,
            segment_length=traj_prompt_h * (state_dim + act_dim + 1),
            epsilon=args.epsilon,
            num_rollouts=args.num_rollouts
        )

    # do online rollouts
    trajectories, rewards, mab_results, prompt_states = do_online_rollouts(
        model=model,
        env=env,
        num_rollouts=args.num_rollouts,
        use_state_dims=use_dims,
        rtg_target=int(args.rtg_target),
        rtg_scale=int(args.rtg_scale),
        act_dim=act_dim,
        state_dim=state_dim,
        context_len=context_len,
        state_mean=state_mean,
        state_std=state_std,
        prompt_sampling_method=args.sampling_method,
        mab=mab,
        max_ep_len=int(args.max_eval_ep_len),
        segments_raw=segments_raw,
        segments_features=segments_features,
        segment_idxs=segment_idxs,
        prompt_trajs=prompt_trajs,
        traj_idxs=traj_idxs,
    )

    save_dir = local_base_dir
    time_stamp = datetime.now().strftime("%d-%m-%y_%H-%M-%S")
    results_dir = f"{save_dir}/{args.sampling_method}_{'SegTFeatures' if args.sampling_method == 'mab' and args.bandit_use_transformer_features else 'SegmentsRaw'}_targetRad{args.target_radius}_targetAngle{args.target_angle}_{args.model_file}_EXP:{time_stamp}_seed:{args.seed}"
    os.makedirs(results_dir, exist_ok=True)

    # save results to disk
    results = {
        "trajectories": trajectories,
        "rewards": rewards,
        "prompt_states": prompt_states,
        "mab_results": mab_results,
        "seed": args.seed,
        "state_mean": state_mean,
        "state_std": state_std,
    }

    results_save_file = f"{results_dir}/results.pkl"
    with open(results_save_file, "wb") as f:
        import pickle
        pickle.dump(results, f)

    if not args.hide_plots:
        import matplotlib.pyplot as plt
        plt.plot(rewards, "-o")
        plt.axhline(10, label="optimal", ls="--", c="k")
        plt.xlabel("Rollouts")
        plt.ylabel("Return")
        plt.show()

    return results_save_file


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # which PDT model to load
    parser.add_argument("--load_path", type=str, default="", help="Path to dir to load model (and model config) from", required=True)
    parser.add_argument('--model_file', type=str, default="", help="The model CP file to load", required=True)

    # which task to tune on
    parser.add_argument('--target_angle', type=float, default=0.0)
    parser.add_argument('--target_radius', type=float, default=1.9)
    parser.add_argument('--env', type=str, default='CircleStopEnv-randomAngle_randomRad-v0')
    parser.add_argument('--rtg_target', type=int, default=10)
    parser.add_argument('--max_eval_ep_len', type=int, default=100)
    parser.add_argument('--num_rollouts', type=int, default=250)

    # tuning args
    parser.add_argument('--sampling_method', type=str, default="ts", help="The prompt sampling method. 'random' for standard PDT without prompt tuning. 'eps_greedy', 'ucb', or 'ts' for bandit prompt-tuning with respective exploration. 'hillclimbing' or 'zoranksgd' for the other baselines.")
    parser.add_argument('--bandit_use_transformer_features', action='store_true', dest='bandit_use_transformer_features', default=False, help="Whether to use transformer features of segments for MAB")
    parser.add_argument('--epsilon', type=str, default="decay", help="Epsilon for MAB, either 'decay' or a float (string) value for a fixed epsilon")
    parser.add_argument('--data_mixture', type=str, default="mixture-100percent-expert", help="For prompt quality experiment. Either 'expert' to sample from all expert demonstrations, or 'mixture-{INT}percent-expert', where {INT} can be replaced with 10, ''', 100.")
    parser.add_argument('--zoranksgd_m', type=int, default=5)
    parser.add_argument('--zoranksgd_mu', type=float, default=1)
    parser.add_argument('--zoranksgd_eta', type=float, default=0.1)

    # other
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--cuda', action='store_true', default=True)
    parser.add_argument('--pdt_use_sparse_reward', action='store_true', default=True)
    parser.add_argument('--hide_plots', action='store_true', default=True)


    args = parser.parse_args()

    results_save_file = main(args)
    print(results_save_file)
