import sys
import torch
import numpy as np
from abc import abstractmethod
import time
from torch.nn.functional import softmax

# RL ENVIRONMENT OBJECT ============================================================

class BinaryRewardEnv:
    def __init__(self, success_p, num_arms=100, rng_gen=None):
        self.num_arms = num_arms
        self.success_p = success_p

        if rng_gen is None:
            self.seed = 49248204
            self.rng_gen = np.random.default_rng(self.seed)
        else:
            self.rng_gen = rng_gen
        
    def generate_reward(self, arms, t):
        tmp_success_p = self.success_p[arms]
        reward = self.rng_gen.binomial(1, tmp_success_p)
        return reward

    def get_expected_reward(self, arm, t):
        return self.success_p[arm]


class BinaryRewardEnv_impute:
    def __init__(self, success_p, T, X, Y, num_arms=100):
        assert success_p.shape[1] >= T

        self.X = X # shape: timesteps x (X dimension)
        self.num_arms = num_arms
        self.success_p = success_p
        self.potential_outcomes = Y
        
    def get_context(self, t):
        if self.X is None:
            return None
        return self.X[:,[t],:]
    
    def generate_reward(self, arm, t): 
        return self.potential_outcomes[arm,t]

    def get_expected_reward(self, arm, t): 
        return self.success_p[arm, t]

        
# GET BANDIT ENVS ============================================================

def get_bandit_envs(num_arms, T, N_monte_carlo, success_p_all, seed=879437260):
    # deterministic, and in the sense that if you re-run with different
    # values of N_monte_carlo = N1 < N2, then the first N1 bandit envs
    # will be the same for both runs
    all_bandit_envs = []
    rng_env = np.random.default_rng(seed)

    for i in range(N_monte_carlo):
        chosen_arms = rng_env.choice(np.arange(len(success_p_all)), num_arms)
        bandit_env_tmp = BinaryRewardEnv(success_p_all[chosen_arms], num_arms, rng_env)
        all_bandit_envs.append( (bandit_env_tmp, chosen_arms) )
    return all_bandit_envs


def get_bandit_envs_from_data_dict(data_dict, num_arms, T, N_monte_carlo, seed=879437260, context=False):
    all_bandit_envs = []
    generator = torch.Generator()
    generator.manual_seed(seed)
        
    for i in range(N_monte_carlo):
        num_rows = data_dict['click_rate'].shape[0]
        num_cols = data_dict['click_rate'].shape[1]

        # which arms/rows
        chosen_arms = torch.randperm(num_rows, generator=generator)[:num_arms]

        # which columns, after choosing arms/rows
        chosen_idxs = torch.randint(high=num_cols, size=(num_arms, T),
                    dtype=torch.int64, generator=generator)
        
        subset_X = None
        if 'X' in data_dict.keys():
            subset_X = torch.gather(data_dict['X'][chosen_arms], dim=1, index=chosen_idxs.unsqueeze(-1))
        
        subset_click_rates = torch.gather(data_dict['click_rate'][chosen_arms], dim=1, index=chosen_idxs)        
        bandit_env_tmp = BinaryRewardEnv_impute(subset_click_rates, T, subset_X)
        all_bandit_envs.append( (bandit_env_tmp, np.array(chosen_arms)) )
    return all_bandit_envs

# Note that this return tuples of (bandit env, data_dict) rather than (bandit_env, arms)
def get_bandit_envs_from_dgp(dgp_fn, num_arms, T, N_monte_carlo, seed=879437260, context=False):   
    all_bandit_envs = []
    generator = torch.Generator()
    generator.manual_seed(seed)
        
    for i in range(N_monte_carlo):
        data_dict = dgp_fn(num_arms, T, generator)
        bandit_env_tmp = BinaryRewardEnv_impute(data_dict['click_rate'], T, data_dict['X'], data_dict['Y'], num_arms=num_arms)
        #$ do something about this weird default...
        all_bandit_envs.append( (bandit_env_tmp, data_dict ) )
    return all_bandit_envs

