import numpy as np
import gym
from scipy.linalg import solve_discrete_are
from lqr_env import LQREnv
from IPython import embed
import torch
import scipy
import random
import itertools
from scipy.stats import norm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


UPPER_BOUND = np.inf
LOWER_BOUND = -np.inf


def sample(dim, H, var, type='uniform'):
    if type == 'uniform':
        means = np.random.uniform(0, 1, dim)
    elif type == 'bernoulli':
        means = np.random.beta(1, 1, dim)
    else:
        raise NotImplementedError
    env = BanditEnv(means, H, var=var, type=type)
    return env


def sample_for_arm(arm, dim, H, var):
    means = np.random.uniform(0, 1, dim)
    if means.argmax() != arm:
        means[arm] = np.random.uniform(means.max(), 1)
    env = BanditEnv(means, H, var=var)
    return env

def sample_topk(dim, H, var, k=1):
    means = np.random.uniform(0, 1, dim)
    env = TopKBanditEnv(means, H, var=var, k=k)
    return env
    

def sample_linear(arms, H, var):
    lin_d = arms.shape[1]
    theta = np.random.normal(0, 1, lin_d) / np.sqrt(lin_d)
    env = LinearBanditEnv(theta, arms, H, var=var)
    return env


class BanditEnv(LQREnv):
    def __init__(self, means, H, var=0.0, type='uniform'):
        opt_a_index = np.argmax(means)
        self.means = means
        self.opt_a_index = opt_a_index
        self.opt_a = np.zeros(means.shape)
        self.opt_a[opt_a_index] = 1.0
        self.dim = len(means)
        self.observation_space = gym.spaces.Box(low=1, high=1, shape=(1,))
        self.action_space = gym.spaces.Box(low=0, high=1, shape=(self.dim,))
        self.state = np.array([1])
        self.var = var
        self.dx = 1
        self.du = self.dim
        self.topk = False
        self.type = type
        
        # some naming issue here
        self.H_context = H
        self.H = 1

    def get_arm_value(self, u):
        return np.sum(self.means * u)

    def reset(self):
        self.current_step = 0
        return self.state

    def transit(self, x, u):
        a = np.argmax(u)
        if self.type == 'uniform':
            r = self.means[a] + np.random.normal(0, self.var)
        elif self.type == 'bernoulli':
            r = np.random.binomial(1, self.means[a])
        else:
            raise NotImplementedError
        return self.state.copy(), r

    def step(self, action):
        if self.current_step >= self.H:
            raise ValueError("Episode has already ended")

        _, r = self.transit(self.state, action)
        self.current_step += 1
        done = (self.current_step >= self.H)

        return self.state.copy(), r, done, {}


    def deploy_eval(self, ctrl):
        # No variance during evaluation
        tmp = self.var
        self.var = 0.0
        res = self.deploy(ctrl)
        self.var = tmp
        return res


class BanditEnvVec(LQREnv):
    """
    Vectorized bandit environment.
    """
    def __init__(self, envs):
        self._envs = envs
        self._num_envs = len(envs)

    def reset(self):
        return [env.reset() for env in self._envs]

    def step(self, actions):
        next_obs, rews, dones = [], [], []
        for action, env in zip(actions, self._envs):
            next_ob, rew, done, _ = env.step(action)
            next_obs.append(next_ob)
            rews.append(rew)
            dones.append(done)
        return next_obs, rews, dones, {}

    @property
    def num_envs(self):
        return self._num_envs

    @property
    def envs(self):
        return self._envs

    def deploy_eval(self, ctrl):
        # No variance during evaluation
        tmp = [env.var for env in self._envs]
        for env in self._envs:
            env.var = 0.0
        res = self.deploy(ctrl)
        for env, var in zip(self._envs, tmp):
            env.var = var
        return res

    def deploy(self, ctrl):
        x = self.reset()
        xs = []
        xps = []
        us = []
        rs = []
        done = False

        while not done:
            u = ctrl.act_numpy_vec(x)

            xs.append(x)
            us.append(u)

            x, r, done, _ = self.step(u)
            done = all(done)

            rs.append(r)
            xps.append(x)

        xs = np.concatenate(xs)
        us = np.concatenate(us)
        xps = np.concatenate(xps)
        rs = np.concatenate(rs)
        return xs, us, xps, rs

    def get_arm_value(self, us):
        values = [np.sum(env.means * u) for env, u in zip(self._envs, us)]
        return np.array(values)


