from collections import defaultdict
import torch
import time

#from scripts.plan import start_time

REWARD_DIM = VALUE_DIM = 1
#from .mcts import *
#from .mcts_beam import *
from .mcts_expand import *
import networkx as nx
import matplotlib.pyplot as plt
#from collections import Counter
import torch.nn.functional as F
import time

@torch.no_grad()
def model_rollout_continuous(model, x, latent, denormalize_rew, denormalize_val, discount, prob_penalty_weight=1e4):
    prediction = model.decode(latent, x[:, -1, :model.observation_dim])
    prediction = prediction.reshape([-1, model.transition_dim])

    r_t, V_t = prediction[:, -3], prediction[:, -2]
    if denormalize_rew is not None:
        r_t = denormalize_rew(r_t).reshape([x.shape[0], -1])
    if denormalize_val is not None:
        V_t = denormalize_val(V_t).reshape([x.shape[0], -1])

    # discounts with terminal flag
    terminal = prediction[:, -1].reshape([x.shape[0], -1])
    discounts = torch.cumprod(torch.ones_like(r_t) * discount * (1-terminal), dim=-1)
    values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1] * discounts[:,-1]
    prob_penalty = prob_penalty_weight * torch.mean(torch.square(latent), dim=-1)
    objective = values - prob_penalty
    return objective.cpu().numpy(), prediction.cpu().numpy()


import numpy as np


@torch.no_grad()
def sample(model, x, denormalize_rew, denormalize_val, discount, steps, nb_samples=4096, rounds=8):
    indicies = torch.randint(0, model.model.K-1, size=[nb_samples, steps // model.latent_step],
                             device=x.device, dtype=torch.int32)
    prediction_raw = model.decode_from_indices(indicies, x[:, 0, :model.observation_dim])
    prediction = prediction_raw.reshape([-1, model.transition_dim])

    r_t, V_t = prediction[:, -3], prediction[:, -2]
    if denormalize_rew is not None:
        r_t = denormalize_rew(r_t).reshape([indicies.shape[0], -1])
    if denormalize_val is not None:
        V_t = denormalize_val(V_t).reshape([indicies.shape[0], -1])

    discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
    values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1]*discounts[:, -1]
    optimal = prediction_raw[values.argmax()]
    print(values.max().item())
    return optimal.cpu().numpy()


@torch.no_grad()
def sample_with_prior(prior, model, x, denormalize_rew, denormalize_val, discount, steps, nb_samples=4096, rounds=8,
                      likelihood_weight=5e2, prob_threshold=0.05, uniform=False, return_info=False):
    state = x[:, 0, :model.observation_dim]
    optimals = []
    optimal_values = []
    info = defaultdict(list)
    for round in range(rounds):
        contex = None
        acc_probs = torch.zeros([1]).to(x)
        for step in range(steps//model.latent_step):
            logits, _ = prior(contex, state) # [B x t x K]
            probs = raw_probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
            log_probs = torch.log(probs)
            if uniform:
                valid = probs > 0
                probs = valid/valid.sum(dim=-1)[:, None]
            if step == 0:
                samples = torch.multinomial(probs, num_samples=nb_samples//rounds, replacement=True) # [B, M]
            else:
                samples = torch.multinomial(probs, num_samples=1, replacement=True)  # [B, M]
            samples_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(log_probs, samples)]) # [B, M]
            acc_probs = acc_probs + samples_prob.reshape([-1])
            if not contex is None:
                contex = torch.cat([contex, samples.reshape([-1, 1])], dim=1)
            else:
                contex = samples.reshape([-1, step+1]) # [(B*M) x t]
        prediction_raw = model.decode_from_indices(contex, state)
        prediction = prediction_raw.reshape([-1, model.transition_dim])

        r_t = prediction[:, -3]
        V_t = prediction[:, -2]
        terminals = prediction[:, -1].reshape([contex.shape[0], -1])
        if denormalize_rew is not None:
            r_t = denormalize_rew(r_t).reshape([contex.shape[0], -1])
        if denormalize_val is not None:
            V_t = denormalize_val(V_t).reshape([contex.shape[0], -1])

        discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
        values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1] * discounts[:,-1]
        likelihood_bonus = likelihood_weight*torch.clamp(acc_probs, -1e5, np.log(prob_threshold)*(steps//model.latent_step))
        info["log_probs"].append(acc_probs.cpu().numpy())
        info["returns"].append(values.cpu().numpy())
        info["predictions"].append(prediction_raw.cpu().numpy())
        info["objectives"].append(values.cpu().numpy() + likelihood_bonus.cpu().numpy())
        info["latent_codes"].append(contex.cpu().numpy())
        max_idx = (values+likelihood_bonus).argmax()
        optimal_value = values[max_idx]
        optimal = prediction_raw[max_idx]
        optimals.append(optimal)
        optimal_values.append(optimal_value.item())

    for key, val in info.items():
        info[key] = np.concatenate(val, axis=0)
    max_idx = np.array(optimal_values).argmax()
    optimal = optimals[max_idx]
    print(f"predicted max value {optimal_values[max_idx]}")
    if return_info:
        return optimal.cpu().numpy(), info
    else:
        return optimal.cpu().numpy()


@torch.no_grad()
def sample_with_prior_tree(prior, model, x, denormalize_rew, denormalize_val, discount, steps, samples_per_latent=16, likelihood_weight=0.0):
    contex = None
    state = x[:, 0, :model.observation_dim]
    acc_probs = torch.ones([1]).to(x)
    for step in range(steps//model.latent_step):
        logits, _ = prior(contex, state) # [B x t x K]
        probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        samples = torch.multinomial(probs, num_samples=samples_per_latent, replacement=True) # [B, M]
        samples_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(probs, samples)]) # [B, M]
        acc_probs = acc_probs.repeat_interleave(samples_per_latent, 0) * samples_prob.reshape([-1])
        if not contex is None:
            contex = torch.cat([torch.repeat_interleave(contex, samples_per_latent, 0), samples.reshape([-1, 1])],
                               dim=1)
        else:
            contex = samples.reshape([-1, step+1]) # [(B*M) x t]

    prediction_raw = model.decode_from_indices(contex, state)
    prediction = prediction_raw.reshape([-1, model.transition_dim])

    r_t, V_t = prediction[:, -3], prediction[:, -2]


    if denormalize_rew is not None:
        r_t = denormalize_rew(r_t).reshape([contex.shape[0], -1])
    if denormalize_val is not None:
        V_t = denormalize_val(V_t).reshape([contex.shape[0], -1])

    discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
    values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1]*discounts[:, -1]
    likelihood_bouns = likelihood_weight*torch.log(acc_probs)
    max_idx = (values+likelihood_bouns).argmax()
    optimal = prediction_raw[max_idx]
    print(f"predicted max value {values[max_idx]}, likelihood {acc_probs[max_idx]} with bouns {likelihood_bouns[max_idx]}")
    return optimal.cpu().numpy()


