import numpy as np
import torch
import pdb

from ..utils.arrays import to_torch

VALUE_PLACEHOLDER = 1e6

def make_prefix(discretizer, context, obs, transition_dim, prefix_context=True, device="cuda"):
    observation_dim = obs.size
    if discretizer is not None:
        obs_discrete = discretizer.discretize(obs, subslice=[0, observation_dim])
        obs_discrete = to_torch(obs_discrete, dtype=torch.long, device=device)
        if prefix_context:
            prefix = torch.cat(context + [obs_discrete], dim=-1)
        else:
            prefix = obs_discrete
    else:
        obs_discrete = to_torch(obs, dtype=torch.float32, device=device)
        pad_dims = to_torch(np.zeros(transition_dim - len(obs)), dtype=torch.float32, device=device)
        if obs_discrete.ndim == 1:
            obs_discrete = obs_discrete.reshape(1, 1, -1) 
            pad_dims = pad_dims.reshape(1, 1, -1) 
        transition = torch.cat([obs_discrete, pad_dims], axis=-1)
        if prefix_context:
            prefix = torch.cat(context + [transition], dim=1)
        else:
            prefix = transition
    
    return prefix

def extract_actions(x, observation_dim, action_dim, t=None):
    actions = x[:, observation_dim:observation_dim+action_dim]
    if t is not None:
        return actions[t]
    else:
        return actions

def extract_actions_continuous(x, observation_dim, action_dim, t=None):
    assert x.shape[0] == 1
    actions = x[0, :, observation_dim:observation_dim+action_dim]
    if t is not None:
        return actions[t]
    else:
        return actions

def update_context(context, discretizer, observation, action, reward, max_context_transitions, device):
    '''
        context : list of transitions
            [ tensor( transition_dim ), ... ]
    '''
    ## use a placeholder for value because input values are masked out by model
    rew_val = np.array([reward, VALUE_PLACEHOLDER])
    transition = np.concatenate([observation, action, rew_val])

    ## discretize transition and convert to torch tensor
    if discretizer is not None:
        transition_discrete = discretizer.discretize(transition)
        transition_discrete = to_torch(transition_discrete, dtype=torch.long, device=device)
    else:
        transition_discrete = to_torch(transition, dtype=torch.float32, device=device)
        transition_discrete = transition_discrete.reshape(1, 1, -1)

    ## add new transition to context
    context.append(transition_discrete)

    ## crop context if necessary
    context = context[-max_context_transitions:]

    return context