class TopKBanditEnv(BanditEnv):
    def __init__(self, means, H, var=0.0, k=1):
        super().__init__(means, H, var=var)
        self.topk = True
        self.k = k
        indices = np.argsort(means)[::-1]
        self.opt_a_indices = indices[:k]
        self.opt_a = np.zeros(means.shape)
        self.opt_a[self.opt_a_indices] = 1.0
        self.opt_means = means[self.opt_a_indices]
        self.opt_mean_sum = self.opt_means.sum()
    
    def get_arm_value(self, u):
        assert np.sum(u) == self.k, f"Subset must consist of k={self.k} actions"
        return np.sum(self.means * u)

    def step(self, action):
        assert np.sum(action) == self.k, f"Subset must consist of k={self.k} actions"
        return super().step(action)

    def transit(self, x, u):
        assert np.sum(u) == self.k, f"Subset must consist of k={self.k} actions"
        r = (self.means * u).sum() + np.random.normal(0, self.var)
        return self.state.copy(), r



class LinearBanditEnv(BanditEnv):
    def __init__(self, theta, arms, H, var=0.0):
        self.theta = theta
        self.arms = arms
        self.means = arms @ theta
        self.opt_a_index = np.argmax(self.means)
        self.opt_a = np.zeros(self.means.shape)
        self.opt_a[self.opt_a_index] = 1.0
        self.dim = len(self.means)
        self.observation_space = gym.spaces.Box(low=1, high=1, shape=(1,))
        self.action_space = gym.spaces.Box(low=0, high=1, shape=(self.dim,)) # this is kind of wrong
        self.state = np.array([1])
        self.var = var
        self.dx = 1
        self.du = self.dim
        
        # some naming issue here
        self.H_context = H
        self.H = 1

    def get_arm_value(self, u):
        return np.sum(self.means * u)

    def reset(self):
        self.current_step = 0
        return self.state

    def transit(self, x, u):
        a = np.argmax(u)
        r = self.means[a] + np.random.normal(0, self.var)
        return self.state.copy(), r

    def step(self, action):
        if self.current_step >= self.H:
            raise ValueError("Episode has already ended")

        _, r = self.transit(self.state, action)
        self.current_step += 1
        done = (self.current_step >= self.H)

        return self.state.copy(), r, done, {}



class Controller:
    def set_batch(self, batch):
        self.batch = batch

    def set_batch_numpy_vec(self, batch):
        self.set_batch(batch)

    def set_env(self, env):
        self.env = env

        

class OptPolicy(Controller):
    def __init__(self, env):
        super().__init__()
        self.env = env

    def reset(self):
        return


    def act(self, x):
        return self.env.opt_a


class GreedyOptPolicy(Controller):
    def __init__(self, env):
        super().__init__()
        self.env = env

    def reset(self):
        return

    def act(self, x):
        rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()
        i = np.argmax(rewards)
        a = self.batch['rollin_us'].cpu().detach().numpy()[0][i]
        self.a = a        
        return self.a



class EmpMeanPolicy(Controller):
    def __init__(self, env, online=False, batch_size = 1):
        super().__init__()
        self.env = env
        self.online = online
        self.batch_size = batch_size

    def reset(self):
        return

    def act(self, x):
        actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
        rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

        b = np.zeros(self.env.dim)
        counts = np.zeros(self.env.dim)
        for i in range(len(actions)):
            c = np.argmax(actions[i])
            b[c] += rewards[i]
            counts[c] += 1

        b_mean = b / np.maximum(1, counts)

        i = np.argmax(b_mean)
        j = np.argmin(counts)
        if self.online and counts[j] == 0:
            i = j
        
        a = np.zeros(self.env.dim)
        a[i] = 1.0

        

        self.a = a
        return self.a

    def act_numpy_vec(self, x):
        actions = self.batch['rollin_us']
        rewards = self.batch['rollin_rs']

        b = np.zeros((self.batch_size, self.env.dim))
        counts = np.zeros((self.batch_size, self.env.dim))
        action_indices = np.argmax(actions, axis=-1)
        for idx in range(self.batch_size):
            actions_idx = action_indices[idx]
            rewards_idx = rewards[idx]
            for c in range(self.env.dim):
                arm_rewards = rewards_idx[actions_idx == c]
                b[idx, c] = np.sum(arm_rewards)
                counts[idx, c] = len(arm_rewards)

        b_mean = b / np.maximum(1, counts)

        i = np.argmax(b_mean, axis=-1)
        j = np.argmin(counts, axis=-1)
        if self.online:
            mask = (counts[np.arange(self.batch_size), j] == 0)
            i[mask] = j[mask]

        a = np.zeros((self.batch_size, self.env.dim))
        a[np.arange(self.batch_size), i] = 1.0

        self.a = a
        return self.a