def get_bandit_envs_from_dgp_withZ(dgp_fn, num_arms, T, N_monte_carlo, all_generate_Z, all_usable_Z, seed=879437260, context=False):  
    # all_usable_Z is Z features to train on
    # all_generate_Z is Z features used to generate outcomes
    
    all_bandit_envs = []
    generator = torch.Generator()
    generator.manual_seed(seed)

    assert len(all_generate_Z) == len(all_usable_Z)
    num_Z = len(all_generate_Z)
    
    for i in range(N_monte_carlo):
        # select some Z's from all_Z
        chosen_Z_idx = torch.randint(num_Z, (num_arms,), generator=generator)
        usable_Z = all_usable_Z[chosen_Z_idx]
        generate_Z = all_generate_Z[chosen_Z_idx]
        data_dict = dgp_fn(num_arms, T, generator, generate_Z)
        bandit_env_tmp = BinaryRewardEnv_impute(data_dict['click_rate'], T, data_dict['X'], data_dict['Y'], num_arms=num_arms)

        # to use to evaluate bandit algs that use Z
        data_dict['Z'] = usable_Z
        data_dict['generate_Z'] = generate_Z
        all_bandit_envs.append( (bandit_env_tmp, data_dict ) )
    return all_bandit_envs



# ABSTRACT BANDIT ALGORITHM CLASS ============================================================

class BanditAlgorithm:
    def __init__(self, num_arms=100, seed=None):
        self.num_arms = num_arms
        
        if seed is None:
            self.seed = 457323948
        else:
            self.seed = seed
        self.rng_gen = np.random.default_rng(self.seed)
    
    @abstractmethod
    def update_algorithm(self, arms, rewards):
        pass

    @abstractmethod        
    def sample_action(self):
        pass


class MarginalAlgWithContext(BanditAlgorithm):
    def __init__(self, model, Z_representation, num_arms):
        self.model = model
        self.Z_representation = Z_representation
        if Z_representation is not None:
            assert len(Z_representation) == num_arms

    def update_algorithm(self, arm, reward, X):
        pass
    
    def sample_action(self, X, return_extra=False):
        p_pred = self.model(self.Z_representation, X)
        best_arm = torch.argmax(p_pred).item()
        if return_extra:
            return best_arm, p_pred
        return best_arm

class LinearGaussianContextTS(BanditAlgorithm):
    '''
    Specify prior precision (inverse covariance) and noise variance. 
    '''
    def add_const_if_necessary(self, X):
        if not self.add_const_feature:
            return X
        if len(X.shape) == 2:
            A,D = X.shape
            return torch.concatenate([X, torch.ones(A,1)], 1)
        elif len(X.shape) == 3:
            A,T,D = X.shape
            return torch.concatenate([X, torch.ones(A,T,1)], 2)
        else:
            raise ValueError('something is wrong')
            
    def __init__(self, num_arms, X, hyparam_dict, seed=None, max_T=None, add_const_feature=True, prior_mean=None,
            Z=None):
        super(LinearGaussianContextTS, self).__init__(num_arms, seed)
        
        self.t = 0
        self.Zdim = 0
        self.Z = None
        if Z is not None:
            self.Z = Z
            assert len(Z) == num_arms
            self.Zdim = Z.shape[-1]

        A, T, raw_Xdim = X.shape
        assert A == num_arms
        self.num_arms = num_arms

        if max_T is None:
            max_T = T
        self.max_T = max_T
        assert max_T <= T

        self.add_const_feature = add_const_feature
        # add a constant to X, for all features
        self.features = self.add_const_if_necessary(X)
        self.Xdim = self.features.shape[-1]
        
        self.feature_dim = self.Xdim + self.Zdim

        # LINEAR GAUSSIAN PARAMS
        # from https://arxiv.org/pdf/1802.09127
        self.lam = hyparam_dict['lam']          # lambda: prior precision \Gamma_0 = I * lambda
        self.sig = hyparam_dict['sig']                # sigma

        # one per arm: model arms separately
        self.prior_beta_mean = torch.zeros(num_arms, self.feature_dim)
        if prior_mean is not None and add_const_feature:
            self.prior_beta_mean[:,-1] = prior_mean
            
        self.prior_beta_cov = torch.eye(self.feature_dim).repeat(num_arms,1,1) / self.lam
        self.prior_inv_beta_cov = torch.eye(self.feature_dim).repeat(num_arms,1,1) * self.lam

        # calculate posterior from all data and prior
        # maybe this is more numerically stable (untested)
        self.beta_mean = torch.zeros(num_arms, self.feature_dim)
        self.beta_cov = torch.eye(self.feature_dim).repeat(num_arms,1,1)
        self.inv_beta_cov = torch.eye(self.feature_dim).repeat(num_arms,1,1) 
        
        self.prev_Ys = torch.zeros(num_arms, max_T)
        self.obs_timesteps = torch.zeros(num_arms, max_T).int() # timesteps corresponding to observations, per arm
        self.obs_count = torch.zeros(num_arms).int() # number of total observations, per arm
            
    def update_algorithm(self, arm, reward, this_X):
        if not isinstance(reward, int):
            try:
                reward = reward.item()
            except:
                raise ValueError('reward must be int')
        self.prev_Ys[arm,self.obs_count[arm]] = reward
        self.obs_timesteps[arm, self.obs_count[arm]] = self.t
        self.obs_count[arm] += 1
        
        # do a posterior update for this arm using prior and all history for this arm:
        X = torch.index_select(self.features[arm], 0, self.obs_timesteps[arm,:self.obs_count[arm]])
        Y = self.prev_Ys[arm,:self.obs_count[arm]]
        if self.Z is not None:
            arm_Z = self.Z[[arm]].repeat(len(X),1)
            feats = torch.cat([arm_Z, X], dim=1)
        else:
            feats = X
        XtX = feats.T@feats
        post_cov = torch.inverse(self.prior_inv_beta_cov[arm] + XtX)
        post_mean = post_cov @ ( 
                         self.prior_inv_beta_cov[arm]@self.prior_beta_mean[arm].unsqueeze(1) + 
                         (feats.T@Y).unsqueeze(1)).flatten()
        
        self.beta_cov[arm] = post_cov
        self.beta_mean[arm] = post_mean
        self.t += 1
        
    def sample_action(self, X, return_extra=False):
        X = self.add_const_if_necessary(X[:,0,:])
        assert torch.allclose(self.features[:,self.t,:], X)
        if self.Z is None:
            feat = X
        else:
            feat = torch.cat([self.Z, X],1)
        arm_means = torch.zeros(self.num_arms)
        arm_vars = torch.zeros(self.num_arms)
        for arm in torch.arange(self.num_arms):
            arm_vars[arm] = self.sig**2 * feat[arm] @ self.beta_cov[arm] @ feat[arm].T
            arm_means[arm] = feat[arm] @ self.beta_mean[arm]
        samples = torch.normal(mean=arm_means, std=arm_vars**0.5)

        if return_extra:
            return samples.argmax().item(), {'arm_means':arm_means, 'arm_vars':arm_vars}
        return samples.argmax().item()

