from collections import defaultdict
import torch
import time

#from scripts.plan import start_time

REWARD_DIM = VALUE_DIM = 1
#from .mcts import *
#from .mcts_beam import *
from .mcts_expand import *
import networkx as nx
import matplotlib.pyplot as plt
#from collections import Counter
import torch.nn.functional as F
import time

@torch.no_grad()
def model_rollout_continuous(model, x, latent, denormalize_rew, denormalize_val, discount, prob_penalty_weight=1e4):
    prediction = model.decode(latent, x[:, -1, :model.observation_dim])
    prediction = prediction.reshape([-1, model.transition_dim])

    r_t, V_t = prediction[:, -3], prediction[:, -2]
    if denormalize_rew is not None:
        r_t = denormalize_rew(r_t).reshape([x.shape[0], -1])
    if denormalize_val is not None:
        V_t = denormalize_val(V_t).reshape([x.shape[0], -1])

    # discounts with terminal flag
    terminal = prediction[:, -1].reshape([x.shape[0], -1])
    discounts = torch.cumprod(torch.ones_like(r_t) * discount * (1-terminal), dim=-1)
    values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1] * discounts[:,-1]
    prob_penalty = prob_penalty_weight * torch.mean(torch.square(latent), dim=-1)
    objective = values - prob_penalty
    return objective.cpu().numpy(), prediction.cpu().numpy()


import numpy as np


@torch.no_grad()
def sample(model, x, denormalize_rew, denormalize_val, discount, steps, nb_samples=4096, rounds=8):
    indicies = torch.randint(0, model.model.K-1, size=[nb_samples, steps // model.latent_step],
                             device=x.device, dtype=torch.int32)
    prediction_raw = model.decode_from_indices(indicies, x[:, 0, :model.observation_dim])
    prediction = prediction_raw.reshape([-1, model.transition_dim])

    r_t, V_t = prediction[:, -3], prediction[:, -2]
    if denormalize_rew is not None:
        r_t = denormalize_rew(r_t).reshape([indicies.shape[0], -1])
    if denormalize_val is not None:
        V_t = denormalize_val(V_t).reshape([indicies.shape[0], -1])

    discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
    values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1]*discounts[:, -1]
    optimal = prediction_raw[values.argmax()]
    print(values.max().item())
    return optimal.cpu().numpy()


@torch.no_grad()
def sample_with_prior(prior, model, x, denormalize_rew, denormalize_val, discount, steps, nb_samples=4096, rounds=8,
                      likelihood_weight=5e2, prob_threshold=0.05, uniform=False, return_info=False):
    state = x[:, 0, :model.observation_dim]
    optimals = []
    optimal_values = []
    info = defaultdict(list)
    for round in range(rounds):
        contex = None
        acc_probs = torch.zeros([1]).to(x)
        for step in range(steps//model.latent_step):
            logits, _ = prior(contex, state) # [B x t x K]
            probs = raw_probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
            log_probs = torch.log(probs)
            if uniform:
                valid = probs > 0
                probs = valid/valid.sum(dim=-1)[:, None]
            if step == 0:
                samples = torch.multinomial(probs, num_samples=nb_samples//rounds, replacement=True) # [B, M]
            else:
                samples = torch.multinomial(probs, num_samples=1, replacement=True)  # [B, M]
            samples_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(log_probs, samples)]) # [B, M]
            acc_probs = acc_probs + samples_prob.reshape([-1])
            if not contex is None:
                contex = torch.cat([contex, samples.reshape([-1, 1])], dim=1)
            else:
                contex = samples.reshape([-1, step+1]) # [(B*M) x t]
        prediction_raw = model.decode_from_indices(contex, state)
        prediction = prediction_raw.reshape([-1, model.transition_dim])

        r_t = prediction[:, -3]
        V_t = prediction[:, -2]
        terminals = prediction[:, -1].reshape([contex.shape[0], -1])
        if denormalize_rew is not None:
            r_t = denormalize_rew(r_t).reshape([contex.shape[0], -1])
        if denormalize_val is not None:
            V_t = denormalize_val(V_t).reshape([contex.shape[0], -1])

        discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
        values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1] * discounts[:,-1]
        likelihood_bonus = likelihood_weight*torch.clamp(acc_probs, -1e5, np.log(prob_threshold)*(steps//model.latent_step))
        info["log_probs"].append(acc_probs.cpu().numpy())
        info["returns"].append(values.cpu().numpy())
        info["predictions"].append(prediction_raw.cpu().numpy())
        info["objectives"].append(values.cpu().numpy() + likelihood_bonus.cpu().numpy())
        info["latent_codes"].append(contex.cpu().numpy())
        max_idx = (values+likelihood_bonus).argmax()
        optimal_value = values[max_idx]
        optimal = prediction_raw[max_idx]
        optimals.append(optimal)
        optimal_values.append(optimal_value.item())

    for key, val in info.items():
        info[key] = np.concatenate(val, axis=0)
    max_idx = np.array(optimal_values).argmax()
    optimal = optimals[max_idx]
    print(f"predicted max value {optimal_values[max_idx]}")
    if return_info:
        return optimal.cpu().numpy(), info
    else:
        return optimal.cpu().numpy()


@torch.no_grad()
def sample_with_prior_tree(prior, model, x, denormalize_rew, denormalize_val, discount, steps, samples_per_latent=16, likelihood_weight=0.0):
    contex = None
    state = x[:, 0, :model.observation_dim]
    acc_probs = torch.ones([1]).to(x)
    for step in range(steps//model.latent_step):
        logits, _ = prior(contex, state) # [B x t x K]
        probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        samples = torch.multinomial(probs, num_samples=samples_per_latent, replacement=True) # [B, M]
        samples_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(probs, samples)]) # [B, M]
        acc_probs = acc_probs.repeat_interleave(samples_per_latent, 0) * samples_prob.reshape([-1])
        if not contex is None:
            contex = torch.cat([torch.repeat_interleave(contex, samples_per_latent, 0), samples.reshape([-1, 1])],
                               dim=1)
        else:
            contex = samples.reshape([-1, step+1]) # [(B*M) x t]

    prediction_raw = model.decode_from_indices(contex, state)
    prediction = prediction_raw.reshape([-1, model.transition_dim])

    r_t, V_t = prediction[:, -3], prediction[:, -2]


    if denormalize_rew is not None:
        r_t = denormalize_rew(r_t).reshape([contex.shape[0], -1])
    if denormalize_val is not None:
        V_t = denormalize_val(V_t).reshape([contex.shape[0], -1])

    discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
    values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1]*discounts[:, -1]
    likelihood_bouns = likelihood_weight*torch.log(acc_probs)
    max_idx = (values+likelihood_bouns).argmax()
    optimal = prediction_raw[max_idx]
    print(f"predicted max value {values[max_idx]}, likelihood {acc_probs[max_idx]} with bouns {likelihood_bouns[max_idx]}")
    return optimal.cpu().numpy()