@torch.no_grad()
def build_full_sequences_step1(ctx_indices_all: torch.Tensor,  # [D,B,Tctx]
                               next1_ids: torch.Tensor         # [D,B,K]
                               ) -> torch.Tensor:
    """
    Make full sequences for step t+1 by appending K candidates to the time axis.
    Returns: full1_flat indices of shape [D, B*K, Tctx+1]
    """
    D, B, Tctx = ctx_indices_all.shape
    _, _, K    = next1_ids.shape

    # repeat context across K branches in the batch dimension
    ctx_rep   = ctx_indices_all.repeat_interleave(K, dim=1)  # [D, B*K, Tctx]
    # append each candidate as the (Tctx)-th new time position
    next1_flat = next1_ids.reshape(D, B*K)                   # [D, B*K]
    full1_flat = torch.cat([ctx_rep, next1_flat.unsqueeze(-1)], dim=2)  # [D, B*K, Tctx+1]
    return full1_flat


@torch.no_grad()
def build_full_sequences_step2(ctx_indices_all: torch.Tensor,  # [D,B,Tctx]
                               next1_ids: torch.Tensor,        # [D,B,K]
                               next2_ids: torch.Tensor         # [D,B,K,M]
                               ) -> torch.Tensor:
    """
    Make full sequences for steps t+1 and t+2 by appending (K, M) branches.
    Returns: full2_flat indices of shape [D, B*K*M, Tctx+2]
    """
    D, B, Tctx = ctx_indices_all.shape
    _, _, K    = next1_ids.shape
    _, _, _, M = next2_ids.shape

    # First, extend by step-1 (K branches)
    full1_flat = build_full_sequences_step1(ctx_indices_all, next1_ids)     # [D, B*K, Tctx+1]

    # Now append step-2 inside each K branch (fan out by M)
    # Repeat the step-1 extended sequences across M
    full1_rep  = full1_flat.repeat_interleave(M, dim=1)                     # [D, B*K*M, Tctx+1]
    # Flatten K,M candidates
    next2_flat = next2_ids.view(D, B*K*M)                                   # [D, B*K*M]
    full2_flat = torch.cat([full1_rep, next2_flat.unsqueeze(-1)], dim=2)    # [D, B*K*M, Tctx+2]
    return full1_flat, full2_flat



@torch.no_grad()
def decode_full_step1(vq_ae,
                      ctx_indices_all: torch.Tensor,  # [D,B,Tctx]
                      next1_ids: torch.Tensor,        # [D,B,K]
                      state: torch.Tensor             # [B, obs_dim]
                      ):
    """
    Decode K full sequences (context + 1 new latent step) in one call.
    Returns:
      recon: Float [B, K, T_dec_total, transition_dim]
    """
    D, B, Tctx = ctx_indices_all.shape
    _, _, K    = next1_ids.shape

    # Build full sequences with context retained
    full1_flat = build_full_sequences_step1(ctx_indices_all, next1_ids)      # [D, B*K, Tctx+1]

    # Repeat state across K branches
    state_rep  = state.unsqueeze(1).expand(B, K, state.size(-1)).reshape(B*K, -1)  # [B*K, obs_dim]

    # One batched decode – decoder attends to full context (causal)
    recon_flat = vq_ae.decode_from_indices(full1_flat, state_rep)            # [B*K, T_dec_total, D_out]
    recon      = recon_flat.view(B, K, recon_flat.size(1), recon_flat.size(2)).contiguous()
    return recon


@torch.no_grad()
def decode_full_step2(vq_ae,
                      ctx_indices_all: torch.Tensor,  # [D,B,Tctx]
                      next1_ids: torch.Tensor,        # [D,B,K]
                      next2_ids: torch.Tensor,        # [D,B,K,M]
                      state: torch.Tensor             # [B, obs_dim]
                      ):
    """
    Decode K×M full sequences (context + 2 latent steps) in one call.
    Returns:
      recon: Float [B, K, M, T_dec_total, transition_dim]
    """
    D, B, Tctx = ctx_indices_all.shape
    _, _, K    = next1_ids.shape
    _, _, _, M = next2_ids.shape

    # Build full sequences up to t+2 with context retained
    full1_flat, full2_flat = build_full_sequences_step2(ctx_indices_all, next1_ids, next2_ids)  # [D, B*K*M, Tctx+2]
    # Repeat state across K×M branches
    state_rep  = state.unsqueeze(1).expand(B, K*M, state.size(-1)).reshape(B*K*M, -1)  # [B*K*M, obs_dim]
    # One batched decode – decoder attends to full context (causal)
    recon_flat = vq_ae.decode_from_indices(full2_flat, state_rep)  # [B*K*M, T_dec_total, D_out]

    recon      = recon_flat.view(B, K, M, recon_flat.size(1), recon_flat.size(2)).contiguous()
    #print("full2_flat", full2_flat.shape)
    return recon, full1_flat, full2_flat