class ThompsonSamplingPolicy2(Controller):
    def __init__(self, env, std=.1, sample=False, prior_mean=.5, prior_var=1/12.0, warm_start=False, batch_size=1):
        super().__init__()
        self.env = env
        self.variance = std**2
        self.prior_mean = prior_mean
        self.prior_variance = prior_var
        self.batch_size = batch_size

        self.reset()
        self.sample = sample
        self.warm_start = warm_start

    def reset(self):
        if self.batch_size > 1:
            self.means = np.ones((self.batch_size, self.env.dim)) * self.prior_mean
            self.variances = np.ones((self.batch_size, self.env.dim)) * self.prior_variance
            self.counts = np.zeros((self.batch_size, self.env.dim))
        else:
            self.means = np.ones(self.env.dim) * self.prior_mean
            self.variances = np.ones(self.env.dim) * self.prior_variance
            self.counts = np.zeros(self.env.dim)

    def set_batch(self, batch):
        self.reset()
        self.batch = batch
        actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
        rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

        for i in range(len(actions)):
            c = np.argmax(actions[i])
            self.counts[c] += 1

        for c in range(self.env.dim):
            arm_rewards = rewards[np.argmax(actions, axis=1) == c]
            self.update_posterior(c, arm_rewards)

    def set_batch_numpy_vec(self, batch):
        self.reset()
        self.batch = batch
        actions = self.batch['rollin_us']
        rewards = self.batch['rollin_rs'][:, :, 0]

        for i in range(len(actions[0])):
            c = np.argmax(actions[:, i], axis=-1)    # vector of size batch_size
            self.counts[np.arange(self.batch_size), c] += 1

        # TODO: We can batch this
        arm_means = np.zeros((self.batch_size, self.env.dim))
        for idx in range(self.batch_size):
            actions_idx = np.argmax(actions[idx], axis=-1)
            rewards_idx = rewards[idx]
            for c in range(self.env.dim):
                arm_rewards = rewards_idx[actions_idx == c]
                if self.counts[idx, c] > 0:
                    arm_mean = np.mean(arm_rewards)
                    arm_means[idx, c] = arm_mean

        assert arm_means.shape[0] == self.batch_size
        assert arm_means.shape[1] == self.env.dim

        self.update_posterior_all(arm_means)

    def update_posterior(self, c, arm_rewards):
        n = self.counts[c]

        if n > 0:
            arm_mean = np.mean(arm_rewards)
            prior_weight = self.prior_variance / (self.prior_variance + (n * self.variance))
            new_mean = prior_weight * self.prior_mean + (1 - prior_weight) * arm_mean
            new_variance = 1 / (1 / self.prior_variance + n / self.variance)

            self.means[c] = new_mean
            self.variances[c] = new_variance

    def update_posterior_all(self, arm_means):
        prior_weight = self.prior_variance / (self.prior_variance + (self.counts * self.variance))
        new_mean = prior_weight * self.prior_mean + (1 - prior_weight) * arm_means
        new_variance = 1 / (1 / self.prior_variance + self.counts / self.variance)

        mask = (self.counts > 0)
        self.means[mask] = new_mean[mask]
        self.variances[mask] = new_variance[mask]

    def act(self, x):

        # sample from the posterior
        if self.sample:

            values = np.random.normal(self.means, np.sqrt(self.variances))
            i = np.argmax(values)

            actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
            rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

            if self.warm_start:
                counts = np.zeros(self.env.dim)
                for j in range(len(actions)):
                    c = np.argmax(actions[j])
                    counts[c] += 1
                j = np.argmin(counts)
                if counts[j] == 0:
                    i = j



        else:
            values = np.random.normal(self.means, np.sqrt(self.variances), size=(100, self.env.dim))
            amax = np.argmax(values, axis=1)
            freqs = np.bincount(amax, minlength=self.env.dim)
            i = np.argmax(freqs)
        # sampled_means = [np.random.normal(loc=mean, scale=np.sqrt(var)) for mean, var in zip(self.means, self.variances)]
        # i = np.argmax(self.means)
        # i = np.argmax(sampled_means)
        a = np.zeros(self.env.dim)
        a[i] = 1.0
        self.a = a

        return self.a

    def act_numpy_vec(self, x):
        # sample from the posterior
        if self.sample:
            values = np.random.normal(self.means, np.sqrt(self.variances))
            action_indices = np.argmax(values, axis=-1)     # vector of size batch_size

            actions = self.batch['rollin_us']
            rewards = self.batch['rollin_rs']

            # if self.warm_start:
            #     counts = np.zeros((self.batch_size, self.env.dim))
            #     for j in range(len(actions[0])):
            #         c = np.argmax(actions[:, j], axis=-1)
            #         counts[np.arange(self.batch_size), c] += 1
            #     j = np.argmin(counts, axis=-1)

            #     # TODO: We can batch this
            #     for idx in range(self.batch_size):
            #         if counts[idx, j[idx]] == 0:
            #             action_indices[idx] = j[idx]
        else:
            values = np.stack([
                np.random.normal(self.means, np.sqrt(self.variances))
                for _ in range(100)], axis=1)
            amax = np.argmax(values, axis=-1)
            freqs = np.array([np.bincount(am, minlength=self.env.dim) for am in amax])
            action_indices = np.argmax(freqs, axis=-1)
        # sampled_means = [np.random.normal(loc=mean, scale=np.sqrt(var)) for mean, var in zip(self.means, self.variances)]
        # i = np.argmax(self.means)
        # i = np.argmax(sampled_means)
        actions = np.zeros((self.batch_size, self.env.dim))
        actions[np.arange(self.batch_size), action_indices] = 1.0
        self.a = actions
        return self.a


