import numpy as np
from joblib import Parallel, delayed
import matplotlib.pyplot as plt

class StageGeneratorEnv:
    '''
    Implements the environment described in the paper. u is optimal arm. 
    '''
    def __init__(self, d, u, Delta=5.0, K=None, sigma=1.0, random_state=None):
        self.d = d
        self.u = u / np.linalg.norm(u)
        self.Delta = Delta
        self.K = K 
        self.noise_count = self.K - 1
        self.sigma = sigma
        self.rng = np.random.RandomState(random_state)

    def arms(self):
        X = self.rng.randn(self.d, self.d-1)
        X -= np.outer(self.u, self.u.dot(X))
        Q, _ = np.linalg.qr(X)
        arms = [self.u]

        for _ in range(self.noise_count):
            coords = self.rng.randn(self.d - 1)
            a = Q.dot(coords)
            a /= np.linalg.norm(a)
            arms.append(a)
        return arms

    def reward(self, ctx, arm):
        mean = arm.dot(self.Delta * self.u)
        return mean + self.rng.randn() * self.sigma


class SimpleLinTS:
    def __init__(self, env, lam, random_state=None):
        self.env   = env
        self.d     = env.d
        self.V     = lam * np.eye(self.d)
        self.V_inv = (1.0/lam) * np.eye(self.d)
        self.b     = np.zeros(self.d)
        self.rng   = np.random.RandomState(random_state)

    def run_to_precision(self, precision=0.1, max_T=100000):
        for t in range(max_T):
            # sample \tilde theta
            theta_t = self.rng.multivariate_normal(np.zeros(self.d), self.V_inv)
            arms = self.env.arms()
            phis = np.stack(arms)
            idx  = np.argmax(phis.dot(theta_t))
            a = arms[idx]
            x = self.env.reward(None, a)
            phi = phis[idx]
            # update inverse design matrix with Sherman-Morrison
            self.V += np.outer(phi, phi)
            v = self.V_inv.dot(phi)
            self.V_inv -= np.outer(v, v) / (1 + phi.dot(v))
            self.b += phi * x
            theta_hat = self.V_inv.dot(self.b)
            test = self.env.arms()
            phis = np.stack(test)
            a_test_idx = np.argmax(phis.dot(theta_hat))
            a_test = test[a_test_idx]
            current_regret = self.env.Delta * (1 - a_test.dot(self.env.u))
            if current_regret <= precision:
                return t + 1
        return max_T

class LinTS:
    '''Cumulative LinTS'''
    def __init__(self, env, lam, random_state=None):
        self.env = env
        self.d = env.d
        self.V = lam * np.eye(self.d)
        self.V_inv = (1.0/lam) * np.eye(self.d)
        self.b = np.zeros(self.d)
        self.rng = np.random.RandomState(random_state)

    def run_to_precision(self, precision=0.1, max_T=100000):
        theta_hat = self.rng.multivariate_normal(np.zeros(self.d), np.eye(self.d))
        for t in range(max_T):
            # sample \tilde theta
            theta_t = self.rng.multivariate_normal(theta_hat, self.V_inv)
            arms = self.env.arms()
            phis = np.stack(arms)
            idx  = np.argmax(phis.dot(theta_t))
            a = arms[idx]
            x = self.env.reward(None, a)
            phi = phis[idx]
            # update inverse design matrix with Sherman-Morrison
            self.V += np.outer(phi, phi)
            v = self.V_inv.dot(phi)
            self.V_inv -= np.outer(v, v) / (1 + phi.dot(v))
            self.b += phi * x
            theta_hat = self.V_inv.dot(self.b)
            test = self.env.arms()
            phis = np.stack(test)
            a_test_idx = np.argmax(phis.dot(theta_hat))
            a_test = test[a_test_idx]
            current_regret = self.env.Delta * (1 - a_test.dot(self.env.u))
            if current_regret <= precision:
                return t + 1 
        return max_T


class UniformGreedySampler:
    def __init__(self, env, lam, random_state=None):
        self.env = env
        self.d=env.d
        self.V = lam * np.eye(self.d)
        self.V_inv = (1.0/lam) * np.eye(self.d)
        self.b = np.zeros(self.d)
        self.rng = np.random.RandomState(random_state)

    def run_to_precision(self, precision=0.1, max_T=100000):
        for t in range(max_T):
            arms = self.env.arms()
            # uniformly pull an arm
            idx = self.rng.randint(len(arms))
            a = arms[idx]
            x = self.env.reward(None, a)
            phi = a
            # update inverse design matrix with Sherman-Morrison
            self.V += np.outer(phi, phi)
            v = self.V_inv.dot(phi)
            self.V_inv -= np.outer(v, v) / (1 + phi.dot(v))
            self.b += phi * x
            # Lesat squares estimate
            theta_hat = self.V_inv.dot(self.b)
            test = self.env.arms()
            phis = np.stack(test)
            a_test_idx = np.argmax(phis.dot(theta_hat))
            a_test = test[a_test_idx]
            current_regret = self.env.Delta * (1 - a_test.dot(self.env.u))
            if current_regret <= precision:
                return t + 1 

        return max_T 