@torch.no_grad()
def beam_with_prior(prior, model, x, context_matrix, denormalize_macro, denormalize_val, normalize_val, discount, steps,
                    beam_width, n_expand, n_action, b_percent, action_percent,
                    pw_alpha, mcts_itr, prob_threshold=0.05, likelihood_weight=5e2, prob_acc="product", return_info=False):
    state = x[:, 0, :prior.observation_dim]
    initial_state = x[:, 0, :prior.observation_dim]
    def tensor_to_tuple(tensor):
        return tuple(tensor.cpu().numpy().flatten())
    # Initialize the outer dictionary
    state_dict = {}
    # Store the value in the nested dictionary
    def store_value(state, action_matrix, index, context):
        state_key = tensor_to_tuple(state)
        if state_key not in state_dict:
            state_dict[state_key] = [action_matrix,index, context]
    #for step in range(steps//model.latent_step):
    import time
    start = time.time()
    max_depth = 2
    tree_gamma = 0.99
    action_sequence = 3
    #batch_size = state.size(0)
    device = state.device
    context_window_size = 6
    # Initialize an empty context tensor
    # Starting with zeros as placeholders (or you could use a special start token index)
    #context = torch.zeros((batch_size, context_window_size), dtype=torch.long, device=device)
    pad_size = context_window_size
    # pad_context = torch.zeros(1, pad_size, model.transition_dim - 1, device=device)
    # one_tensor = torch.ones(1, pad_size, 1, device=device)
    # Encode the padding
    #pad_codes = model.encode(pad_context, one_tensor)
    #context = pad_codes
    #if context_matrix is not None and context_matrix.size(1) >= 6:
    if context_matrix is not None:
        #print(context_matrix.shape)
        #print(context_matrix.size(1))
        #if context_matrix.size(1) >= 6:
        zero_tensor = torch.zeros(1, context_matrix.size(1), 1, device=context_matrix.device)
        context_codes = model.encode_runtime(context_matrix, zero_tensor)
        # pad_size = context_window_size - context_codes.size(1)
        # zero_pad = torch.zeros(context_codes.size(0), pad_size, device=context_codes.device, dtype=context_codes.dtype)
        # context = torch.cat([zero_pad, context_codes], dim=1)
        # Calculate padding needed
        #print(context_codes.shape, context_matrix.shape)
        pad_size = context_window_size - context_codes.size(2)

        if pad_size > 0:
            # Create a padding context tensor
            pad_context = torch.zeros(1, pad_size, context_matrix.size(2), device=device)
            # Use ones instead of zeros for the second tensor
            one_tensor = torch.ones(1, pad_size, 1, device=device)
            # Encode the padding
            #print(pad_context.shape, one_tensor.shape)
            pad_codes = model.encode_runtime(pad_context, one_tensor)
            # Concatenate with the real context
            #print(pad_codes.shape, context_codes.shape)
            initial_context = torch.cat([pad_codes, context_codes], dim=2)
            #print(context)
        else:
            initial_context = context_codes

    else:
        pad_size = context_window_size
        pad_context = torch.zeros(1, pad_size, model.transition_dim-1, device=device)
        one_tensor = torch.ones(1, pad_size, 1, device=device)
        # Encode the padding
        pad_codes = model.encode_runtime(pad_context, one_tensor)
        initial_context = pad_codes
        #print(context)
    #print("context", initial_context.shape)
    ctx_indices = initial_context
    #print("ctx_indices", ctx_indices.shape)
    #initial_context = initial_context[0,:,:]
    N_per_coarse = 4
    for step in range(max_depth):
        action_sampling = beam_width if step == 0 else n_action
        if step == 0:
            (stage1_ids, stage1_lp), (stage2_ids, stage2_lp) = prior.two_stage_expand(
                model = model,
                ctx_indices_all=ctx_indices,
                state=state,
                K_topk=action_sampling,  # no-replacement breadth for step t+1
                N_per_coarse=N_per_coarse,
                M_samples=n_expand,
                coarse_temperature_stage1=2,
                fine_temperature_stage1=1,
                coarse_temperature_stage2=2,
                fine_temperature_stage2=1,
                deeper_policy_stage1="sample",
                deeper_policy_stage2="sample",  # set to "sample" if you want stochastic deeper depths
                topk_each=None
            )
            #print(stage2_ids)
            # 2) Decode (context + t+1) for all K branches in one call
            #recon_t1_full = decode_full_step1(model, initial_context, stage1_ids, state)  # [B, K, T_dec_total, D_out]
            #print("ctx_indices", stage1_ids.shape, stage1_lp.shape)
            action_probs_sampled = stage1_lp
            # 3) Decode (context + t+1 + t+2) for all K×M branches in one call
            recon_t2_full, full1_flat, full2_flat = decode_full_step2(model, ctx_indices, stage1_ids, stage2_ids,
                                              state)  # [B, K, M, T_dec_total, D_out]

            action_contex = full1_flat
            context = full2_flat
            #torch.set_printoptions(threshold=float('inf'))
            # best_ids, best_scores, best_recon = prior.propose_fine_for_topk_coarse_and_select(
            #     model,
            #     ctx_indices_all=ctx_indices,
            #     state=state,
            #     K_coarse=32, M_fine=10,
            #     temperature_coarse=1.0, temperature_fine=1.0,
            #     fine_policy="sample",  # "argmax" is okay too
            #     topk_each=3,  # set e.g. 64 to prune fine vocab
            #     score_mode="value_last",  # or "value_mean"
            #     value_index=0,  # your 'value' channel is 0
            #     terminal_index=-1,  # last channel is terminals
            #     measure_tail=False,  # useful if bottleneck == "attention"
            #     chunk_decode=None,  # or e.g. 1024 to cap VRAM
            #     return_all=False
            # )
            # #print(best_recon.shape)
            # return best_recon.squeeze(0).cpu().numpy(), 0
            #return best_action.unsqueeze(0).cpu().numpy(), 0
            # decode full sequence; slice the last 2 if that's all you need
            #recon_full = model.decode_from_indices(indices_full, state_batch2)  # [B', Tctx+2, D]
        else:
            #contex = None
            state_for_next_prior = state_for_next_prior[0]
            (stage1_ids, stage1_lp), (stage2_ids, stage2_lp) = prior.two_stage_expand(
                model=model,
                ctx_indices_all=context,
                state=state_for_next_prior,
                K_topk=action_sampling,  # no-replacement breadth for step t+1
                N_per_coarse=N_per_coarse,
                M_samples=n_expand,  # with-replacement samples for step t+2
                coarse_temperature_stage1=2,
                fine_temperature_stage1=1,
                coarse_temperature_stage2=2,
                fine_temperature_stage2=1,
                deeper_policy_stage1="sample",
                deeper_policy_stage2="sample",  # set to "sample" if you want stochastic deeper depths
                topk_each=None
            )
            #print("step2", stage1_ids.shape, stage2_ids.shape)
            action_probs_sampled = stage1_lp
            # 3) Decode (context + t+1 + t+2) for all K×M branches in one call
            recon_t2_full, full1_flat, full2_flat = decode_full_step2(model, context, stage1_ids, stage2_ids,
                                              state_for_next_prior)  # [B, K, M, T_dec_total, D_out]
            action_contex = full1_flat
            context = full2_flat
        nb_samples = action_sampling*N_per_coarse
        if step == 0:
            D = context.shape[0]
            prediction_raw = recon_t2_full.squeeze(0)
            #print(prediction_raw.shape)
            reshaped_prediction_raw = prediction_raw.view(nb_samples, n_expand, context_window_size+2, -1)
            #expanded_action_contex = action_contex[:, :, -1:].unsqueeze(1).unsqueeze(2).expand(nb_samples, n_expand,context_window_size + 2, 1)
            expanded_action_contex = action_contex[:, :, -1:]
            expanded_prior_probs = action_probs_sampled.reshape([-1, 1]).unsqueeze(2).unsqueeze(3).expand(nb_samples, n_expand, context_window_size + 2, 1)
            concatenated_tensor = torch.cat([reshaped_prediction_raw, expanded_prior_probs], dim=3)
            expanded_action_contex = expanded_action_contex.permute(1, 2, 0)
            expanded_action_contex = expanded_action_contex.unsqueeze(1).expand(-1, n_expand, context_window_size + 2, -1)
            #print("after unsqueeze", expanded_action_contex)
            concatenated_tensor = torch.cat([concatenated_tensor,expanded_action_contex], dim=3)

            B, H, W, F = concatenated_tensor.shape
            C = 1  # how many channels you want to append
            expanded_mse_loss = concatenated_tensor.new_ones(B, H, W, C)
            final_tensor = torch.cat([concatenated_tensor, expanded_mse_loss], dim=3)
            expansion_values = denormalize_val(final_tensor[:, :, -1, 0])

            #expansion_values = final_tensor[:, :, -1, 0]
            action_values = final_tensor[:, 0, -2, 0].view(-1, 1)
            macro_values = final_tensor[:, :, -2, -4-D]
            expansion_values *= (tree_gamma ** action_sequence)
            #print(denormalize_macro(macro_values), expansion_values)

            expansion_values += denormalize_macro(macro_values)

            expansion_values = normalize_val(expansion_values)
            #print(final_tensor.shape, expansion_values.shape, action_values.shape)
            #print(expansion_values, action_values, macro_values, expansion_values)
            mean_values = torch.cat((expansion_values, action_values), dim=1)
            #print(context)
            #print(mean_values)

            mean_values = mean_values.mean(dim=1)
            k = int(mean_values.size(0)*b_percent) if int(mean_values.size(0)*b_percent) >=1 else 1
            values_with_b, index = torch.topk(mean_values, k)
            store_value(initial_state, final_tensor, index, action_contex)

            state_for_next_prior = final_tensor[index,:,-1,1:1+model.observation_dim]
            #print(context)
            history_contex = context.view(D, nb_samples, n_expand, -1)[:,index,:,:]
            history_contex = history_contex.view(D, -1, history_contex.size(-1))
            #print(history_contex)
            original_ctx_dtype = history_contex.dtype
            original_state_dtype = state_for_next_prior.dtype
            # #unique_tensor = torch.unique(state_for_next_prior, dim=1)
            # #print("Shape after unique along dim=1:", unique_tensor.shape)
            state_for_next_prior = state_for_next_prior.view(-1, state_for_next_prior.size(-1))
            state_for_next_prior = state_for_next_prior.unsqueeze(0).expand(history_contex.size(0), -1, -1)
            #print(state_for_next_prior.shape, history_contex)
            combined = torch.cat([history_contex, state_for_next_prior], dim=-1)
            unique_combined = torch.unique(combined, dim=1)
            # # Split the unique tensor back into the original components
            context = unique_combined[:,:, :history_contex.size(2)].to(original_ctx_dtype)[:, :, 1:-1]
            state_for_next_prior = unique_combined[:,:, history_contex.size(2):].to(original_state_dtype)
            #print(expanded_action_contex)
            print("history_contex depth 0", context.shape, state_for_next_prior.shape)
            # mcts_instance1 = MCTS(state, state_dict, tree_gamma, prior, model, int(n_action * action_percent), n_expand,
            #                      0, 0, D, denormalize_macro, denormalize_val, normalize_val)
            # # # print(state_dict.keys())
            # mcts_instance1.search(mcts_itr)
            # argmax_best_action, predicted_states = mcts_instance1.best_action()
            # argmax_best_action = argmax_best_action[:, None, None]
            # print("argmax_best_action", argmax_best_action)
        else:
            #decode_for_ood
            #state_for_next_prior_expanded = state_for_next_prior_expanded.repeat_interleave(n_expand, 0)
            #prediction_raw = recon_t2_full
            #prediction_raw = model.decode_for_ood(context, state_for_next_prior_expanded)
            #predicted_first_state = prediction_raw[:, -2,1:model.observation_dim+1]
            #decoded_state_compare = state_for_next_prior_expanded
            #mse_loss_per_element = F.mse_loss(predicted_first_state, decoded_state_compare, reduction='none')
            #mse_loss_per_example = mse_loss_per_element.mean(dim=1)
            #mse_loss_per_example = mse_loss_per_example.view(-1, nb_samples, n_expand)
            #print(mse_loss_per_example)
            #expanded_mse_loss = mse_loss_per_example.unsqueeze(3).unsqueeze(4).expand(-1, nb_samples, n_expand, context_window_size+2, 1)
            #print(expanded_mse_loss[0])
            #reshaped_prediction_raw = prediction_raw.view(-1,nb_samples, n_expand, context_window_size+2, model.observation_dim + action_sequence*model.action_dim+ 2)
            D = action_contex.shape[0]
            reshaped_prediction_raw = recon_t2_full
            expanded_action_contex = action_contex[:,:,-1:]
            #print(expanded_action_contex.shape)
            expanded_action_contex = expanded_action_contex.permute(1, 2, 0).view(-1, nb_samples, 1, D)
            #print(expanded_action_contex.shape)
            expanded_action_contex = expanded_action_contex.unsqueeze(2).expand(-1, nb_samples, n_expand, context_window_size + 2, D)
            #print("expanded_action_contex", recon_t2_full.shape, action_contex.shape, expanded_action_contex.shape)
            #print(action_probs_sampled.shape)
            expanded_prior_probs = action_probs_sampled.unsqueeze(2).unsqueeze(3).unsqueeze(4).expand(-1, nb_samples,
                                                                                             n_expand, context_window_size+2, 1)
            #print("reshaped_prediction_raw", reshaped_prediction_raw.shape)
            # expanded_action_contex = expanded_action_contex.unsqueeze(3).unsqueeze(4).expand(-1, nb_samples,
            #                                                                                  n_expand, context_window_size+2, 1)
            concatenated_tensor = torch.cat([reshaped_prediction_raw, expanded_prior_probs], dim=4)
            concatenated_tensor = torch.cat([concatenated_tensor,expanded_action_contex], dim=4)
            B, H, W, F, A = concatenated_tensor.shape
            expanded_mse_loss = concatenated_tensor.new_ones(B, H, W, F, 1)
            final_tensor = torch.cat([concatenated_tensor, expanded_mse_loss], dim=4)
            #print(final_tensor.shape)
            macro_values = final_tensor[:, :, :, -2, -4-D]
            expansion_values = denormalize_val(final_tensor[:, :, :, -1, 0])
            #expansion_values = final_tensor[:, :, :, -1, 0]
            action_values = final_tensor[:, :, 0, -2, 0].view(state_for_next_prior.shape[0], -1, 1)
            action_mse = final_tensor[:, :, 0, -2, -1]
            expansion_values *= (tree_gamma ** action_sequence)
            expansion_values += denormalize_macro(macro_values)
            expansion_values = normalize_val(expansion_values)
            mean_values = torch.cat((expansion_values, action_values), dim=2)
            mean_values = mean_values.mean(dim=2) -  0*action_mse
            #mean_values = action_probs_sampled.squeeze(-1) * mean_values

            #print("step 1", mean_values)
            #mean_values = mean_values.mean(dim=2)
            k = max(int(mean_values.size(1) * action_percent), 1)
            all_selected_tensors = []
            all_selected_history = []
            values_with_b, index = torch.topk(mean_values, k)
            #print("context", context.shape, final_tensor.shape)
            history_contex = context.view(D, -1, nb_samples, n_expand, context_window_size+2)
            #print(final_tensor.shape, history_contex.shape, action_contex.shape)
            for i in range(state_for_next_prior.shape[0]):
                store_value(state_for_next_prior[i], final_tensor[i], index[i], history_contex[:,i,:,:,:])
                #print(len(state_dict.keys()))
                all_selected_tensors.append(final_tensor[i][index[i]])
                selected_history = history_contex[:,i,index[i],:,:]
                all_selected_history.append(selected_history)
            final_selected_state = torch.cat(all_selected_tensors, dim=0)
            #final_selected_state = final_selected_state.view(-1,final_selected_state.size(2), final_selected_state.size(3))
            final_selected_state = final_selected_state[:,:,-1,1:1+model.observation_dim]
            final_history = torch.cat(all_selected_history, dim=1)
            final_selected_state = final_selected_state.unsqueeze(0).expand(final_history.size(0), -1, -1,-1)
            final_selected_state = final_selected_state.view(final_selected_state.size(0), -1, final_selected_state.size(3))
            final_history = final_history.view(final_history.size(0), -1, final_history.size(3))
            #final_history = final_history.view(final_selected_state.size(0), -1)
            # history_contex = history_contex.view(-1, history_contex.size(-1))
            original_ctx_dtype = final_history.dtype
            original_state_dtype = final_selected_state.dtype
            combined = torch.cat([final_history, final_selected_state], dim=-1)
            unique_combined = torch.unique(combined, dim=1)
            # # Split the unique tensor back into the original components
            context = unique_combined[:,:, :final_history.size(2)].to(original_ctx_dtype)[:, :, 1:-1]
            state_for_next_prior = unique_combined[:,:, final_history.size(2):].to(original_state_dtype)
            #print("history_contex depth 1", context.shape, state_for_next_prior.shape)
    print("inference time,",time.time() - start)
    #print(tensor_to_tuple(state), state_dict)
    mcts_instance = MCTS(state, state_dict, tree_gamma, prior, model, int(n_action*action_percent), n_expand, 0, max_depth-1, D, denormalize_macro, denormalize_val, normalize_val)
    #print(state_dict.keys())
    start_time = time.time()
    mcts_instance.search(mcts_itr)
    values_list = list(mcts_instance.Qsa.values())
    # Stack the tensors into one tensor
    values_tensor = torch.stack(values_list)

    # Compute the mean and standard deviation
    value_mean = torch.mean(values_tensor)
    value_std = torch.std(values_tensor)
    value_max = torch.max(values_tensor)
    value_min = torch.min(values_tensor)
    # Print the results
    print("Mean:", value_mean.item(), "Std:", value_std.item(), "Max:", value_max.item(), "Min:", value_min.item())
    # Stop the timer
    end_time = time.time()

    # Calculate the running time
    running_time = end_time - start_time
    print("search time,", running_time)
    #print(context.shape, best_action.view(1, -1).shape)
    best_action, predicted_states = mcts_instance.best_action()
    best_action = best_action[:, None, None]
    print("best_action,", best_action)
    #print(model.decode_from_indices(torch.cat((initial_context, best_action.long()), dim=-1), state).shape)
    #prediction_raw = model.decode_from_pos(best_action.view(1, -1), state, context_window_size).squeeze(0)
    #print("best_action,", best_action, best_action.shape)
    #print(best_action)
    prediction_raw = model.decode_from_indices(torch.cat((initial_context, best_action.long()), dim=-1), state)[:,-1:,:].squeeze(0)
    #print("prediction_raw,", prediction_raw.shape, prediction_raw)
    return prediction_raw.cpu().numpy(), predicted_states