class PessMeanPolicy(Controller):
    def __init__(self, env, const=1.0, batch_size=1):
        super().__init__()
        self.env = env
        self.const = const
        self.batch_size = batch_size
        

    def reset(self):
        return

    def act(self, x):
        actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
        rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

        b = np.zeros(self.env.dim)
        counts = np.zeros(self.env.dim)
        for i in range(len(actions)):
            c = np.argmax(actions[i])
            b[c] += rewards[i]
            counts[c] += 1

        b_mean = b / np.maximum(1, counts)

        # compute the square root of the counts but clip so it's at least one
        pens = self.const / np.maximum(1, np.sqrt(counts))
        bounds = b_mean - pens

        i = np.argmax(bounds)
        a = np.zeros(self.env.dim)
        a[i] = 1.0
        self.a = a
        return self.a

    def act_numpy_vec(self, x):
        actions = self.batch['rollin_us']
        rewards = self.batch['rollin_rs']

        b = np.zeros((self.batch_size, self.env.dim))
        counts = np.zeros((self.batch_size, self.env.dim))
        action_indices = np.argmax(actions, axis=-1)
        for idx in range(self.batch_size):
            actions_idx = action_indices[idx]
            rewards_idx = rewards[idx]
            for c in range(self.env.dim):
                arm_rewards = rewards_idx[actions_idx == c]
                b[idx, c] = np.sum(arm_rewards)
                counts[idx, c] = len(arm_rewards)

        b_mean = b / np.maximum(1, counts)

        # compute the square root of the counts but clip so it's at least one
        pens = self.const / np.maximum(1, np.sqrt(counts))
        bounds = b_mean - pens

        i = np.argmax(bounds, axis=-1)
        a = np.zeros((self.batch_size, self.env.dim))
        a[np.arange(self.batch_size), i] = 1.0
        self.a = a
        return self.a