@torch.no_grad()
def beam_with_prior_ood(prior, model, x, denormalize_rew, denormalize_val, discount, steps,
                    beam_width, n_expand, prob_threshold=0.05, likelihood_weight=5e2, prob_acc="product",
                    return_info=False):
    contex = None
    # print(x.shape, prior.observation_dim)
    # print(x)
    state = x[:, 0, :prior.observation_dim]
    # print(state)
    acc_probs = torch.zeros([1]).to(x)
    acc_oods = torch.zeros([1]).to(x)
    info = {}
    for step in range(steps // model.latent_step):
        #logits, _ = prior(contex, state)  # [B x t x K]
        # logits = logits[:, -1, :]
        start_time = time.time()
        logits, _ = prior(context, state)  # [B x t x K]
        prior_time = time.time() - start_time
        print("prior function 1:", prior_time)
        probs = torch.softmax(logits[:, -1, :], dim=-1)  # [B x K]
        log_probs = torch.log(probs)
        nb_samples = 512 if step == 0 else n_expand
        samples = torch.multinomial(probs, num_samples=nb_samples, replacement=True)  # [B, M]
        samples_log_prob = torch.cat(
            [torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(log_probs, samples)])  # [B, M]

        if prob_acc in ["product", "expect"]:
            acc_probs = acc_probs.repeat_interleave(nb_samples, 0) + samples_log_prob.reshape([-1])
        elif prob_acc == "min":
            acc_probs = torch.minimum(acc_probs.repeat_interleave(nb_samples, 0), samples_log_prob.reshape([-1]))
        if not contex is None:
            contex = torch.cat([torch.repeat_interleave(contex, nb_samples, 0), samples.reshape([-1, 1])],
                               dim=1)
        else:
            contex = samples.reshape([-1, step + 1])  # [(B*M) x t]
        prediction_raw = model.decode_from_indices(contex, state)
        prediction = prediction_raw.reshape([-1, model.action_dim + model.observation_dim + 3])
        r_t = prediction[:, -3]
        V_t = prediction[:, -2]


        if step >= 1:
            #print(prediction_raw.shape)
            predicted_first_state = prediction_raw[:, step * model.latent_step, :model.observation_dim]
            # print(contex[:, step].reshape([-1,1]).shape, predicted_first_state.shape)
            # init_context
            decoded_state = model.decode_for_ood(contex[:, step].reshape([-1, 1]), predicted_first_state)
            # print(contex[:, step].reshape([-1,1]).shape, predicted_first_state)
            decoded_state_compare = decoded_state[:, 0, :model.observation_dim]
        else:
            predicted_first_state = prediction_raw[:, 0, :model.observation_dim]

            decoded_state_compare = state.expand_as(predicted_first_state)
        # Compute MSE loss between the predicted first state and the actual state for each row
        mse_loss_per_element = F.mse_loss(predicted_first_state, decoded_state_compare, reduction='none')
        mse_loss_per_example = mse_loss_per_element.mean(dim=1)
        acc_oods = acc_oods.repeat_interleave(nb_samples, 0) + mse_loss_per_example
        if denormalize_rew is not None:
            r_t = denormalize_rew(r_t).reshape([contex.shape[0], -1])
        if denormalize_val is not None:
            V_t = denormalize_val(V_t).reshape([contex.shape[0], -1])
        # print(mse_loss_per_example.shape)
        # print("after,",r_t.shape)
        # mse_loss = mse_loss.reshape([contex.shape[0], -1])
        discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
        values = torch.sum(r_t[:, :-1] * discounts[:, :-1], dim=-1) + V_t[:, -1] * discounts[:, -1]

        if prob_acc == "product":
            likelihood_bonus = likelihood_weight * torch.clamp(acc_probs, -1e5,
                                                               np.log(prob_threshold) * (steps // model.latent_step))
        elif prob_acc == "min":
            likelihood_bonus = likelihood_weight * torch.clamp(acc_probs, 0, np.log(prob_threshold))
        nb_top = beam_width if step < (steps // model.latent_step - 1) else 1
        # nb_top = beam_width
        if prob_acc == "expect":
            values_with_b, index = torch.topk(values * torch.exp(acc_probs), nb_top)
        else:
            values_with_b, index = torch.topk(values, nb_top)
            # values_with_b, index = torch.topk(values+likelihood_bonus, nb_top)
        if return_info:
            info[(step + 1) * model.latent_step] = dict(predictions=prediction_raw.cpu(), returns=values.cpu(),
                                                        latent_codes=contex.cpu(), log_probs=acc_probs.cpu(),
                                                        objectives=values + likelihood_bonus, index=index.cpu())
        contex = contex[index]
        acc_probs = acc_probs[index]
        acc_oods = acc_oods[index]
    optimal = prediction_raw[index[0]]
    if return_info:
        return optimal.cpu().numpy(), info
    else:
        return optimal.cpu().numpy()


@torch.no_grad()
def mcts_with_prior(prior, model, x, denormalize_rew, denormalize_val, discount, steps,
                    beam_width, n_expand, n_action, b_percent, action_percent,
                    pw_alpha, mcts_itr, prob_threshold=0.05, likelihood_weight=5e2, prob_acc="product", return_info=False):
    state = x[:, 0, :prior.observation_dim]
    def tensor_to_tuple(tensor):
        return tuple(tensor.cpu().numpy().flatten())

    # Initialize the outer dictionary
    state_dict = {}
    # Store the value in the nested dictionary
    def store_value(state, action_matrix, index):
        state_key = tensor_to_tuple(state)
        #print(state_key)
        if state_key not in state_dict:
            state_dict[state_key] = [action_matrix,index]
    #for step in range(steps//model.latent_step):
    import time
    start = time.time()
    max_depth = 3
    tree_gamma = 0.99
    action_sequence = 3
    mse_factor = 1
    for step in range(max_depth):
        if step == 0:
            logits, _ = prior(None, state) # [B x t x K]
        else:
            #contex = None
            logits, _ = prior(None, state_for_next_prior)
        action_probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        nb_samples = beam_width if step == 0 else n_action

        action_samples = torch.multinomial(action_probs, num_samples=nb_samples, replacement=False) # [B, M]
        # Gather the corresponding probabilities for the sampled actions
        action_probs_sampled = torch.gather(action_probs, 1, action_samples)
        action_contex = action_samples.reshape([-1, 1]) # [(B*M) x t]

        if step == 0:
            logits, _ = prior(action_contex, state)
        else:
            #print(state_for_next_prior.shape)
            state_for_next_prior_expanded = state_for_next_prior.repeat_interleave(nb_samples, 0)
            logits, _ = prior(action_contex, state_for_next_prior_expanded)
        probs = torch.softmax(logits[:, -1, :], dim=-1)  # [B x K]
        log_probs = torch.log(probs)
        samples = torch.multinomial(probs, num_samples=n_expand, replacement=True)  # [B, M]
        contex = torch.cat([torch.repeat_interleave(action_contex, n_expand, 0), samples.reshape([-1, 1])], dim=1)
        if step == 0:

            prediction_raw = model.decode_from_indices(contex, state)
            #print(prediction_raw.shape)
            reshaped_prediction_raw = prediction_raw.view(nb_samples, n_expand, 2, -1)
            expanded_action_contex = action_contex.unsqueeze(1).unsqueeze(2).expand(nb_samples, n_expand, 2, 1)
            predicted_first_state = prediction_raw[:, 0, 1:1+model.observation_dim]
            decoded_state_compare = state.expand_as(predicted_first_state)

            mse_loss_per_element = F.mse_loss(predicted_first_state, decoded_state_compare, reduction='none')
            mse_loss_per_example = mse_loss_per_element.mean(dim=1)
            mse_loss_per_example = mse_loss_per_example.view(nb_samples, n_expand)
            expanded_mse_loss = mse_loss_per_example.unsqueeze(2).unsqueeze(3).expand(nb_samples, n_expand, 2, 1)
            expanded_prior_probs = action_probs_sampled.reshape([-1, 1]).unsqueeze(2).unsqueeze(3).expand(nb_samples, n_expand, 2, 1)
            concatenated_tensor = torch.cat([reshaped_prediction_raw, expanded_prior_probs], dim=3)
            concatenated_tensor = torch.cat([concatenated_tensor,expanded_action_contex], dim=3)
            final_tensor = torch.cat([concatenated_tensor, expanded_mse_loss], dim=3)
            expansion_values = final_tensor[:, :, 1, 0]
            action_values = final_tensor[:, 0, 0, 0].view(-1, 1)
            action_mse = final_tensor[:, 0, 0, -1]
            expansion_values *= (tree_gamma ** action_sequence)
            mean_values = torch.cat((expansion_values, action_values), dim=1)
            mean_values = mean_values.mean(dim=1)

            mean_values = mean_values - mse_factor*action_mse
            k = int(mean_values.size(0)*b_percent) if int(mean_values.size(0)*b_percent) >=1 else 1
            values_with_b, index = torch.topk(mean_values, k)
            store_value(state, final_tensor, index)

            state_for_next_prior = final_tensor[index,:,1,1:1+model.observation_dim]
            state_for_next_prior = state_for_next_prior.view(-1, state_for_next_prior.size(-1))
            state_for_next_prior = torch.unique(state_for_next_prior, dim=0)

        else:

            state_for_next_prior_expanded = state_for_next_prior_expanded.repeat_interleave(n_expand, 0)
            prediction_raw = model.decode_for_ood(contex, state_for_next_prior_expanded)
            predicted_first_state = prediction_raw[:, 0,1:model.observation_dim+1]
            decoded_state_compare = state_for_next_prior_expanded


            mse_loss_per_element = F.mse_loss(predicted_first_state, decoded_state_compare, reduction='none')
            mse_loss_per_example = mse_loss_per_element.mean(dim=1)
            mse_loss_per_example = mse_loss_per_example.view(-1, nb_samples, n_expand)

            expanded_mse_loss = mse_loss_per_example.unsqueeze(3).unsqueeze(4).expand(-1, nb_samples, n_expand, 2, 1)
            #print(expanded_mse_loss[0])

            reshaped_prediction_raw = prediction_raw.view(-1,nb_samples, n_expand, 2, model.observation_dim + 3*model.action_dim+ 2)
            #print(reshaped_prediction_raw.shape)
            action_contex = action_contex.view(-1, nb_samples, 1)
            action_probs_sampled = action_probs_sampled.view(-1, nb_samples, 1)
            expanded_prior_probs = action_probs_sampled.unsqueeze(3).unsqueeze(4).expand(-1, nb_samples,
                                                                                             n_expand, 2, 1)

            #print("reshaped_prediction_raw", reshaped_prediction_raw.shape)
            expanded_action_contex = action_contex.unsqueeze(3).unsqueeze(4).expand(-1, nb_samples,
                                                                                             n_expand, 2, 1)

            concatenated_tensor = torch.cat([reshaped_prediction_raw, expanded_prior_probs], dim=4)
            concatenated_tensor = torch.cat([concatenated_tensor,expanded_action_contex], dim=4)
            final_tensor = torch.cat([concatenated_tensor, expanded_mse_loss], dim=4)

            expansion_values = final_tensor[:, :, :, 1, 0]
            action_values = final_tensor[:, :, 0, 0, 0].view(state_for_next_prior.shape[0], -1, 1)
            action_mse = final_tensor[:, :, 0, 0, -1]
            expansion_values *= (tree_gamma ** action_sequence)
            mean_values = torch.cat((expansion_values, action_values), dim=2)

            mean_values = mean_values.mean(dim=2) -  mse_factor*action_mse
            #mean_values = mean_values.mean(dim=2)
            k = max(int(mean_values.size(1) * action_percent), 1)
            all_selected_tensors = []
            values_with_b, index = torch.topk(mean_values, k)
            #start = time.time()
            #print(index.shape)
            #print()
            for i in range(state_for_next_prior.shape[0]):
                store_value(state_for_next_prior[i], final_tensor[i], index[i])
                all_selected_tensors.append(final_tensor[i][index[i]])


            final_selected_state = torch.cat(all_selected_tensors, dim=0)
            final_selected_state = final_selected_state.view(-1,final_selected_state.size(2), final_selected_state.size(3))
            #print(final_selected_tensor.shape, prediction_raw.shape)
            final_selected_state = final_selected_state[:,1,1:1+model.observation_dim]
            #print(state_for_next_prior.shape)

            state_for_next_prior = torch.unique(final_selected_state, dim=0)
            #print(state_for_next_prior.shape)
            #state_for_next_prior = prediction_raw[:,1,1:1+model.observation_dim]
            #state_for_next_prior = torch.unique(state_for_next_prior, dim=0)
    print("inference time,",time.time() - start)
    mcts_instance = MCTS(state, state_dict, tree_gamma, prior, model, int(n_action*action_percent), n_expand, mse_factor, max_depth-1)

    start_time = time.time()
    mcts_instance.search(mcts_itr)
    #print(mcts_instance.Qsa.values())
    values_list = list(mcts_instance.Qsa.values())

    # Stack the tensors into one tensor
    values_tensor = torch.stack(values_list)

    # Compute the mean and standard deviation
    value_mean = torch.mean(values_tensor)
    value_std = torch.std(values_tensor)
    value_max = torch.max(values_tensor)
    value_min = torch.min(values_tensor)
    # Print the results
    print("Mean:", value_mean.item(), "Std:", value_std.item(), "Max:", value_max.item(), "Min:", value_min.item())
    # Stop the timer
    end_time = time.time()

    # Calculate the running time
    running_time = end_time - start_time
    print("search time,", running_time)
    best_action = mcts_instance.best_action().long()
    print(best_action)
    prediction_raw = model.decode_from_indices(best_action.view(1, -1), state).squeeze(0)
    return prediction_raw.cpu().numpy()


def beam_with_prior_copy_without_beam(prior, model, x, denormalize_rew, denormalize_val, discount, steps,
                    beam_width, n_expand, prob_threshold=0.05, likelihood_weight=5e2, prob_acc="product", return_info=False):
    contex = None
    #print(x.shape, prior.observation_dim)
    #print(x)
    state = x[:, 0, :prior.observation_dim]
    #print(state)
    acc_probs = torch.zeros([1]).to(x)
    acc_oods = torch.zeros([1]).to(x)
    def tensor_to_tuple(tensor):
        return tuple(tensor.cpu().numpy().flatten())

    # Initialize the outer dictionary
    state_dict = {}

    # Store the value in the nested dictionary
    def store_value(state, action_matrix):
        state_key = tensor_to_tuple(state)
        #print(state_key)
        if state_key not in state_dict:
            state_dict[state_key] = action_matrix


    # Retrieve the value from the nested dictionary
    def retrieve_value(state, action):
        state_key = tensor_to_tuple(state)
        print(state_key)
        action_key = tensor_to_tuple(action)

        if state_key in state_dict and action_key in state_dict[state_key]:
            return state_dict[state_key][action_key]
        else:
            return None

    #for step in range(steps//model.latent_step):
    import time
    start = time.time()
    max_depth = 3
    for step in range(max_depth):
        if step == 0:
            logits, _ = prior(None, state) # [B x t x K]
        else:
            #contex = None
            logits, _ = prior(None, state_for_next_prior)
        action_probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        #log_probs = torch.log(probs)
        nb_samples = beam_width if step == 0 else 2
        action_samples = torch.multinomial(action_probs, num_samples=nb_samples, replacement=False) # [B, M]
        action_contex = action_samples.reshape([-1, 1]) # [(B*M) x t]
        #sampling next states
        if step == 0:
            logits, _ = prior(action_contex, state)
        else:
            #print(state_for_next_prior.shape)
            state_for_next_prior_expanded = state_for_next_prior.repeat_interleave(nb_samples, 0)
            logits, _ = prior(action_contex, state_for_next_prior_expanded)
        probs = torch.softmax(logits[:, -1, :], dim=-1)  # [B x K]
        log_probs = torch.log(probs)
        samples = torch.multinomial(probs, num_samples=n_expand, replacement=True)  # [B, M]
        contex = torch.cat([torch.repeat_interleave(action_contex, n_expand, 0), samples.reshape([-1, 1])], dim=1)

        if step == 0:
            #print(action_contex.flatten())
            #store_value(state, action_contex)
            prediction_raw = model.decode_from_indices(contex, state)
            #print(action_contex.shape, prediction_raw.shape)
            reshaped_prediction_raw = prediction_raw.view(nb_samples, n_expand, 2, 22)
            expanded_action_contex = action_contex.unsqueeze(1).unsqueeze(2).expand(nb_samples, n_expand, 2, 1)
            predicted_first_state = prediction_raw[:, 0, 1:1+model.observation_dim]
            decoded_state_compare = state.expand_as(predicted_first_state)

            mse_loss_per_element = F.mse_loss(predicted_first_state, decoded_state_compare, reduction='none')
            mse_loss_per_example = mse_loss_per_element.mean(dim=1)
            mse_loss_per_example = mse_loss_per_example.view(nb_samples, n_expand)
            #print(mse_loss_per_example.shape)
            expanded_mse_loss = mse_loss_per_example.unsqueeze(2).unsqueeze(3).expand(nb_samples, n_expand, 2, 1)
            #print(expanded_mse_loss[30])

            concatenated_tensor = torch.cat([reshaped_prediction_raw,expanded_action_contex], dim=3)
            final_tensor = torch.cat([concatenated_tensor, expanded_mse_loss], dim=3)

            store_value(state, final_tensor)
            state_for_next_prior = prediction_raw[:,1,1:1+model.observation_dim]
            #print(state_for_next_prior.shape)
        #else:
            state_for_next_prior = torch.unique(state_for_next_prior, dim=0)

        else:

            state_for_next_prior_expanded = state_for_next_prior_expanded.repeat_interleave(n_expand, 0)
            prediction_raw = model.decode_for_ood(contex, state_for_next_prior_expanded)
            predicted_first_state = prediction_raw[:, 0,1:model.observation_dim+1]
            decoded_state_compare = state_for_next_prior_expanded


            mse_loss_per_element = F.mse_loss(predicted_first_state, decoded_state_compare, reduction='none')
            mse_loss_per_example = mse_loss_per_element.mean(dim=1)
            mse_loss_per_example = mse_loss_per_example.view(-1, nb_samples, n_expand)
            expanded_mse_loss = mse_loss_per_example.unsqueeze(3).unsqueeze(4).expand(-1, nb_samples, n_expand, 2, 1)
            #print(expanded_mse_loss[0])

            reshaped_prediction_raw = prediction_raw.view(-1,nb_samples, n_expand, 2, 22)
            action_contex = action_contex.view(-1, nb_samples, 1)
            #print(action_contex.shape)
            #print("reshaped_prediction_raw", reshaped_prediction_raw.shape)
            expanded_action_contex = action_contex.unsqueeze(3).unsqueeze(4).expand(-1, nb_samples,
                                                                                             n_expand, 2, 1)
            concatenated_tensor = torch.cat([reshaped_prediction_raw,expanded_action_contex], dim=4)
            final_tensor = torch.cat([concatenated_tensor, expanded_mse_loss], dim=4)
            for i in range(state_for_next_prior.shape[0]):
                # Convert the row in state_for_next_prior to a tuple (as tuples are hashable and can be dictionary keys)
                store_value(state_for_next_prior[i], final_tensor[i])

            state_for_next_prior = prediction_raw[:,1,1:1+model.observation_dim]
            state_for_next_prior = torch.unique(state_for_next_prior, dim=0)

    print("inference time,", time.time() - start)
    tree_gamma = 0.99
    mcts_instance = MCTS(state, state_dict, tree_gamma, prior, model, max_depth-1)

    start_time = time.time()
    mcts_instance.search(200)
    # Stop the timer
    end_time = time.time()

    # Calculate the running time
    running_time = end_time - start_time
    print("search time,", running_time)
    best_action = mcts_instance.best_action().long()
    print(best_action)
    prediction_raw = model.decode_from_indices(best_action.view(1, -1), state).squeeze(0)
    return prediction_raw.cpu().numpy()




@torch.no_grad()
def beam_with_prior_original(prior, model, x, denormalize_rew, denormalize_val, discount, steps,
                    beam_width, n_expand, prob_threshold=0.05, likelihood_weight=5e2, prob_acc="expect", return_info=False):
    contex = None
    state = x[:, 0, :prior.observation_dim]
    acc_probs = torch.zeros([1]).to(x)
    info = {}
    steps = 1
    for step in range(steps//model.latent_step):
        logits, _ = prior(contex, state) # [B x t x K]
        #logits = logits[:, -1, :]
        probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        log_probs = torch.log(probs)
        #print(log_probs)
        nb_samples = beam_width * n_expand if step == 0 else n_expand
        samples = torch.multinomial(probs, num_samples=nb_samples, replacement=True)  # [B, M]
        samples_log_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(log_probs, samples)])
        if prob_acc in ["product", "expect"]:
            acc_probs = acc_probs.repeat_interleave(nb_samples, 0) + samples_log_prob.reshape([-1])
        contex = samples.reshape([-1, step + 1])  # [(B*M) x t]
        prediction_raw = model.decode_from_indices(contex, state)
        #print(prediction_raw.shape)
        prediction = prediction_raw.reshape([-1, 3*model.action_dim+model.observation_dim+2])
        V_t = prediction[:, 0]


        if prob_acc == "product":
            likelihood_bonus = likelihood_weight*torch.clamp(acc_probs, -1e5, np.log(prob_threshold)*(steps//model.latent_step))
        elif prob_acc == "min":
            likelihood_bonus = likelihood_weight*torch.clamp(acc_probs, 0, np.log(prob_threshold))

        if denormalize_val is not None:
            V_t = denormalize_val(V_t).reshape([contex.shape[0], -1])
        nb_top = beam_width if step < (steps//model.latent_step-1) else 1
        values = V_t[:,-1]
        if prob_acc == "expect":
            values_with_b, index = torch.topk(values*torch.exp(acc_probs), nb_top)
        else:
            values_with_b, index = torch.topk(values+likelihood_bonus, nb_top)
    #print(len(index))
    #print(index)
    optimal = prediction_raw[index[0]]
    print(f"predicted max value {values[0]}")
    if return_info:
        return optimal.cpu().numpy(), info
    else:
        return optimal.cpu().numpy()

@torch.no_grad()
def beam_with_prior_expectation(prior, model, x, denormalize_rew, denormalize_val, discount, steps,
                    beam_width, n_expand, prob_threshold=0.05, likelihood_weight=5e2, prob_acc="expect", return_info=False):
    contex = None
    #print(x.shape, prior.observation_dim)
    #print(x)
    state = x[:, 0, :prior.observation_dim]
    #print(state)
    acc_probs = torch.zeros([1]).to(x)
    acc_oods = torch.zeros([1]).to(x)
    info = {}
    results_dict = {}
    context_to_next_tokens = {}
    contex = None
    prediction_raw_track = None
    values_track = None
    #init_contex = init_contex.view(-1, 1)
    #print(state)
    contex_track = None
    steps = 1
    for step in range(steps//model.latent_step):
        if step == 0:
            #logits, _ = prior(None, state) # [B x t x K]
            #start_time = time.time()
            logits, _ = prior(None, state)  # [B x t x K]
            #prior_time = time.time() - start_time
            #print("prior function 1:", prior_time)
        else:
            contex = None
            logits, _ = prior(None, state_for_next_prior)
        #print("state shape:",state.unsqueeze(0))
        #logits = logits[:, -1, :]
        probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        log_probs = torch.log(probs)
        #nb_samples = 64 if step == 0 else n_expand
        nb_samples = beam_width if step == 0 else n_expand
        samples = torch.multinomial(probs, num_samples=nb_samples, replacement=False) # [B, M]
        contex = samples.reshape([-1, 1]) #[(B*M) x t]
        #start_time = time.time()
        logits, _ = prior(contex, state)
        #prior_time = time.time() - start_time
        #print("prior function 2:", prior_time)

        probs = torch.softmax(logits[:, -1, :], dim=-1)  # [B x K]
        log_probs = torch.log(probs)
        samples = torch.multinomial(probs, num_samples=n_expand, replacement=True)  # [B, M]
        contex = torch.cat([torch.repeat_interleave(contex, n_expand, 0), samples.reshape([-1, 1])], dim=1)
        #samples_log_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(log_probs, samples)]) # [B, M]


        #print(samples_log_prob)
        #state_for_next_prior = prediction_raw[:, 0, 1:model.observation_dim + 1]
        #print(contex.shape, contex)



        #if prob_acc in ["product", "expect"]:
        #    acc_probs = acc_probs.repeat_interleave(nb_samples*n_expand, 0) + samples_log_prob.reshape([-1])
        #start_time = time.time()
        prediction_raw = model.decode_from_indices(contex, state)
        #prior_time = time.time() - start_time
        #print("decode function:", prior_time)
        prediction = prediction_raw.reshape([-1, n_expand, 2, 3*model.action_dim+model.observation_dim+2])
        prediction_output = prediction[:, 0, :, :]
        V_t = prediction_raw[:, 1, 0]
        a_v = prediction_raw[:, 0, 0]
        if denormalize_val is not None:
            #V_t = denormalize_val(V_t).reshape([contex.shape[0], -1])
            V_t = denormalize_val(V_t)
            a_v = denormalize_val(a_v)
        #values = V_t[:, -1] * discounts[:, -1]
        #values = V_t[:, -1]
        #values = V_t * torch.exp(acc_probs)
        values = V_t
        values = values.reshape([-1, n_expand])
        a_v = a_v.view(beam_width, n_expand)

        # Select the first value from each group of 4 values
        a_v = a_v[:, 0]
        #print(a_v)
        # Reshape tensor_64 to [64, 1]
        a_v = a_v.view(beam_width, 1)

        # Concatenate along the second dimension (dim=1)
        result_tensor = torch.cat((values, a_v), dim=1)
        #print(result_tensor.shape)
        values_track = result_tensor.mean(dim=1)
        #print(result_tensor.shape)
        nb_top = beam_width if step < (steps // model.latent_step - 1) else 1

        #else:
        #    nb_top = 1
        #print(values_track.shape, nb_top)
        if prob_acc == "expect":
            values_with_b, index = torch.topk(values_track, nb_top)
            #print(index, result_tensor)
        else:
            values_with_b, index = torch.topk(values_track, nb_top)
        if return_info:
            info[(step+1)*model.latent_step] = dict(predictions=prediction_raw.cpu(), returns=values.cpu(),
                                                    latent_codes=contex.cpu(), log_probs=acc_probs.cpu(),
                                                    objectives=values+likelihood_bonus, index=index.cpu())
    optimal = prediction_output[index[0]]
    print(f"predicted max value {values_track[index[0]]}")
    if return_info:
        return optimal.cpu().numpy(), info
    else:
        return optimal.cpu().numpy()



@torch.no_grad()
def beam_with_uniform(prior, model, x, denormalize_rew, denormalize_val, discount, steps,
                    beam_width, n_expand,  prob_threshold=0.05):
    contex = None
    state = x[:, 0, :model.observation_dim]
    acc_probs = torch.ones([1]).to(x)
    for step in range(steps//model.latent_step):
        logits, _ = prior(contex, state) # [B x t x K]
        probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        nb_samples = beam_width * n_expand if step == 0 else n_expand
        valid = probs > prob_threshold
        samples = torch.multinomial(valid/valid.sum(dim=-1), num_samples=nb_samples, replacement=True) # [B, M]
        samples_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(probs, samples)]) # [B, M]
        acc_probs = acc_probs.repeat_interleave(nb_samples, 0) * samples_prob.reshape([-1])
        if not contex is None:
            contex = torch.cat([torch.repeat_interleave(contex, nb_samples, 0), samples.reshape([-1, 1])],
                               dim=1)
        else:
            contex = samples.reshape([-1, step+1]) # [(B*M) x t]

        prediction_raw = model.decode_from_indices(contex, state)
        prediction = prediction_raw.reshape([-1, model.transition_dim])
        r_t, V_t = prediction[:, -3], prediction[:, -2]

        if denormalize_rew is not None:
            r_t = denormalize_rew(r_t).reshape([contex.shape[0], -1])
        if denormalize_val is not None:
            V_t = denormalize_val(V_t).reshape([contex.shape[0], -1])


        discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
        values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1] * discounts[:,-1]
        nb_top = beam_width if step < (steps//model.latent_step-1) else 1
        values, index = torch.topk(values, nb_top)
        contex = contex[index]
        acc_probs = acc_probs[index]

    optimal = prediction_raw[index[0]]
    print(f"predicted max value {values[0]}")
    return optimal.cpu().numpy()

@torch.no_grad()
def beam_mimic(prior, model, x, denormalize_rew, denormalize_val, discount, steps,
                    beam_width, n_expand,  prob_threshold=0.05):
    contex = None
    state = x[:, 0, :model.observation_dim]
    acc_probs = torch.ones([1]).to(x)
    for step in range(steps//model.latent_step):
        logits, _ = prior(contex, state) # [B x t x K]
        probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        nb_samples = beam_width * n_expand if step == 0 else n_expand
        samples = torch.multinomial(probs, num_samples=nb_samples, replacement=True) # [B, M]
        samples_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(probs, samples)]) # [B, M]
        acc_probs = acc_probs.repeat_interleave(nb_samples, 0) * samples_prob.reshape([-1])
        if not contex is None:
            contex = torch.cat([torch.repeat_interleave(contex, nb_samples, 0), samples.reshape([-1, 1])],
                               dim=1)
        else:
            contex = samples.reshape([-1, step+1]) # [(B*M) x t]

        nb_top = beam_width if step < (steps//model.latent_step-1) else 1
        values, index = torch.topk(acc_probs, nb_top)
        contex = contex[index]
        acc_probs = acc_probs[index]

    prediction_raw = model.decode_from_indices(contex, state)
    optimal = prediction_raw[0]
    print(f"value {values[0]}, prob {acc_probs[0]}")
    return optimal.cpu().numpy()


@torch.no_grad()
def enumerate_all(model, x, denormalize_rew, denormalize_val, discount):
    indicies = torch.range(0, model.model.K-1, device=x.device, dtype=torch.int32)
    prediction_raw = model.decode_from_indices(indicies, x[:, 0, :model.observation_dim])
    prediction = prediction_raw.reshape([-1, model.transition_dim])

    r_t, V_t = prediction[:, -2], prediction[:, -1]
    if denormalize_rew is not None:
        r_t = denormalize_rew(r_t).reshape([indicies.shape[0], -1])
    if denormalize_val is not None:
        V_t = denormalize_val(V_t).reshape([indicies.shape[0], -1])

    discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
    values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1] * discounts[:,-1]
    optimal = prediction_raw[values.argmax()]
    return optimal.cpu().numpy()


@torch.no_grad()
def propose_plan_continuous(model, x):
    latent = torch.zeros([1, model.trajectory_embd], device="cuda")
    prediction = model.decode(latent, x[:, 0, :model.observation_dim])
    prediction = prediction.reshape([-1, model.transition_dim])
    return prediction.cpu().numpy()