@torch.no_grad()
def beam_with_prior_contex(prior, model, x, context_matrix, denormalize_rew, denormalize_val, discount, steps,
                    beam_width, n_expand, n_action, b_percent, action_percent,
                    pw_alpha, mcts_itr, prob_threshold=0.05, likelihood_weight=5e2, prob_acc="product", return_info=False):
    state = x[:, 0, :prior.observation_dim]
    def tensor_to_tuple(tensor):
        return tuple(tensor.cpu().numpy().flatten())

    # Initialize the outer dictionary
    state_dict = {}
    # Store the value in the nested dictionary
    def store_value(state, action_matrix, index):
        state_key = tensor_to_tuple(state)
        #print(state_key)
        if state_key not in state_dict:
            state_dict[state_key] = [action_matrix,index]
    #for step in range(steps//model.latent_step):
    import time
    start = time.time()
    max_depth = 1
    tree_gamma = 0.99
    action_sequence = 3
    mse_factor = 1
    batch_size = state.size(0)
    device = state.device
    context_window_size = 6

    # Initialize an empty context tensor
    # Starting with zeros as placeholders (or you could use a special start token index)
    context = torch.zeros((batch_size, context_window_size), dtype=torch.long, device=device)
    if context_matrix is not None:
        zero_tensor = torch.zeros(1, context_matrix.size(1), 1, device=context_matrix.device)
        context_codes = model.encode(context_matrix, zero_tensor)
        pad_size = context_window_size - context_codes.size(1)
        zero_pad = torch.zeros(context_codes.size(0), pad_size, device=context_codes.device, dtype=context_codes.dtype)
        context = torch.cat([zero_pad, context_codes], dim=1)
    #context_window
    for step in range(max_depth):
        if step == 0:
            logits, _ = prior(None, state) # [B x t x K]
        else:
            #contex = None
            logits, _ = prior(None, state_for_next_prior)
        action_probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        nb_samples = beam_width if step == 0 else n_action

        action_samples = torch.multinomial(action_probs, num_samples=nb_samples, replacement=False) # [B, M]
        # Gather the corresponding probabilities for the sampled actions
        action_probs_sampled = torch.gather(action_probs, 1, action_samples)
        action_contex = action_samples.reshape([-1, 1]) # [(B*M) x t]
        expanded_context = context.repeat_interleave(nb_samples, 0)
        action_contex = torch.cat([expanded_context, action_contex], dim=1)
        if step == 0:
            logits, _ = prior(action_contex, state)
        else:
            #print(state_for_next_prior.shape)
            state_for_next_prior_expanded = state_for_next_prior.repeat_interleave(nb_samples, 0)
            logits, _ = prior(action_contex, state_for_next_prior_expanded)
        probs = torch.softmax(logits[:, -1, :], dim=-1)  # [B x K]
        log_probs = torch.log(probs)
        samples = torch.multinomial(probs, num_samples=n_expand, replacement=True)  # [B, M]
        contex = torch.cat([torch.repeat_interleave(action_contex, n_expand, 0), samples.reshape([-1, 1])], dim=1)
        if step == 0:
            prediction_raw = model.decode_from_indices(contex, state)
            reshaped_prediction_raw = prediction_raw.view(nb_samples, n_expand, context_window_size+2, -1)
            #print("prediction_raw,",prediction_raw.shape, reshaped_prediction_raw.shape)
            #print(action_contex[:,-1:].shape, nb_samples, n_expand)
            expanded_action_contex = action_contex[:,-1:].unsqueeze(1).unsqueeze(2).expand(nb_samples, n_expand, context_window_size+2, 1)
            #print("expanded_action_contex,",expanded_action_contex.shape)
            #print(expanded_action_contex.shape)
            predicted_first_state = prediction_raw[:, -2, 1:1+model.observation_dim]
            decoded_state_compare = state.expand_as(predicted_first_state)

            mse_loss_per_element = F.mse_loss(predicted_first_state, decoded_state_compare, reduction='none')
            mse_loss_per_example = mse_loss_per_element.mean(dim=1)
            mse_loss_per_example = mse_loss_per_example.view(nb_samples, n_expand)
            expanded_mse_loss = mse_loss_per_example.unsqueeze(2).unsqueeze(3).expand(nb_samples, n_expand, context_window_size+2, 1)
            expanded_prior_probs = action_probs_sampled.reshape([-1, 1]).unsqueeze(2).unsqueeze(3).expand(nb_samples, n_expand, context_window_size+2, 1)
            concatenated_tensor = torch.cat([reshaped_prediction_raw, expanded_prior_probs], dim=3)
            concatenated_tensor = torch.cat([concatenated_tensor,expanded_action_contex], dim=3)
            final_tensor = torch.cat([concatenated_tensor, expanded_mse_loss], dim=3)
            #print(final_tensor.shape)
            expansion_values = final_tensor[:, :, -1, 0]   #return to go for sampled state
            action_values = final_tensor[:, 0, -2, 0].view(-1, 1)  #return to go for bootstrapping Q value
            action_mse = final_tensor[:, 0, 0, -1]
            #print(final_tensor.shape)
            expansion_values *= (tree_gamma ** action_sequence) #short back propagation

            #expansion_mean = expansion_values.mean(dim=1)
            #action_mean = action_values.mean(dim=1)
            #mean_values = action_mean + 0.1 * (expansion_mean - action_mean)

            mean_values = torch.cat((expansion_values, action_values), dim=1)
            mean_values = mean_values.mean(dim=1)
            mean_values = mean_values - mse_factor*action_mse
            #print("mean_values,",mean_values)
            k = int(mean_values.size(0)*b_percent) if int(mean_values.size(0)*b_percent) >=1 else 1
            values_with_b, index = torch.topk(mean_values, k)
            store_value(state, final_tensor, index)

            state_for_next_prior = final_tensor[index,:,1,1:1+model.observation_dim]
            state_for_next_prior = state_for_next_prior.view(-1, state_for_next_prior.size(-1))
            state_for_next_prior = torch.unique(state_for_next_prior, dim=0)

        else:

            state_for_next_prior_expanded = state_for_next_prior_expanded.repeat_interleave(n_expand, 0)
            prediction_raw = model.decode_for_ood(contex, state_for_next_prior_expanded)
            predicted_first_state = prediction_raw[:, 0,1:model.observation_dim+1]
            decoded_state_compare = state_for_next_prior_expanded


            mse_loss_per_element = F.mse_loss(predicted_first_state, decoded_state_compare, reduction='none')
            mse_loss_per_example = mse_loss_per_element.mean(dim=1)
            mse_loss_per_example = mse_loss_per_example.view(-1, nb_samples, n_expand)

            expanded_mse_loss = mse_loss_per_example.unsqueeze(3).unsqueeze(4).expand(-1, nb_samples, n_expand, 2, 1)
            #print(expanded_mse_loss[0])

            reshaped_prediction_raw = prediction_raw.view(-1,nb_samples, n_expand, 2, model.observation_dim + 3*model.action_dim+ 2)
            #print(reshaped_prediction_raw.shape)
            action_contex = action_contex.view(-1, nb_samples, 1)
            action_probs_sampled = action_probs_sampled.view(-1, nb_samples, 1)
            expanded_prior_probs = action_probs_sampled.unsqueeze(3).unsqueeze(4).expand(-1, nb_samples,
                                                                                             n_expand, 2, 1)

            #print("reshaped_prediction_raw", reshaped_prediction_raw.shape)
            expanded_action_contex = action_contex.unsqueeze(3).unsqueeze(4).expand(-1, nb_samples,
                                                                                             n_expand, 2, 1)

            concatenated_tensor = torch.cat([reshaped_prediction_raw, expanded_prior_probs], dim=4)
            concatenated_tensor = torch.cat([concatenated_tensor,expanded_action_contex], dim=4)
            zero_tensor = torch.zeros(concatenated_tensor.shape[0],
                                      concatenated_tensor.shape[1],
                                      concatenated_tensor.shape[2],
                                      concatenated_tensor.shape[3],
                                      1,
                                      device=concatenated_tensor.device)
            final_tensor = torch.cat([concatenated_tensor, zero_tensor], dim=4)
            expansion_values = final_tensor[:, :, :, 1, 0]
            action_values = final_tensor[:, :, 0, 0, 0].view(state_for_next_prior.shape[0], -1, 1)
            action_mse = final_tensor[:, :, 0, 0, -1]
            expansion_values *= (tree_gamma ** action_sequence)

            #average means
            #mean_values = torch.cat((expansion_values, action_values), dim=2)
            #mean_values = mean_values.mean(dim=2) -  mse_factor*action_mse
            #mean_values = mean_values.mean(dim=2)
            #expansion_mean = expansion_values.mean(dim=2)
            #action_mean = action_values.mean(dim=2)
            #mean_values = action_mean + 0.1 * (expansion_mean - action_mean)

            #average means
            mean_values = torch.cat((expansion_values, action_values), dim=2)
            mean_values = mean_values.mean(dim=2) -  mse_factor*action_mse



            k = max(int(mean_values.size(1) * action_percent), 1)
            all_selected_tensors = []
            values_with_b, index = torch.topk(mean_values, k)
            #start = time.time()
            #print(index.shape)
            #print()
            for i in range(state_for_next_prior.shape[0]):
                store_value(state_for_next_prior[i], final_tensor[i], index[i])
                all_selected_tensors.append(final_tensor[i][index[i]])


            final_selected_state = torch.cat(all_selected_tensors, dim=0)
            final_selected_state = final_selected_state.view(-1,final_selected_state.size(2), final_selected_state.size(3))
            #print(final_selected_tensor.shape, prediction_raw.shape)
            final_selected_state = final_selected_state[:,1,1:1+model.observation_dim]
            #print(state_for_next_prior.shape)

            state_for_next_prior = torch.unique(final_selected_state, dim=0)
            #print(state_for_next_prior.shape)
            #state_for_next_prior = prediction_raw[:,1,1:1+model.observation_dim]
            #state_for_next_prior = torch.unique(state_for_next_prior, dim=0)
    print("inference time,",time.time() - start)
    #mcts_instance = MCTS(state, state_dict, tree_gamma, prior, model, 1, 1, mse_factor, max_depth - 1)
    mcts_instance = MCTS(state, state_dict, tree_gamma, prior, model, int(n_action*action_percent), n_expand, mse_factor, max_depth)

    start_time = time.time()
    mcts_instance.search(mcts_itr)
    #print(mcts_instance.Qsa.values())
    values_list = list(mcts_instance.Qsa.values())

    # Stack the tensors into one tensor
    values_tensor = torch.stack(values_list)

    # Compute the mean and standard deviation
    value_mean = torch.mean(values_tensor)
    value_std = torch.std(values_tensor)
    value_max = torch.max(values_tensor)
    value_min = torch.min(values_tensor)
    # Print the results
    print("Mean:", value_mean.item(), "Std:", value_std.item(), "Max:", value_max.item(), "Min:", value_min.item())
    # Stop the timer
    end_time = time.time()

    # Calculate the running time
    running_time = end_time - start_time
    print("search time,", running_time)
    best_action = mcts_instance.best_action().long()
    print(best_action)
    prediction_raw = model.decode_from_indices(best_action.view(1, -1), state).squeeze(0)
    return prediction_raw.cpu().numpy()