class UCBPolicy(Controller):
    def __init__(self, env, const=1.0, batch_size=1):
        super().__init__()
        self.env = env
        self.const = const
        self.batch_size = batch_size

    def reset(self):
        return

    def act(self, x):
        actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
        rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

        b = np.zeros(self.env.dim)
        counts = np.zeros(self.env.dim)
        for i in range(len(actions)):
            c = np.argmax(actions[i])
            b[c] += rewards[i]
            counts[c] += 1

        b_mean = b / np.maximum(1, counts)

        # compute the square root of the counts but clip so it's at least one
        bons = self.const / np.maximum(1, np.sqrt(counts))
        bounds = b_mean + bons

        i = np.argmax(bounds)
        j = np.argmin(counts)
        if counts[j] == 0:
            i = j

        
        a = np.zeros(self.env.dim)
        a[i] = 1.0
        self.a = a
        return self.a

    def act_numpy_vec(self, x):
        actions = self.batch['rollin_us']
        rewards = self.batch['rollin_rs']

        b = np.zeros((self.batch_size, self.env.dim))
        counts = np.zeros((self.batch_size, self.env.dim))
        action_indices = np.argmax(actions, axis=-1)
        for idx in range(self.batch_size):
            actions_idx = action_indices[idx]
            rewards_idx = rewards[idx]
            for c in range(self.env.dim):
                arm_rewards = rewards_idx[actions_idx == c]
                b[idx, c] = np.sum(arm_rewards)
                counts[idx, c] = len(arm_rewards)

        b_mean = b / np.maximum(1, counts)

        # compute the square root of the counts but clip so it's at least one
        bons = self.const / np.maximum(1, np.sqrt(counts))
        bounds = b_mean + bons

        i = np.argmax(bounds, axis=-1)
        j = np.argmin(counts, axis=-1)
        mask = (counts[np.arange(200), j] == 0)
        i[mask] = j[mask]

        a = np.zeros((self.batch_size, self.env.dim))
        a[np.arange(self.batch_size), i] = 1.0
        self.a = a
        return self.a


class RandPolicy(OptPolicy):

    def __init__(self, env):
        super().__init__(env)

    def act(self, x):
        a = np.zeros(self.env.dim)
        i = np.random.choice(np.arange(self.env.dim))
        a[i] = 1.0
        return a


class BanditTransformerController(Controller):
    def __init__(self, model, sample=False, batch_size=1):
        self.model = model
        self.du = model.config['du']
        self.dx = model.config['dx']
        self.H = model.H
        self.batch_size = batch_size
        self.zeros = torch.zeros(batch_size, self.dx**2 + self.du + 1).float().to(device)
        self.zerosQ = torch.zeros(batch_size, self.H, self.dx**2).float().to(device)
        self.sample = sample
        self.t = 0
        # if Q is not None:
        #     self.batch['Qs'] = torch.tensor(Q[None,:,:]).float().to(device)



    def set_env(self, env):
        return

    def set_batch_numpy_vec(self, batch):
        # Convert each element of the batch to a torch tensor
        for key in batch.keys():
            batch[key] = torch.tensor(batch[key]).float().to(device)
        self.set_batch(batch)

    def act(self, x):

        # if self.sample and self.t < self.du:
        #     a = np.zeros(self.du)
        #     a[self.t] = 1.0
        #     self.t += 1
        #     return a

        self.batch['zeros'] = self.zeros
        self.batch['zerosQ'] = self.zerosQ
        
        states = torch.tensor(x)[None,:].float().to(device)
        self.batch['states'] = states

        a = self.model(self.batch)
        a = a.cpu().detach().numpy()[0]

        if self.sample:
            probs = scipy.special.softmax(a)
            i = np.random.choice(np.arange(self.du), p=probs)
        else:
            i = np.argmax(a)

        a = np.zeros(self.du)
        a[i] = 1.0
        return a

    def act_numpy_vec(self, x):
        self.batch['zeros'] = self.zeros
        self.batch['zerosQ'] = self.zerosQ

        states = torch.tensor(np.array(x))
        if self.batch_size == 1:
            states = states[None,:]
        states = states.float().to(device)
        self.batch['states'] = states

        a = self.model(self.batch)
        a = a.cpu().detach().numpy()
        if self.batch_size == 1:
            a = a[0]

        if self.sample:
            probs = scipy.special.softmax(a, axis=-1)
            action_indices = np.array([np.random.choice(np.arange(self.du), p=p) for p in probs])
        else:
            action_indices = np.argmax(a, axis=-1)

        actions = np.zeros((self.batch_size, self.du))
        actions[np.arange(self.batch_size), action_indices] = 1.0
        return actions


class TopKBanditTransformerController(BanditTransformerController):
    def __init__(self, model, k=1, sample=False):
        super().__init__(model, sample=sample)
        self.k = k
        self.tmax = np.inf


    def act(self, x):
        self.batch['zeros'] = self.zeros
        self.batch['zerosQ'] = self.zerosQ
        
        states = torch.tensor(x)[None,:].float().to(device)
        self.batch['states'] = states

        a = self.model(self.batch)
        a = a.cpu().detach().numpy()[0]
        dim = a.shape[0]

        if self.sample and len(self.batch['rollin_rs'][0]) < self.tmax:
            # b = gumbel_softmax_sample_logits(dim, self.k, a)
            b = order_sample(a, self.k)
            # a = np.zeros(self.du)
            # a[indices] = 1.0
            a = b
        else:
            indices = np.argsort(a)[::-1][:self.k]
            a = np.zeros(self.du)
            a[indices] = 1.0
        
        return a


