import numpy as np
import gym
from scipy.linalg import solve_discrete_are
from lqr_env import LQREnv
from IPython import embed
import torch
import scipy
import random
import itertools
from scipy.stats import norm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


UPPER_BOUND = np.inf
LOWER_BOUND = -np.inf


def sample(dim, H, var, type='uniform'):
    if type == 'uniform':
        means = np.random.uniform(0, 1, dim)
    elif type == 'bernoulli':
        means = np.random.beta(1, 1, dim)
    else:
        raise NotImplementedError
    env = BanditEnv(means, H, var=var, type=type)
    return env


def sample_for_arm(arm, dim, H, var):
    means = np.random.uniform(0, 1, dim)
    if means.argmax() != arm:
        means[arm] = np.random.uniform(means.max(), 1)
    env = BanditEnv(means, H, var=var)
    return env

def sample_topk(dim, H, var, k=1):
    means = np.random.uniform(0, 1, dim)
    env = TopKBanditEnv(means, H, var=var, k=k)
    return env
    



class BanditEnv(LQREnv):
    def __init__(self, means, H, var=0.0, type='uniform'):
        opt_a_index = np.argmax(means)
        self.means = means
        self.opt_a_index = opt_a_index
        self.opt_a = np.zeros(means.shape)
        self.opt_a[opt_a_index] = 1.0
        self.dim = len(means)
        self.observation_space = gym.spaces.Box(low=1, high=1, shape=(1,))
        self.action_space = gym.spaces.Box(low=0, high=1, shape=(self.dim,))
        self.state = np.array([1])
        self.var = var
        self.dx = 1
        self.du = self.dim
        self.topk = False
        self.type = type
        
        self.H_context = H
        self.H = 1

    def get_arm_value(self, u):
        return np.sum(self.means * u)

    def reset(self):
        self.current_step = 0
        return self.state

    def transit(self, x, u):
        a = np.argmax(u)
        if self.type == 'uniform':
            r = self.means[a] + np.random.normal(0, self.var)
        elif self.type == 'bernoulli':
            r = np.random.binomial(1, self.means[a])
        else:
            raise NotImplementedError
        return self.state.copy(), r

    def step(self, action):
        if self.current_step >= self.H:
            raise ValueError("Episode has already ended")

        _, r = self.transit(self.state, action)
        self.current_step += 1
        done = (self.current_step >= self.H)

        return self.state.copy(), r, done, {}


    def deploy_eval(self, ctrl):
        tmp = self.var
        self.var = 0.0
        res = self.deploy(ctrl)
        self.var = tmp
        return res

class TopKBanditEnv(BanditEnv):
    def __init__(self, means, H, var=0.0, k=1):
        super().__init__(means, H, var=var)
        self.topk = True
        self.k = k
        indices = np.argsort(means)[::-1]
        self.opt_a_indices = indices[:k]
        self.opt_a = np.zeros(means.shape)
        self.opt_a[self.opt_a_indices] = 1.0
        self.opt_means = means[self.opt_a_indices]
        self.opt_mean_sum = self.opt_means.sum()
    
    def get_arm_value(self, u):
        assert np.sum(u) == self.k, f"Subset must consist of k={self.k} actions"
        return np.sum(self.means * u)

    def step(self, action):
        assert np.sum(action) == self.k, f"Subset must consist of k={self.k} actions"
        return super().step(action)

    def transit(self, x, u):
        assert np.sum(u) == self.k, f"Subset must consist of k={self.k} actions"
        r = (self.means * u).sum() + np.random.normal(0, self.var)
        return self.state.copy(), r




class Controller:
    def set_batch(self, batch):
        self.batch = batch
    def set_env(self, env):
        self.env = env

        

class OptPolicy(Controller):
    def __init__(self, env):
        super().__init__()
        self.env = env

    def reset(self):
        return


    def act(self, x):
        return self.env.opt_a


class GreedyOptPolicy(Controller):
    def __init__(self, env):
        super().__init__()
        self.env = env

    def reset(self):
        return

    def act(self, x):
        rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()
        i = np.argmax(rewards)
        a = self.batch['rollin_us'].cpu().detach().numpy()[0][i]
        self.a = a        
        return self.a



class EmpMeanPolicy(Controller):
    def __init__(self, env):
        super().__init__()
        self.env = env


    def reset(self):
        return

    def act(self, x):
        actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
        rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

        b = np.zeros(self.env.dim)
        counts = np.zeros(self.env.dim)
        for i in range(len(actions)):
            c = np.argmax(actions[i])
            b[c] += rewards[i]
            counts[c] += 1

        b_mean = b / np.maximum(1, counts)

        i = np.argmax(b_mean)
        a = np.zeros(self.env.dim)
        a[i] = 1.0
        self.a = a

        return self.a