class LinearGaussianContextTS_general(BanditAlgorithm):
    '''
    Specify prior mean and (full) variance (matrix) on coefs, and also the noise variance. 
    '''
    def add_const_if_necessary(self, X):
        if not self.add_const_feature:
            return X
        if len(X.shape) == 2:
            A,D = X.shape
            return torch.concatenate([X, torch.ones(A,1)], 1)
        elif len(X.shape) == 3:
            A,T,D = X.shape
            return torch.concatenate([X, torch.ones(A,T,1)], 2)
        else:
            raise ValueError('something is wrong')

    def __init__(self, num_arms, X, prior_mean, prior_cov, noise_var, seed=None, max_T=None, add_const_feature=True, 
            Z=None, cholesky=False, normalize_feats=False):
        super(LinearGaussianContextTS_general, self).__init__(num_arms, seed)

        self.t = 0
        self.Zdim = 0
        self.Z = None
        if Z is not None:
            self.Z = Z
            assert len(Z) == num_arms
            self.Zdim = Z.shape[-1]

        A, T, raw_Xdim = X.shape
        assert A == num_arms
        self.num_arms = num_arms

        if max_T is None:
            max_T = T
        self.max_T = max_T
        assert max_T <= T
        
        # add a constant to X, for all features
        self.add_const_feature = add_const_feature
        self.features = self.add_const_if_necessary(X)
        self.Xdim = self.features.shape[-1]

        self.feature_dim = self.Xdim + self.Zdim

        self.prior_beta_mean = prior_mean
        self.prior_beta_cov = prior_cov
        self.prior_inv_beta_cov = torch.inverse(prior_cov)
        self.noise_var = noise_var

        # one per arm: model arms separately        
        # calculate posterior from all data and prior
        self.beta_mean = self.prior_beta_mean.unsqueeze(0).repeat((num_arms,1))
        self.beta_cov = self.prior_beta_cov.unsqueeze(0).repeat((num_arms,1,1))
        self.inv_beta_cov = self.prior_inv_beta_cov.unsqueeze(0).repeat((num_arms,1,1))
        
        self.prev_Ys = torch.zeros(num_arms, max_T)
        self.obs_timesteps = torch.zeros(num_arms, max_T).int() # timesteps corresponding to observations, per arm
        self.obs_count = torch.zeros(num_arms).int() # number of total observations, per arm
        
    def update_algorithm(self, arm, reward, this_X):
        if not isinstance(reward, int):
            try:
                reward = reward.item()
            except:
                raise ValueError('reward must be int')
        self.prev_Ys[arm,self.obs_count[arm]] = reward
        self.obs_timesteps[arm, self.obs_count[arm]] = self.t
        self.obs_count[arm] += 1

        # do a posterior update for this arm using prior and all history for this arm:
        X = torch.index_select(self.features[arm], 0, self.obs_timesteps[arm,:self.obs_count[arm]])
        Y = self.prev_Ys[arm,:self.obs_count[arm]]
        if self.Z is not None:
            arm_Z = self.Z[[arm]].repeat(len(X),1)
            feats = torch.cat([arm_Z, X], dim=1)
        else:
            feats = X
        XtX = feats.T@feats
        XtY = feats.T@Y
        
        post_cov = torch.inverse(self.prior_inv_beta_cov + 1/self.noise_var * XtX)
        post_mean = post_cov @ (
                         self.prior_inv_beta_cov @ self.prior_beta_mean.unsqueeze(1) +
                         1/self.noise_var * XtY.unsqueeze(1)).flatten()
        
        self.beta_cov[arm] = post_cov
        self.beta_mean[arm] = post_mean
        self.t += 1
         
    def sample_action(self, X, return_extra=False):
        X = self.add_const_if_necessary(X[:,0,:])
        assert torch.allclose(self.features[:,self.t,:], X)
        if self.Z is None:
            feat = X
        else:
            feat = torch.cat([self.Z, X],1)
        arm_means = torch.zeros(self.num_arms)
        arm_vars = torch.zeros(self.num_arms)
        for arm in torch.arange(self.num_arms):
            arm_vars[arm] = feat[arm] @ self.beta_cov[arm] @ feat[arm].T 
            arm_means[arm] = feat[arm] @ self.beta_mean[arm]

        samples = torch.normal(mean=arm_means, std=arm_vars**0.5)

        if return_extra:
            return samples.argmax().item(), {'arm_means':arm_means, 'arm_vars':arm_vars}
        return samples.argmax().item()