class TopKRandCommitPolicy(OptPolicy):
    def __init__(self, env, k, horizon, immediate=False):
        super().__init__(env)
        self.rand = True
        self.t = 0
        self.k = k
        self.horizon = horizon
        self.tmax = int(horizon * 3/4)
        self.imm = immediate

    def act(self, x):
        if self.t < self.tmax and not self.imm:
            hot_vector = np.zeros(self.env.dim)
            indices = np.random.choice(self.env.dim, size=self.env.k, replace=False)
            hot_vector[indices] = 1
            self.t += 1
            return hot_vector
        elif self.t == self.tmax or self.imm:
            self.a = best_emp_mean(self.batch)
            return self.a
        else:
            self.t += 1
            return self.a


class ETC(OptPolicy):
    def __init__(self, env, horizon):
        super().__init__(env)
        self.rand = True
        self.t = 0
        self.horizon = horizon
        self.tmax = int(horizon * 1/4)

    def reset(self):
        self.t = 0    

    def act(self, x):
        if self.t < self.tmax:
            a = np.zeros(self.env.dim)
            i = np.random.choice(np.arange(self.env.dim))
            a[i] = 1.0
            self.t += 1
            return a
        else:
            self.a = best_emp_mean(self.batch)
            return self.a

class TopKEpsGreedy(OptPolicy):
    def __init__(self, env, k, horizon, immediate=False):
        super().__init__(env)
        self.rand = True
        self.t = 0
        self.k = k
        self.horizon = horizon
        

    def act(self, x):
        if len(self.batch['rollin_rs'][0]) < 1 or random.random() < 0.1:
            hot_vector = np.zeros(self.env.dim)
            indices = np.random.choice(self.env.dim, size=self.env.k, replace=False)
            hot_vector[indices] = 1
            self.t += 1
            return hot_vector
        else:
            self.a = best_emp_mean(self.batch)
            return self.a


# K-armed bandit version
# class LinUCB(OptPolicy):
#     def __init__(self, env, k, const=1.0):
#         super().__init__(env)
#         self.rand = True
#         self.t = 0
#         self.k = k
#         self.theta = np.zeros(self.env.dim)
#         self.cov = 1.0 * np.eye(self.env.dim)
#         self.const = const

#     def act(self, x):
#         if len(self.batch['rollin_rs'][0]) < 1:
#             hot_vector = np.zeros(self.env.dim)
#             indices = np.random.choice(self.env.dim, size=self.env.k, replace=False)
#             hot_vector[indices] = 1
#             self.t += 1
#             return hot_vector

#         else:
#             actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
#             rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

#             cov = self.cov + actions.T @ actions
#             cov_inv = np.linalg.inv(cov)

#             theta = cov_inv @ actions.T @ rewards

#             if self.const == 0:
#                 indices = np.argsort(theta)[::-1][:self.k]    
#                 self.a = np.zeros(self.env.dim)
#                 self.a[indices] = 1.0
#                 return self.a
#             else:
#                 k_hot_vectors = generate_k_hot_vectors(self.env.dim, self.k)
#                 best_arm = None
#                 best_value = -np.inf
#                 for a in k_hot_vectors:
#                     value = theta @ a + self.const * np.sqrt(a @ cov_inv @ a)
#                     if value > best_value:
#                         best_value = value
#                         best_arm = a

#                 self.a = best_arm
#                 return self.a