@torch.no_grad()
def beam_with_prior_MTP(prior, model, x, denormalize_rew, denormalize_val, discount, steps,
                    beam_width, n_expand, n_action, b_percent, action_percent,
                    pw_alpha, mcts_itr, prob_threshold=0.05, likelihood_weight=5e2, prob_acc="product", return_info=False):
    contex = None
    #print(x.shape, prior.observation_dim)
    #print(x)
    state = x[:, 0, :prior.observation_dim]
    #print(state)
    acc_probs = torch.zeros([1]).to(x)
    acc_oods = torch.zeros([1]).to(x)
    info = {}
    values_track = None
    steps = 1
    for step in range(steps//model.latent_step):
        if step == 0:
            #logits, _ = prior(None, state) # [B x t x K]
            #start_time = time.time()
            logits, _ = prior(None, state)  # [B x t x K]
            #print(logits.shape)
            #prior_time = time.time() - start_time
            #print("prior function 1:", prior_time)
        else:
            contex = None
            logits, _ = prior(None, state_for_next_prior)
        #print("state shape:",state.unsqueeze(0))
        #logits = logits[:, -1, :]
        #print(logits.shape)
        #probs = torch.softmax(logits[:, :, -1, :], dim=-1)

        #probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        #print(probs.shape)
        #print(probs)
        # = torch.log(probs)
        #nb_samples = 64 if step == 0 else n_expand
        head0_logits = logits[:, 0, -1, :]  # Shape: [B, vocab_size]
        probs = torch.softmax(head0_logits, dim=-1)  # Compute probabilities
        nb_samples = beam_width if step == 0 else n_expand
        samples = torch.multinomial(probs, num_samples=nb_samples, replacement=False) # [B, M]
        #print(samples.shape)
        contex = samples.reshape([-1, 1]) #[(B*M) x t]
        #print(contex.shape)
        #start_time = time.time()
        logits, _ = prior(contex, state)
        #prior_time = time.time() - start_time
        #print("prior function 2:", prior_time)
        #print(logits.shape)
        probs = torch.softmax(logits[:, :, -1, :], dim=-1)  # [B x K]
        #print(probs.shape)
        #log_probs = torch.log(probs)
        #samples = torch.multinomial(probs, num_samples=n_expand, replacement=True)  # [B, M]
        #contex = torch.cat([torch.repeat_interleave(contex, n_expand, 0), samples.reshape([-1, 1])], dim=1)
        #samples_log_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(log_probs, samples)]) # [B, M]


        #print(samples_log_prob)
        #state_for_next_prior = prediction_raw[:, 0, 1:model.observation_dim + 1]
        #print(contex.shape, contex)



        #if prob_acc in ["product", "expect"]:
        #    acc_probs = acc_probs.repeat_interleave(nb_samples*n_expand, 0) + samples_log_prob.reshape([-1])
        #start_time = time.time()
        prediction_raw = model.decode_from_indices(contex, state)
        #prior_time = time.time() - start_time
        #print("decode function:", prior_time)
        prediction = prediction_raw.reshape([-1, n_expand, 2, 3*model.action_dim+model.observation_dim+2])
        prediction_output = prediction[:, 0, :, :]
        V_t = prediction_raw[:, 1, 0]
        a_v = prediction_raw[:, 0, 0]
        if denormalize_val is not None:
            #V_t = denormalize_val(V_t).reshape([contex.shape[0], -1])
            V_t = denormalize_val(V_t)
            a_v = denormalize_val(a_v)
        #values = V_t[:, -1] * discounts[:, -1]
        #values = V_t[:, -1]
        #values = V_t * torch.exp(acc_probs)
        values = V_t
        values = values.reshape([-1, n_expand])
        a_v = a_v.view(beam_width, n_expand)

        # Select the first value from each group of 4 values
        a_v = a_v[:, 0]
        #print(a_v)
        # Reshape tensor_64 to [64, 1]
        a_v = a_v.view(beam_width, 1)

        # Concatenate along the second dimension (dim=1)
        result_tensor = torch.cat((values, a_v), dim=1)
        #print(result_tensor.shape)
        values_track = result_tensor.mean(dim=1)
        #print(result_tensor.shape)
        nb_top = beam_width if step < (steps // model.latent_step - 1) else 1

        #else:
        #    nb_top = 1
        #print(values_track.shape, nb_top)
        if prob_acc == "expect":
            values_with_b, index = torch.topk(values_track, nb_top)
            #print(index, result_tensor)
        else:
            values_with_b, index = torch.topk(values_track, nb_top)
        if return_info:
            info[(step+1)*model.latent_step] = dict(predictions=prediction_raw.cpu(), returns=values.cpu(),
                                                    latent_codes=contex.cpu(), log_probs=acc_probs.cpu(),
                                                    objectives=values+likelihood_bonus, index=index.cpu())
    optimal = prediction_output[index[0]]
    print(f"predicted max value {values_track[index[0]]}")
    if return_info:
        return optimal.cpu().numpy(), info
    else:
        return optimal.cpu().numpy()



@torch.no_grad()
def beam_with_uniform(prior, model, x, denormalize_rew, denormalize_val, discount, steps,
                    beam_width, n_expand,  prob_threshold=0.05):
    contex = None
    state = x[:, 0, :model.observation_dim]
    acc_probs = torch.ones([1]).to(x)
    for step in range(steps//model.latent_step):
        logits, _ = prior(contex, state) # [B x t x K]
        probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        nb_samples = beam_width * n_expand if step == 0 else n_expand
        valid = probs > prob_threshold
        samples = torch.multinomial(valid/valid.sum(dim=-1), num_samples=nb_samples, replacement=True) # [B, M]
        samples_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(probs, samples)]) # [B, M]
        acc_probs = acc_probs.repeat_interleave(nb_samples, 0) * samples_prob.reshape([-1])
        if not contex is None:
            contex = torch.cat([torch.repeat_interleave(contex, nb_samples, 0), samples.reshape([-1, 1])],
                               dim=1)
        else:
            contex = samples.reshape([-1, step+1]) # [(B*M) x t]

        prediction_raw = model.decode_from_indices(contex, state)
        prediction = prediction_raw.reshape([-1, model.transition_dim])
        r_t, V_t = prediction[:, -3], prediction[:, -2]

        if denormalize_rew is not None:
            r_t = denormalize_rew(r_t).reshape([contex.shape[0], -1])
        if denormalize_val is not None:
            V_t = denormalize_val(V_t).reshape([contex.shape[0], -1])


        discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
        values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1] * discounts[:,-1]
        nb_top = beam_width if step < (steps//model.latent_step-1) else 1
        values, index = torch.topk(values, nb_top)
        contex = contex[index]
        acc_probs = acc_probs[index]

    optimal = prediction_raw[index[0]]
    print(f"predicted max value {values[0]}")
    return optimal.cpu().numpy()

@torch.no_grad()
def beam_mimic(prior, model, x, denormalize_rew, denormalize_val, discount, steps,
                    beam_width, n_expand,  prob_threshold=0.05):
    contex = None
    state = x[:, 0, :model.observation_dim]
    acc_probs = torch.ones([1]).to(x)
    for step in range(steps//model.latent_step):
        logits, _ = prior(contex, state) # [B x t x K]
        probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        nb_samples = beam_width * n_expand if step == 0 else n_expand
        samples = torch.multinomial(probs, num_samples=nb_samples, replacement=True) # [B, M]
        samples_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(probs, samples)]) # [B, M]
        acc_probs = acc_probs.repeat_interleave(nb_samples, 0) * samples_prob.reshape([-1])
        if not contex is None:
            contex = torch.cat([torch.repeat_interleave(contex, nb_samples, 0), samples.reshape([-1, 1])],
                               dim=1)
        else:
            contex = samples.reshape([-1, step+1]) # [(B*M) x t]

        nb_top = beam_width if step < (steps//model.latent_step-1) else 1
        values, index = torch.topk(acc_probs, nb_top)
        contex = contex[index]
        acc_probs = acc_probs[index]

    prediction_raw = model.decode_from_indices(contex, state)
    optimal = prediction_raw[0]
    print(f"value {values[0]}, prob {acc_probs[0]}")
    return optimal.cpu().numpy()


@torch.no_grad()
def enumerate_all(model, x, denormalize_rew, denormalize_val, discount):
    indicies = torch.range(0, model.model.K-1, device=x.device, dtype=torch.int32)
    prediction_raw = model.decode_from_indices(indicies, x[:, 0, :model.observation_dim])
    prediction = prediction_raw.reshape([-1, model.transition_dim])

    r_t, V_t = prediction[:, -2], prediction[:, -1]
    if denormalize_rew is not None:
        r_t = denormalize_rew(r_t).reshape([indicies.shape[0], -1])
    if denormalize_val is not None:
        V_t = denormalize_val(V_t).reshape([indicies.shape[0], -1])

    discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
    values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1] * discounts[:,-1]
    optimal = prediction_raw[values.argmax()]
    return optimal.cpu().numpy()


@torch.no_grad()
def propose_plan_continuous(model, x):
    latent = torch.zeros([1, model.trajectory_embd], device="cuda")
    prediction = model.decode(latent, x[:, 0, :model.observation_dim])
    prediction = prediction.reshape([-1, model.transition_dim])
    return prediction.cpu().numpy()