class ThompsonSamplingPolicy(Controller):
    def __init__(self, env, var=.1):
        super().__init__()
        self.env = env
        self.variance = var
        self.prior_mean = .5
        self.prior_variance = 1.0
        self.reset()

    def reset(self):
        self.means = np.ones(self.env.dim) * self.prior_mean
        self.variances = np.ones(self.env.dim) * self.prior_variance
        self.counts = np.zeros(self.env.dim)

    def set_batch(self, batch):
        self.reset()
        self.batch = batch
        actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
        rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

        for i in range(len(actions)):
            c = np.argmax(actions[i])
            self.counts[c] += 1

        for c in range(self.env.dim):
            arm_rewards = rewards[np.argmax(actions, axis=1) == c]
            self.update_posterior(c, arm_rewards)


    def update_posterior(self, c, arm_rewards):
        n = self.counts[c]

        if n > 0:
            arm_mean = np.mean(arm_rewards)
            prior_weight = self.prior_variance / (self.prior_variance + (n * self.variance))
            new_mean = prior_weight * self.prior_mean + (1 - prior_weight) * arm_mean
            new_variance = 1 / (1 / self.prior_variance + n / self.variance)

            self.means[c] = new_mean
            self.variances[c] = new_variance

    def act(self, x):

        i = np.argmax(self.means)
        a = np.zeros(self.env.dim)
        a[i] = 1.0
        self.a = a

        return self.a








class PessMeanPolicy(Controller):
    def __init__(self, env, const=1.0):
        super().__init__()
        self.env = env
        self.const = const
        

    def reset(self):
        return

    def act(self, x):
        actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
        rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

        b = np.zeros(self.env.dim)
        counts = np.zeros(self.env.dim)
        for i in range(len(actions)):
            c = np.argmax(actions[i])
            b[c] += rewards[i]
            counts[c] += 1

        b_mean = b / np.maximum(1, counts)

        pens = self.const / np.maximum(1, np.sqrt(counts))
        bounds = b_mean - pens

        i = np.argmax(bounds)
        a = np.zeros(self.env.dim)
        a[i] = 1.0
        self.a = a
        return self.a



class UCBPolicy(Controller):
    def __init__(self, env, const=1.0):
        super().__init__()
        self.env = env
        self.const = const

    def reset(self):
        return

    def act(self, x):
        actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
        rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

        b = np.zeros(self.env.dim)
        counts = np.zeros(self.env.dim)
        for i in range(len(actions)):
            c = np.argmax(actions[i])
            b[c] += rewards[i]
            counts[c] += 1

        b_mean = b / np.maximum(1, counts)

        bons = self.const / np.maximum(1, np.sqrt(counts))
        bounds = b_mean + bons

        i = np.argmax(bounds)
        a = np.zeros(self.env.dim)
        a[i] = 1.0
        self.a = a
        return self.a



class RandPolicy(OptPolicy):

    def __init__(self, env):
        super().__init__(env)

    def act(self, x):
        a = np.zeros(self.env.dim)
        i = np.random.choice(np.arange(self.env.dim))
        a[i] = 1.0
        return a


class BanditTransformerController(Controller):
    def __init__(self, model, sample=False):
        self.model = model
        self.du = model.config['du']
        self.dx = model.config['dx']
        self.H = model.H
        self.zeros = torch.zeros(1, self.dx**2 + self.du + 1).float().to(device)
        self.zerosQ = torch.zeros(1, self.H, self.dx**2).float().to(device)
        self.sample = sample

    def set_env(self, env):
        return


    def act(self, x):
        self.batch['zeros'] = self.zeros
        self.batch['zerosQ'] = self.zerosQ
        
        states = torch.tensor(x)[None,:].float().to(device)
        self.batch['states'] = states

        a = self.model(self.batch)
        a = a.cpu().detach().numpy()[0]

        if self.sample:
            probs = scipy.special.softmax(a)
            i = np.random.choice(np.arange(self.du), p=probs)
        else:
            i = np.argmax(a)

        a = np.zeros(self.du)
        a[i] = 1.0
        return a


class TopKBanditTransformerController(BanditTransformerController):
    def __init__(self, model, k=1, sample=False):
        super().__init__(model, sample=sample)
        self.k = k
        self.tmax = np.inf


    def act(self, x):
        self.batch['zeros'] = self.zeros
        self.batch['zerosQ'] = self.zerosQ
        
        states = torch.tensor(x)[None,:].float().to(device)
        self.batch['states'] = states

        a = self.model(self.batch)
        a = a.cpu().detach().numpy()[0]
        dim = a.shape[0]

        if self.sample and len(self.batch['rollin_rs'][0]) < self.tmax:
            b = order_sample(a, self.k)
            a = b
        else:
            indices = np.argsort(a)[::-1][:self.k]
            a = np.zeros(self.du)
            a[indices] = 1.0
        
        return a


