import numpy as np
from joblib import Parallel, delayed
import matplotlib.pyplot as plt

class StageGeneratorEnv:
    '''
    Implements the environment described in the paper. u is optimal arm. 
    '''
    def __init__(self, d, u, Delta=5.0, K=None, sigma=1.0, random_state=None):
        self.d = d
        self.u = u / np.linalg.norm(u)
        self.Delta = Delta
        self.K = K 
        # assert self.K >= 2*(d-1) + 1, "K too small"
        self.noise_count = self.K - 1
        self.sigma = sigma
        self.rng = np.random.RandomState(random_state)

    def arms(self):
        X = self.rng.randn(self.d, self.d-1)
        X -= np.outer(self.u, self.u.dot(X))
        Q, _ = np.linalg.qr(X)
        arms = [self.u]
        for _ in range(self.noise_count):
            coords = self.rng.randn(self.d - 1)
            a = Q.dot(coords)
            a /= np.linalg.norm(a)
            arms.append(a)
        return arms

    def reward(self, ctx, arm):
        mean = arm.dot(self.Delta * self.u)
        return mean + self.rng.randn() * self.sigma


class SimpleLinTS:
    def __init__(self, env, lam, random_state=None):
        self.env = env
        self.d = env.d
        self.V = lam * np.eye(self.d)
        self.V_inv = (1.0/lam) * np.eye(self.d)
        self.b = np.zeros(self.d)
        self.rng   = np.random.RandomState(random_state)

    def run(self, T):
        regrets = np.zeros(T)
        theta_hat = self.rng.multivariate_normal(np.zeros(self.d), np.eye(self.d))
        for t in range(T):
            # sample \tilde theta
            theta_t = self.rng.multivariate_normal(np.zeros(self.d), self.V_inv)
            arms = self.env.arms()
            phis = np.stack(arms)
            idx  = np.argmax(phis.dot(theta_t))
            a = arms[idx]
            x = self.env.reward(None, a)
            phi = phis[idx]
            # update inverse design matrix with Sherman-Morrison
            self.V += np.outer(phi, phi)
            v = self.V_inv.dot(phi)
            self.V_inv -= np.outer(v, v) / (1 + phi.dot(v))
            self.b += phi * x
            theta_hat = self.V_inv.dot(self.b)
            test = self.env.arms()
            phis = np.stack(test)
            a_test_idx = np.argmax(phis.dot(theta_hat))
            a_test = test[a_test_idx]
            regrets[t] = self.env.Delta * (1 - a_test.dot(self.env.u))

        return regrets

class LinTS:
    '''Cumulative LinTS'''
    def __init__(self, env, lam, random_state=None):
        self.env = env
        self.d = env.d
        self.V = lam * np.eye(self.d)
        self.V_inv = (1.0/lam) * np.eye(self.d)
        self.b = np.zeros(self.d)
        self.rng = np.random.RandomState(random_state)

    def run(self, T):
        regrets = np.zeros(T)
        theta_hat = self.rng.multivariate_normal(np.zeros(self.d), np.eye(self.d))
        for t in range(T):
            # sample \tilde theta
            theta_t = self.rng.multivariate_normal(theta_hat, self.V_inv)
            arms = self.env.arms()
            phis = np.stack(arms)
            idx  = np.argmax(phis.dot(theta_t))
            a = arms[idx]
            x = self.env.reward(None, a)
            phi = phis[idx]
            # update inverse design matrix with Sherman-Morrison
            self.V += np.outer(phi, phi)
            v = self.V_inv.dot(phi)
            self.V_inv -= np.outer(v, v) / (1 + phi.dot(v))
            self.b += phi * x
            theta_hat = self.V_inv.dot(self.b)
            test = self.env.arms()
            phis = np.stack(test)
            a_test_idx = np.argmax(phis.dot(theta_hat))
            a_test = test[a_test_idx]

            regrets[t] = self.env.Delta * (1 - a_test.dot(self.env.u))

        return regrets