def run_single_experiment(run_id, env, lam, precision, max_T):
    ts = SimpleLinTS(env, lam, random_state=run_id)
    ug = UniformGreedySampler(env, lam, random_state=run_id)
    cumuts = LinTS(env, lam, random_state=run_id)
    
    rounds_ts = ts.run_to_precision(precision, max_T)
    rounds_ug = ug.run_to_precision(precision, max_T)
    rounds_cumuts = cumuts.run_to_precision(precision, max_T)
    
    return rounds_ts, rounds_ug, rounds_cumuts


if __name__ == "__main__":
    # problem & algorithm parameters
    d = 8
    lam = 1.0
    Delta = 5.0
    sigma = 5.0
    precision = 0.1
    max_T = 100000
    n_runs = 100

    # Generate K values: 1, 2, 4, ..., 2^d, 2^(d+1), ..., 2^(1.5d)
    K_values = [1, 2] + [2**i for i in range(2, int(1.5*d) + 1)]
    
    print(f"Running experiments for d={d}, precision={precision}")
    print(f"K values: {K_values}")

    rng = np.random.RandomState(0)
    u = rng.randn(d)
    u /= np.linalg.norm(u)

    results_summary = []
    
    for K in K_values:
        print(f"\nRunning experiments for K={K}...")
        results = Parallel(n_jobs=-1)(
            delayed(run_single_experiment)(
                i,
                StageGeneratorEnv(d, u=u, Delta=Delta, K=K, sigma=sigma, random_state=i),
                lam, precision, max_T
            )
            for i in range(n_runs)
        )

        rounds_ts = [r[0] for r in results]
        rounds_ug = [r[1] for r in results]
        rounds_cumuts = [r[2] for r in results]

        mean_rounds_ts = np.mean(rounds_ts)
        std_rounds_ts = np.std(rounds_ts)
        mean_rounds_ug = np.mean(rounds_ug)
        std_rounds_ug = np.std(rounds_ug)
        mean_rounds_cumuts = np.mean(rounds_cumuts)
        std_rounds_cumuts = np.std(rounds_cumuts)

        success_ts = sum(1 for r in rounds_ts if r < max_T)
        success_ug = sum(1 for r in rounds_ug if r < max_T)
        success_cumuts = sum(1 for r in rounds_cumuts if r < max_T)

        results_summary.append({
            'K': K,
            'mean_rounds_ts': mean_rounds_ts,
            'std_rounds_ts': std_rounds_ts,
            'mean_rounds_ug': mean_rounds_ug,
            'std_rounds_ug': std_rounds_ug,
            'mean_rounds_cumuts': mean_rounds_cumuts,
            'std_rounds_cumuts': std_rounds_cumuts,
            'success_rate_ts': success_ts / n_runs,
            'success_rate_ug': success_ug / n_runs,
            'success_rate_cumuts': success_cumuts / n_runs
        })
        
        print(f"K={K}: TS={mean_rounds_ts:.1f}±{std_rounds_ts:.1f} (success: {success_ts}/{n_runs}), "
              f"UG={mean_rounds_ug:.1f}±{std_rounds_ug:.1f} (success: {success_ug}/{n_runs}), "
              f"CumuTS={mean_rounds_cumuts:.1f}±{std_rounds_cumuts:.1f} (success: {success_cumuts}/{n_runs})")

    # Create plots
    K_vals = [r['K'] for r in results_summary]
    ts_means = [r['mean_rounds_ts'] for r in results_summary]
    ts_stds = [r['std_rounds_ts'] for r in results_summary]
    ug_means = [r['mean_rounds_ug'] for r in results_summary]
    ug_stds = [r['std_rounds_ug'] for r in results_summary]
    cumuts_means = [r['mean_rounds_cumuts'] for r in results_summary]
    cumuts_stds = [r['std_rounds_cumuts'] for r in results_summary]

    plt.figure(figsize=(12, 8))
    
    # Plot mean rounds with error bars for all three algorithms
    plt.errorbar(K_vals, ts_means, yerr=ts_stds, label='Thompson Sampling', 
                marker='o', capsize=5, capthick=2)
    plt.errorbar(K_vals, ug_means, yerr=ug_stds, label='Uniform-Greedy', 
                marker='s', capsize=5, capthick=2)
    plt.errorbar(K_vals, cumuts_means, yerr=cumuts_stds, label='LinTS (Cumulative)', 
                marker='^', capsize=5, capthick=2)
    
    plt.xlabel('K (Number of Arms)')
    plt.ylabel(f'Mean Rounds to Precision {precision}')
    plt.title(f'Rounds to Achieve Precision {precision}\n(d={d}, Δ={Delta}, σ={sigma}, {n_runs} runs)')
    plt.xscale('log', base=2)
    plt.yscale('log')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    filename = f'rounds_to_precision_{precision}_d{d}_Khigh_2^{d}_Delta{Delta}_sigma{sigma}_runs{n_runs}.png'
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"\nSaved plot: {filename}")

    # Save detailed results as tab-separated text file
    results_file = f'precision_results_cumu_d{d}_Delta{Delta}_sigma{sigma}_runs{n_runs}.txt'
    with open(results_file, 'w') as f:
        f.write(f"Rounds to achieve precision {precision}\n")
        f.write(f"Parameters: d={d}, Delta={Delta}, sigma={sigma}, runs={n_runs}\n\n")
        f.write("K\tTS_mean\tTS_std\tUG_mean\tUG_std\tLinTS_mean\tLinTS_std\tTS_success\tUG_success\tLinTS_success\n")
        for r in results_summary:
            f.write(f"{r['K']}\t{r['mean_rounds_ts']:.2f}\t{r['std_rounds_ts']:.2f}\t"
                   f"{r['mean_rounds_ug']:.2f}\t{r['std_rounds_ug']:.2f}\t"
                   f"{r['mean_rounds_cumuts']:.2f}\t{r['std_rounds_cumuts']:.2f}\t"
                   f"{r['success_rate_ts']:.3f}\t{r['success_rate_ug']:.3f}\t"
                   f"{r['success_rate_cumuts']:.3f}\n")

    # Save raw data as numpy arrays for further analysis
    data_dict = {
        'K_values': np.array(K_vals),
        'ts_means': np.array(ts_means),
        'ts_stds': np.array(ts_stds),
        'ug_means': np.array(ug_means),
        'ug_stds': np.array(ug_stds),
        'cumuts_means': np.array(cumuts_means),
        'cumuts_stds': np.array(cumuts_stds),
        'ts_success_rates': np.array([r['success_rate_ts'] for r in results_summary]),
        'ug_success_rates': np.array([r['success_rate_ug'] for r in results_summary]),
        'cumuts_success_rates': np.array([r['success_rate_cumuts'] for r in results_summary]),
        'all_results': results_summary
    }
    raw_data_file = f'precision_raw_data_cumu_d{d}_Delta{Delta}_sigma{sigma}_runs{n_runs}.npy'
    np.save(raw_data_file, data_dict)
    
    # Also save as CSV for easy import into other tools
    import csv
    csv_file = f'precision_results_cumu_d{d}_Delta{Delta}_sigma{sigma}_runs{n_runs}.csv'
    with open(csv_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['K', 'TS_mean', 'TS_std', 'UG_mean', 'UG_std', 'LinTS_mean', 'LinTS_std', 
                         'TS_success_rate', 'UG_success_rate', 'LinTS_success_rate'])
        for r in results_summary:
            writer.writerow([r['K'], r['mean_rounds_ts'], r['std_rounds_ts'],
                           r['mean_rounds_ug'], r['std_rounds_ug'],
                           r['mean_rounds_cumuts'], r['std_rounds_cumuts'],
                           r['success_rate_ts'], r['success_rate_ug'], r['success_rate_cumuts']])
    
    print(f"Saved results in multiple formats:")
    print(f"  - Text file: {results_file}")
    print(f"  - NumPy dictionary: {raw_data_file}")
    print(f"  - CSV file: {csv_file}")
    
    # Print summary
    print(f"\n=== SUMMARY ===")
    print(f"Experiment: Rounds to achieve precision {precision}")
    print(f"Parameters: d={d}, Delta={Delta}, sigma={sigma}")
    print(f"Runs per K: {n_runs}")
    print(f"K values tested: {len(K_vals)}")
    print("\nBest performance (lowest mean rounds):")
    best_ts_idx = np.argmin(ts_means)
    best_ug_idx = np.argmin(ug_means)
    best_cumuts_idx = np.argmin(cumuts_means)
    print(f"Thompson Sampling: K={K_vals[best_ts_idx]}, {ts_means[best_ts_idx]:.1f} rounds")
    print(f"Uniform-Greedy: K={K_vals[best_ug_idx]}, {ug_means[best_ug_idx]:.1f} rounds")
    print(f"LinTS (Cumulative): K={K_vals[best_cumuts_idx]}, {cumuts_means[best_cumuts_idx]:.1f} rounds")