class TopKRandCommitPolicy(OptPolicy):
    def __init__(self, env, k, horizon, immediate=False):
        super().__init__(env)
        self.rand = True
        self.t = 0
        self.k = k
        self.horizon = horizon
        self.tmax = int(horizon * 3/4)
        self.imm = immediate

    def act(self, x):
        if self.t < self.tmax and not self.imm:
            hot_vector = np.zeros(self.env.dim)
            indices = np.random.choice(self.env.dim, size=self.env.k, replace=False)
            hot_vector[indices] = 1
            self.t += 1
            return hot_vector
        elif self.t == self.tmax or self.imm:
            self.a = best_emp_mean(self.batch)
            return self.a
        else:
            self.t += 1
            return self.a



class TopKEpsGreedy(OptPolicy):
    def __init__(self, env, k, horizon, immediate=False):
        super().__init__(env)
        self.rand = True
        self.t = 0
        self.k = k
        self.horizon = horizon
        

    def act(self, x):
        if len(self.batch['rollin_rs'][0]) < 1 or random.random() < 0.1:
            hot_vector = np.zeros(self.env.dim)
            indices = np.random.choice(self.env.dim, size=self.env.k, replace=False)
            hot_vector[indices] = 1
            self.t += 1
            return hot_vector
        else:
            self.a = best_emp_mean(self.batch)
            return self.a



class LinUCB(OptPolicy):
    def __init__(self, env, k, const=1.0):
        super().__init__(env)
        self.rand = True
        self.t = 0
        self.k = k
        self.theta = np.zeros(self.env.dim)
        self.cov = 1.0 * np.eye(self.env.dim)
        self.const = const

    def act(self, x):
        if len(self.batch['rollin_rs'][0]) < 1:
            hot_vector = np.zeros(self.env.dim)
            indices = np.random.choice(self.env.dim, size=self.env.k, replace=False)
            hot_vector[indices] = 1
            self.t += 1
            return hot_vector

        else:
            actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
            rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

            cov = self.cov + actions.T @ actions
            cov_inv = np.linalg.inv(cov)

            theta = cov_inv @ actions.T @ rewards

            if self.const == 0:
                indices = np.argsort(theta)[::-1][:self.k]    
                self.a = np.zeros(self.env.dim)
                self.a[indices] = 1.0
                return self.a
            else:
                k_hot_vectors = generate_k_hot_vectors(self.env.dim, self.k)
                best_arm = None
                best_value = -np.inf
                for a in k_hot_vectors:
                    value = theta @ a + self.const * np.sqrt(a @ cov_inv @ a)
                    if value > best_value:
                        best_value = value
                        best_arm = a

                self.a = best_arm
                return self.a


def generate_k_hot_vectors(n, k):
    indices = range(n)
    k_hot_vectors = []
    for combination in itertools.combinations(indices, k):
        k_hot_vector = [0] * n
        for index in combination:
            k_hot_vector[index] = 1
        k_hot_vectors.append(k_hot_vector)
    return np.array(k_hot_vectors)


def best_emp_mean(batch):
    actions = batch['rollin_us'].cpu().detach().numpy()[0]
    rewards = batch['rollin_rs'].cpu().detach().numpy().flatten()
    
    stats = {}
    for a, r in zip(actions, rewards):
        if tuple(a) not in stats:
            stats[tuple(a)] = []
        stats[tuple(a)].append(r)
    emp_means = {k: np.mean(v) if len(v) > 0 else 0.0 for k, v in stats.items()}
    
    best = max(emp_means, key=emp_means.get)
    a = np.array(best)
    return a


def order_sample(logits, k):
    probs = scipy.special.softmax(logits * 2)
    b = np.zeros(len(logits))
    for i in range(k):
        j = np.random.choice(np.arange(len(probs)), p=probs)
        b[j] = 1.0
        probs[j] = 0.0
        probs = probs / np.sum(probs)
    return b


def gumbel_softmax_sample_logits(n, k, a, tau=1.0):
    gumbel_noise = -np.log(-np.log(np.random.uniform(0, 1, n)))
    z = a  + .5 * gumbel_noise
    top_k_indices = np.argsort(z)[::-1][:k]
    
    k_hot_vector = np.zeros(n, dtype=int)
    k_hot_vector[top_k_indices] = 1
    
    return k_hot_vector



if __name__ == '__main__':
    env = sample(3)
    ctrl = RandPolicy(env)
    embed()