class NeuralLinearGaussianContextTS_general(BanditAlgorithm):
    '''
    Specify prior mean and (full) variance (matrix) on coefs, and also the noise variance. 
    Use learned features from neural network, rather than raw X's. 
    '''
    def add_const_if_necessary(self, X):
        if not self.add_const_feature:
            return X
        if len(X.shape) == 2:
            A,D = X.shape
            return torch.concatenate([X, torch.ones(A,1)], 1)
        elif len(X.shape) == 3:
            A,T,D = X.shape
            return torch.concatenate([X, torch.ones(A,T,1)], 2)
        else:
            raise ValueError('something is wrong')

    def __init__(self, num_arms, X, prior_mean, prior_cov, noise_var, seed=None, max_T=None, add_const_feature=True, 
            Z=None, neural_model=None, cholesky=False):
        super(NeuralLinearGaussianContextTS_general, self).__init__(num_arms, seed)
        assert neural_model is not None
        self.neural_model = neural_model
        self.cholesky = cholesky  # more stable than directly calculating inverses?
        self.t = 0
        self.Zdim = 0
        self.Z = None
        if Z is not None:
            self.Z = Z
            assert len(Z) == num_arms
            self.Zdim = Z.shape[-1]

        A, T, raw_Xdim = X.shape
        assert A == num_arms
        self.num_arms = num_arms

        if max_T is None:
            max_T = T
        self.max_T = max_T
        assert max_T <= T
        
        self.add_const_feature = add_const_feature

        with torch.no_grad():
            phi_embeddings = neural_model(Z, X)

        # add a constant to X, for all features
        self.features = self.add_const_if_necessary(phi_embeddings)

        print('FEATURES SHAPE', self.features.shape)
        self.Xdim = self.features.shape[-1]
        
        self.prior_beta_mean = prior_mean
        L = torch.linalg.cholesky(prior_cov)
        self.prior_inv_beta_cov = torch.cholesky_inverse(L)
        self.prior_beta_cov = prior_cov
        self.noise_var = noise_var
        
        # one per arm: model arms separately        
        # calculate posterior from all data and prior
        self.beta_mean = self.prior_beta_mean.unsqueeze(0).repeat((num_arms,1))
        self.beta_cov = self.prior_beta_cov.unsqueeze(0).repeat((num_arms,1,1))
        self.inv_beta_cov = self.prior_inv_beta_cov.unsqueeze(0).repeat((num_arms,1,1))
        
        self.prev_Ys = torch.zeros(num_arms, max_T)
        self.obs_timesteps = torch.zeros(num_arms, max_T).int() # timesteps corresponding to observations, per arm
        self.obs_count = torch.zeros(num_arms).int() # number of total observations, per arm
    
    def update_algorithm(self, arm, reward, this_X):
        if not isinstance(reward, int):
            try:
                reward = reward.item()
            except:
                raise ValueError('reward must be int')
        self.prev_Ys[arm,self.obs_count[arm]] = reward
        self.obs_timesteps[arm, self.obs_count[arm]] = self.t
        self.obs_count[arm] += 1

        # do a posterior update for this arm using prior and all history for this arm:
        X = torch.index_select(self.features[arm], 0, self.obs_timesteps[arm,:self.obs_count[arm]])
        Y = self.prev_Ys[arm,:self.obs_count[arm]]
        feats = X
        XtX = feats.T@feats
        XtY = feats.T@Y

        if not self.cholesky:
            post_cov = torch.inverse(self.prior_inv_beta_cov + 1/self.noise_var * XtX)
            post_mean = post_cov @ (
                             self.prior_inv_beta_cov @ self.prior_beta_mean.unsqueeze(1) +
                             1/self.noise_var * XtY.unsqueeze(1)).flatten()

        if self.cholesky:
            A = self.prior_inv_beta_cov + (1/self.noise_var) * XtX
            b = self.prior_inv_beta_cov @ self.prior_beta_mean + (1/self.noise_var) * XtY
            L = torch.linalg.cholesky(A)

            post_mean = torch.cholesky_solve(b.unsqueeze(1), L).squeeze(1)
            post_cov  = torch.cholesky_inverse(L)

        self.beta_cov[arm] = post_cov
        self.beta_mean[arm] = post_mean
        self.t += 1
        
    def sample_action(self, X, return_extra=False):
        X = self.add_const_if_necessary(self.neural_model(self.Z,X))[:,0,:]
        assert torch.allclose(self.features[:,self.t,:], X, atol=1e-6,rtol=1e-3)
        feat = X
        arm_means = torch.zeros(self.num_arms)
        arm_vars = torch.zeros(self.num_arms)
        arm = 0
        with torch.no_grad():
            for arm in torch.arange(self.num_arms):
                arm_vars[arm] = feat[arm].unsqueeze(0) @ self.beta_cov[arm] @ feat[arm].unsqueeze(0).T
                arm_means[arm] = feat[arm] @ self.beta_mean[arm]
            samples = torch.normal(mean=arm_means, std=arm_vars**0.5)

        if return_extra:
            return samples.argmax().item(), {'arm_means':arm_means, 'arm_vars':arm_vars}
        return samples.argmax().item()