class UniformGreedySampler:
    def __init__(self, env, lam, random_state=None):
        self.env=env
        self.d = env.d
        self.V = lam * np.eye(self.d)
        self.V_inv = (1.0/lam) * np.eye(self.d)
        self.b = np.zeros(self.d)
        self.rng = np.random.RandomState(random_state)

    def run(self, T):
        regrets = np.zeros(T)
        for t in range(T):
            arms = self.env.arms()
            # uniformly pull an arm
            idx = self.rng.randint(len(arms))
            a = arms[idx]
            x = self.env.reward(None, a)
            phi = a
            # update inverse design matrix with Sherman-Morrison
            self.V += np.outer(phi, phi)
            v = self.V_inv.dot(phi)
            self.V_inv -= np.outer(v, v) / (1 + phi.dot(v))
            self.b  += phi * x
            # Lesat squares estimate
            theta_hat = self.V_inv.dot(self.b)
            test = self.env.arms()
            phis = np.stack(test)
            a_test_idx = np.argmax(phis.dot(theta_hat))
            a_test = test[a_test_idx]
            regrets[t] = self.env.Delta * (1 - a_test.dot(self.env.u))

        return regrets


def run_experiment(run_id, env, lam, T):
    ts = SimpleLinTS(env, lam, random_state=run_id)
    ug = UniformGreedySampler(env, lam, random_state=run_id)
    cumu_ts = LinTS(env, lam, random_state=run_id)
    return ts.run(T), ug.run(T), cumu_ts.run(T)


if __name__ == "__main__":
    # problem & algorithm parameters
    d, lam, Delta, K, sigma, T = 8, 1.0, 5.0, 8, 5, 10000
    n_runs = 100
    
    # for K in [1, 2, 8, 16, 32, 64, 128, 256]:
    for K in [256]:
        rng = np.random.RandomState(0)
        u   = rng.randn(d)
        u  /= np.linalg.norm(u)

        results = Parallel(n_jobs=-1)(
            delayed(run_experiment)(
                i,
                StageGeneratorEnv(d, u=u, Delta=Delta, K=K, sigma=sigma, random_state=i),
                lam, T
            )
            for i in range(n_runs)
        )

        regs_ts = np.vstack([r[0] for r in results])
        regs_ug = np.vstack([r[1] for r in results])
        regs_cumuts = np.vstack([r[2] for r in results])
        
        mean_ts = regs_ts.mean(axis=0)
        mean_ug = regs_ug.mean(axis=0)
        mean_cumuts = regs_cumuts.mean(axis=0)
        
        se_ts = regs_ts.std(axis=0) / np.sqrt(n_runs)
        se_ug = regs_ug.std(axis=0) / np.sqrt(n_runs)
        se_cumuts = regs_cumuts.std(axis=0) / np.sqrt(n_runs)
        
        time_steps = np.arange(1, T+1)
        
        error_bar_interval = T // 20  # show about 20 error bars
        error_indices = np.arange(0, T, error_bar_interval)
        
        plt.figure(figsize=(5, 3))
        
        plt.tick_params(axis='both', which='major', labelsize=10)
        plt.plot(time_steps[0:], mean_ts[0:], label='SimpLinTS', color='blue', linewidth=1.5)
        plt.plot(time_steps[0:], mean_ug[0:], label='Uniform-Greedy', color='orange', linewidth=1.5)
        plt.plot(time_steps[0:], mean_cumuts[0:], label='CumuLinTS', color='green', linewidth=1.5)
        
        plt.errorbar(time_steps[error_indices][0:], mean_ts[error_indices][0:], 
                    yerr=se_ts[error_indices][0:], fmt='none', ecolor='blue', 
                    alpha=0.6, capsize=3, capthick=1, linewidth=1)
        
        plt.errorbar(time_steps[error_indices][0:], mean_ug[error_indices][0:], 
                    yerr=se_ug[error_indices][0:], fmt='none', ecolor='orange', 
                    alpha=0.6, capsize=3, capthick=1, linewidth=1)
        
        plt.errorbar(time_steps[error_indices][0:], mean_cumuts[error_indices][0:], 
                    yerr=se_cumuts[error_indices][0:], fmt='none', ecolor='green', 
                    alpha=0.6, capsize=3, capthick=1, linewidth=1)
        
        plt.xlabel('Rounds', fontsize=10)
        plt.ylabel(f'Avg. Simple Regret, d={d}, K={K}', fontsize=10)
        plt.legend(fontsize=10, loc='best')
        plt.grid(True)
        
        plt.tight_layout()
        filename = f'Lin_avg_cumu_compare_regret_d{d}_K{K}_mean{Delta}_sigma{sigma}_T{T}.png'
        plt.savefig(filename)
        print("Saved "+f'Lin_avg_compare_regret_d{d}_K{K}_mean{Delta}_sigma{sigma}_T{T}.png')