class LinUCB(OptPolicy):
    def __init__(self, env, k, const=1.0):
        super().__init__(env)
        self.rand = True
        self.t = 0
        self.k = k
        self.const = const
        self.arms = env.arms
        self.d = self.arms.shape[1]
        self.dim = env.dim
        self.theta = np.zeros(self.d)
        self.init_cov = 1.0 * np.eye(self.d)

    def act(self, x):
        if len(self.batch['rollin_rs'][0]) < 1:
            i = np.random.choice(np.arange(self.dim))
            hot_vector = np.zeros(self.dim)
            hot_vector[i] = 1
            return hot_vector

        else:
            actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
            rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

            actions_indices = np.argmax(actions, axis=1)
            actions_arms = self.arms[actions_indices]

            cov = self.init_cov + actions_arms.T @ actions_arms
            cov_inv = np.linalg.inv(cov)

            theta = cov_inv @ actions_arms.T @ rewards

            best_arm_index = None
            best_value = -np.inf
            for i, arm in enumerate(self.arms):
                value = theta @ arm + self.const * np.sqrt(arm @ cov_inv @ arm)
                if value > best_value:
                    best_value = value
                    best_arm_index = i

            hot_vector = np.zeros(self.dim)
            hot_vector[best_arm_index] = 1
            self.a = hot_vector
            return hot_vector



def generate_k_hot_vectors(n, k):
    indices = range(n)
    k_hot_vectors = []
    for combination in itertools.combinations(indices, k):
        k_hot_vector = [0] * n
        for index in combination:
            k_hot_vector[index] = 1
        k_hot_vectors.append(k_hot_vector)
    return np.array(k_hot_vectors)


def best_emp_mean(batch):
    actions = batch['rollin_us'].cpu().detach().numpy()[0]
    rewards = batch['rollin_rs'].cpu().detach().numpy().flatten()
    
    stats = {}
    for a, r in zip(actions, rewards):
        if tuple(a) not in stats:
            stats[tuple(a)] = []
        stats[tuple(a)].append(r)
    emp_means = {k: np.mean(v) if len(v) > 0 else 0.0 for k, v in stats.items()}
    
    # get key of mean with highest value
    best = max(emp_means, key=emp_means.get)
    a = np.array(best)
    return a


def order_sample(logits, k):
    probs = scipy.special.softmax(logits * 2)
    b = np.zeros(len(logits))
    # print("\n")
    for i in range(k):
        # print(probs.round(2))
        j = np.random.choice(np.arange(len(probs)), p=probs)
        b[j] = 1.0
        probs[j] = 0.0
        probs = probs / np.sum(probs)
    return b

    # indices = np.argsort(probs)[::-1][:k]
    # b = np.zeros(len(logits))
    # b[indices] = 1.0
    # return b

def gumbel_softmax_sample_logits(n, k, a, tau=1.0):
    gumbel_noise = -np.log(-np.log(np.random.uniform(0, 1, n)))
    z = a  + .5 * gumbel_noise
    # softmax = np.exp(z / tau) / np.sum(np.exp(z / tau))
    top_k_indices = np.argsort(z)[::-1][:k]
    # top_k_indices = np.argpartition(z, -k)[-k:]
    
    k_hot_vector = np.zeros(n, dtype=int)
    k_hot_vector[top_k_indices] = 1
    
    return k_hot_vector




# class ThompsonSamplingPolicy(Controller):
#     def __init__(self, env, std=.1):
#         super().__init__()
#         self.env = env
#         self.variance = std**2
#         self.prior_mean = .5
#         self.prior_variance = 1.0
#         self.reset()

#     def reset(self):
#         self.means = np.ones(self.env.dim) * self.prior_mean
#         self.variances = np.ones(self.env.dim) * self.prior_variance
#         self.counts = np.zeros(self.env.dim)

#     def set_batch(self, batch):
#         self.reset()
#         self.batch = batch
#         actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
#         rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

#         for i in range(len(actions)):
#             c = np.argmax(actions[i])
#             self.counts[c] += 1

#         for c in range(self.env.dim):
#             arm_rewards = rewards[np.argmax(actions, axis=1) == c]
#             self.update_posterior(c, arm_rewards)


#     def update_posterior(self, c, arm_rewards):
#         n = self.counts[c]

#         if n > 0:
#             arm_mean = np.mean(arm_rewards)
#             prior_weight = self.prior_variance / (self.prior_variance + (n * self.variance))
#             new_mean = prior_weight * self.prior_mean + (1 - prior_weight) * arm_mean
#             new_variance = 1 / (1 / self.prior_variance + n / self.variance)

#             self.means[c] = new_mean
#             self.variances[c] = new_variance

#     def act(self, x):

#         i = np.argmax(self.means)
#         a = np.zeros(self.env.dim)
#         a[i] = 1.0
#         self.a = a

#         return self.a





if __name__ == '__main__':
    env = sample(3)
    ctrl = RandPolicy(env)
    embed()