# POSTERIOR BANDIT ALGORITHMS ============================================================


class GreedySequentialWithContext(BanditAlgorithm):
    def __init__(self, model, encoded_Z, num_arms, T, X, epsilon=0, tau=None):
        # epsilon is for epsilon-greedy
        # tau is for softmax version
        assert not ((tau is not None) and (epsilon > 0))
        # imputation sequential model with context
        self.model = model 
        # we are hoping for encoded Z's here. 
        self.Z = encoded_Z
        self.T = T
        self.num_arms = num_arms
        assert self.Z.shape[0] == self.num_arms
        # not really used, currently?
        self.X = X 
        self.Xdim = X.shape[-1]
        
        self.epsilon = epsilon
        self.tau = tau
        # things that increment / accumulate
        self.t = 0
        # accumulate observations observed by bandit algorithm
        self.hist_X = torch.zeros(self.num_arms, self.T, self.Xdim)
        self.hist_Y = torch.zeros(self.num_arms, self.T)
        self.hist_mask = torch.zeros(self.num_arms, self.T)
        
    def update_algorithm(self, arm, reward, X): 
        # todo check shapes here
        self.hist_X[arm,[self.t],:] = X[arm]
        self.hist_Y[arm,self.t] = reward
        self.hist_mask[arm,self.t] = 1
        self.t += 1
    
    def sample_action(self, X, return_extra=False):
        # put history into model, generate for current X, output that directly
        # might want to refactor some of this + put it into model instead
        curr_state = self.model.get_state(self.model.x_suff_encoder(self.hist_X), self.hist_Y)
        input_ = torch.cat([self.Z.unsqueeze(1), X, curr_state.unsqueeze(1)], 2)
        p_hat_pred = self.model.top_layer(input_).squeeze(2)
        best_arm = torch.argmax(p_hat_pred).item()
        if self.epsilon > 0:
            randomly_choose = torch.rand(1).item()
            if randomly_choose < self.epsilon:
                best_arm = torch.randint(self.num_arms, size=(1,)).item()
        if self.tau is not None:
            probs = softmax(p_hat_pred.flatten()/self.tau)
            best_arm = torch.multinomial(probs, num_samples=1).item()
        if return_extra:
            return best_arm, p_hat_pred
        return best_arm
    

class SequentialAlgWithContext(BanditAlgorithm):
    # imputation bandit algorithm for sequential models with context
    def __init__(self, model, Z, num_arms, T, X, 
                 get_ttp=None, train_ttp=None, 
                 t2p='logistic',
                 ignore_context=False,
                 num_imagined = 500,
                 finite_horizon_alg=False,
                 no_shuffle_boot=False, 
                 extra_xgb_params={}, timer=True,
                 X_samples=None):
        assert t2p in ['logistic','xgb','gp','mlp'] 
        # imputation sequential model with context
        self.model = model 
        self.Z = Z
        self.T = T
        self.num_arms = num_arms
        assert Z.shape[0] == self.num_arms
        self.num_imagined = num_imagined
        self.finite_horizon_alg = finite_horizon_alg
        self.no_shuffle_boot = no_shuffle_boot

        self.X = X
        self.eval_X = X # this is what you do table imputation over. We could make this different. 
                
        # ttp = table to policy
        self.ignore_context = ignore_context
        self.t2p = t2p
        self.extra_xgb_params=extra_xgb_params

        self.get_ttp = get_ttp
        self.train_ttp = train_ttp
        self.Xdim = X.shape[-1]
        self.Zdim = Z.shape[-1]
         
        # things that increment / accumulate
        self.t = 0
        # accumulate observations observed by bandit algorithm
        self.hist_X = torch.zeros(self.num_arms, self.T, self.Xdim)
        self.hist_Y = torch.zeros(self.num_arms, self.T)
        self.hist_mask = torch.zeros(self.num_arms, self.T)
        self.timer=True
        self.X_samples = X_samples
        if self.X_samples is not None:
            assert self.eval_X.shape[0] == self.X_samples.shape[0]
            assert self.eval_X.shape[2] == self.X_samples.shape[2]

    def update_algorithm(self, arm, reward, X): 
        # todo check shapes here
        self.hist_X[arm,[self.t],:] = X[arm]
        self.hist_Y[arm,self.t] = reward
        self.hist_mask[arm,self.t] = 1
        self.t += 1
    
    def sample_action(self, X, return_extra=False):
        # generate TTP training data. 
        # inputs are just X's, since we have a separate model per row/arm. 
        # eval_X: data to generate on
        # X (in arguments): current user features (apply learned model to this X)
        gen_time = 0
        fit_time = 0
        if self.finite_horizon_alg:
            if self.timer:
                start = time.time()
            if self.X_samples is not None:
                # draw samples from X_samples for current + future timesteps
                idxs = torch.randint(0, self.X_samples.shape[1], (self.eval_X.shape[1]-self.t,))
                sampled_X = torch.gather(self.X_samples, 1, idxs.unsqueeze(0).unsqueeze(-1).repeat((1,1,self.Xdim))).repeat(self.eval_X.shape[0],1,1)
                ttp_inputs = torch.concatenate((self.eval_X[:,:self.t], sampled_X), dim=1)
                assert ttp_inputs.shape[1] == self.eval_X.shape[1]
            else:
                ttp_inputs = self.eval_X
            with torch.no_grad():
                ttp_labels = self.model.fill_table_naive_finite(self.Z, self.hist_X, self.hist_Y, self.hist_mask, self.eval_X).detach()
            if self.num_imagined + self.t < self.T:
                ttp_inputs = ttp_inputs[:,:self.t + self.num_imagined]
                ttp_labels = ttp_labels[:,:self.t + self.num_imagined]
            if self.timer:
                end = time.time()
                gen_time += end - start
        else:
            # subsample inputs up to length self.num_imagined:
            if self.no_shuffle_boot:
                assert self.eval_X.shape[1] == self.num_imagined
                idxs = torch.arange(self.eval_X.shape[1])
            else:
                idxs = torch.randint(0, self.eval_X.shape[1], (self.num_imagined,))
            ttp_inputs = torch.gather(self.eval_X, 1, idxs.unsqueeze(0).unsqueeze(-1).repeat((1,1,self.Xdim))).repeat(self.eval_X.shape[0],1,1)
            with torch.no_grad():
                ttp_labels = self.model.fill_table_naive(self.Z, self.hist_X, self.hist_Y, self.hist_mask, ttp_inputs).detach()
        if self.timer:
            start = time.time()
        if self.ignore_context:
            preds = ttp_labels.mean(1)
        else:
            if self.t2p=='logistic':
                from sklearn.linear_model import LogisticRegression
                preds = []

                for idx in range(self.num_arms):
                    if ttp_labels[idx].min() == ttp_labels[idx].max():
                        preds.append(torch.ones(len(X[0])) * ttp_labels[idx].max()) 
                    else:
                        ttp_model = LogisticRegression()
                        ttp_model.fit(ttp_inputs[idx].numpy(), ttp_labels[idx].numpy())
                        preds.append(ttp_model.predict_proba(X[0])[:,1])
            elif self.t2p=='xgb':
                from sklearn.ensemble import GradientBoostingClassifier
                preds = []

                for idx in range(self.num_arms):
                    if ttp_labels[idx].min() == ttp_labels[idx].max():
                        preds.append(torch.ones(len(X[0])) * ttp_labels[idx].max()) 
                    else:
                        ttp_model = GradientBoostingClassifier(**self.extra_xgb_params)
                        ttp_model.fit(ttp_inputs[idx].numpy(), ttp_labels[idx].numpy())
                        preds.append(ttp_model.predict_proba(X[0])[:,1])
            elif self.t2p=='mlp':
                from sklearn.neural_network import MLPClassifier
                preds = []

                for idx in range(self.num_arms):
                    if ttp_labels[idx].min() == ttp_labels[idx].max():
                        preds.append(torch.ones(len(X[0])) * ttp_labels[idx].max()) 
                    else:
                        ttp_model = MLPClassifier(**self.extra_xgb_params)
                        ttp_model.fit(ttp_inputs[idx].numpy(), ttp_labels[idx].numpy())
                        preds.append(ttp_model.predict_proba(X[0])[:,1])

            elif self.t2p=='gp':
                from sklearn.gaussian_process import GaussianProcessClassifier
                preds = []

                for idx in range(self.num_arms):
                    if ttp_labels[idx].min() == ttp_labels[idx].max():
                        preds.append(torch.ones(len(X[0])) * ttp_labels[idx].max()) 
                    else:
                        ttp_model = GaussianProcessClassifier()
                        ttp_model.fit(ttp_inputs[idx].numpy(), ttp_labels[idx].numpy())
                        preds.append(ttp_model.predict_proba(X[0])[:,1])

            else:
                ttps_per_arm = []
                for _ in range(self.num_arms):
                    ttps_per_arm.append(self.get_ttp(self, in_dim=self.Xdim))

                # train
                for idx in range(self.num_arms):
                    ttp_model, ttp_criterion, ttp_optimizer = ttps_per_arm[idx]
                    ttp_model = self.train_ttp(ttp_model, ttp_criterion, ttp_optimizer, ttp_inputs[idx], ttp_labels[idx])

                # eval
                preds = []
                for idx in range(self.num_arms):
                    ttp_model, ttp_criterion, ttp_optimizer = ttps_per_arm[idx]
                    ttp_model.eval()
                    preds.append(ttp_model(X[0]).detach().item())

        p_pred = torch.tensor(np.array(preds))        
        best_arm = torch.argmax(p_pred).item()
        if self.timer:
            end = time.time()
            fit_time += end - start
        if return_extra:
            return best_arm, (p_pred, gen_time, fit_time)
        return best_arm


## LinUCB disjoint
# from https://github.com/kfoofw/bandit_simulations/blob/master/python/contextual_bandits/analysis/linUCB%20disjoint%20implementation%20and%20analysis.md
# which is bsaed on https://arxiv.org/pdf/1003.0146
# Create class object for a single linear ucb disjoint arm
class LinUCBDisjointArm():
    
    def __init__(self, arm_index, d, alpha):
        
        # Track arm index
        self.arm_index = arm_index
        
        # Keep track of alpha
        self.alpha = alpha
        
        # A: (d x d) matrix = D_a.T * D_a + I_d. 
        # The inverse of A is used in ridge regression 
        self.A = torch.eye(d)
        
        # b: (d x 1) corresponding response vector. 
        # Equals to D_a.T * c_a in ridge regression formulation
        self.b = torch.zeros([d,1])
        
    def calc_UCB(self, x_array):
        # Find A inverse for ridge regression
        A_inv = torch.linalg.inv(self.A)
        
        # Perform ridge regression to obtain estimate of covariate coefficients theta
        # theta is (d x 1) dimension vector
        self.theta = A_inv @ self.b
        
        # Reshape covariates input into (d x 1) shape vector
        x = x_array.reshape([-1,1])
        
        # Find ucb based on p formulation (mean + std_dev) 
        # p is (1 x 1) dimension vector
        p = self.theta.T @ x +  self.alpha * np.sqrt(x.T @ (A_inv@x))
        
        return p
    
    def reward_update(self, reward, x_array):
        # Reshape covariates input into (d x 1) shape vector
        x = x_array.reshape([-1,1])
        
        # Update A which is (d * d) matrix.
        self.A += x @ x.T
        
        # Update b which is (d x 1) vector
        # reward is scalar
        self.b += reward * x

class LinUCBDisjoint():
    # from https://github.com/kfoofw/bandit_simulations/blob/master/python/contextual_bandits/analysis/linUCB%20disjoint%20implementation%20and%20analysis.md 
    def __init__(self, K_arms, d, alpha):
        # d is dimension of X
        self.K_arms = K_arms
        self.linucb_arms = [LinUCBDisjointArm(arm_index = 1, d = d, alpha = alpha) for i in range(K_arms)]
        
    def sample_action(self, x_array, return_extra=True):
        # assume the contexts are the same across arms
        assert torch.allclose(x_array[0,0,:], x_array[1,0,:])
        x_array = x_array[0,0,:]
        # Initiate ucb to be 0
        highest_ucb = -1
        
        # Track index of arms to be selected on if they have the max UCB.
        candidate_arms = []
        for arm_index in range(self.K_arms):
            # Calculate ucb based on each arm using current covariates at time t
            arm_ucb = self.linucb_arms[arm_index].calc_UCB(x_array)
            # If current arm is highest than current highest_ucb
            if arm_ucb > highest_ucb:
                # Set new max ucb
                highest_ucb = arm_ucb
                # Reset candidate_arms list with new entry based on current arm
                candidate_arms = [arm_index]
            # If there is a tie, append to candidate_arms
            if arm_ucb == highest_ucb:
                candidate_arms.append(arm_index)
        
        # Choose based on candidate_arms randomly (tie breaker)
        chosen_arm = np.random.choice(candidate_arms)
        if return_extra:
            return chosen_arm, {'candidate_arms':candidate_arms}
        return chosen_arm
    def update_algorithm(self, arm, reward, X):
        self.linucb_arms[arm].reward_update(reward, X[0])



def run_bandit(env, alg, T, num_round_robin=0, context=False, return_extra=False, verbose=False):
    all_rewards = []
    all_exp_rewards = []
    action_arms = []
    X = None
    all_extras = []
    start_time = time.time()
    for t in range(T):
        extras = None
        if context:
            # This is returning one copy of the context per action
            X = env.get_context(t) #X[:,[t],:]
        if t < num_round_robin * env.num_arms:
            arm = t % env.num_arms
        else:
            if context:
                if return_extra:
                    arm, extras = alg.sample_action(X, return_extra=True)
                else:
                    arm = alg.sample_action(X)
            else:
                if return_extra:
                    arm, extras = alg.sample_action(return_extra=True)
                else:
                    arm = alg.sample_action()
        #print("arm", arm)
        reward = env.generate_reward(arm, t)
        if context:
            alg.update_algorithm(arm, reward, X)
        else:
            alg.update_algorithm(arm, reward)
        all_rewards.append(reward)

        # Get a less noisy estimate of the reward
        exp_reward = env.get_expected_reward(arm, t)
        all_exp_rewards.append(exp_reward)
        action_arms.append(arm)
        if extras is not None:
            all_extras.append(extras)
        
        if verbose:
            end_time = time.time()
            print(f'timestep {t}, took {end_time-start_time:.1f} seconds = {(end_time-start_time)/60:.1f} minutes')
            sys.stdout.flush()
            start_time = end_time
    res = { 'rewards':np.array(all_rewards), 
             'expected_rewards': np.array(all_exp_rewards), 
             'action_arms': np.array(action_arms) }
    if return_extra:
        res['extras'] = all_extras
    return res
