import numpy as np
from numpy.random import default_rng
from scipy.special import log_expit, expit
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
import copy
import os

def llog(w, x, y):
    xw = x @ w
    return -1.0 * np.sum(y * log_expit(xw) + (1 - y) * log_expit(-xw)) + 0.5*1e-8 * w.T@w

def dllog(w, x, y):
    return (expit(x @ w) - y) @ x + 1e-8 * w 

def hllog(w, x, y):
    p = expit(x @ w)
    W = np.diag(p * (1 - p))
    temp = x.T @ W @ x
    return temp + 1e-8 * np.eye(temp.shape[0])

def logistic_mle(X, y, lam, init=None, tol=1e-8):
    X = np.asarray(X)
    y = np.asarray(y)
    n, d = X.shape
    if init is None:
        w = np.zeros(d)
    else:
        w = np.asarray(init, dtype=float)
        if w.shape != (d,):
            raise ValueError("Initial weights must be a vector of shape (d,)")
    for _ in range(1000):
        grad = dllog(w, X, y)
        hess = hllog(w, X, y)
        step = np.linalg.solve(hess, grad)
        beta = np.linalg.norm(step)
        w = w - np.log1p(beta) / beta * step
        if np.linalg.norm(grad) < tol:
            break
    return w

class LogisticBanditEnv:
    def __init__(self, d, theta_norm, best=0.1, second_best=-0.05, random_state=None):
        '''
        theta_norm is actually M in the paper
        best and second_best are the last coordinate values for the best and second best arms
        d is the dimension of theta (and arms)
        '''
        self.theta_star = theta_norm * np.ones(d, dtype=float)
        self.theta_star[-1] = 1
        self.arms = -1 * np.eye(d)
        self.arms[-1,-1] = best
        second_opt = np.zeros(d)
        second_opt[-1] = second_best
        self.arms = np.vstack([self.arms, second_opt])
        self.rng = default_rng(random_state)
        self.regrets = []
        self.current_arms = None
        self.d = d
        self.K = d+1
        self.S = np.linalg.norm(self.theta_star)

    def sample_arms(self):
        self.current_arms = self.arms
        return self.current_arms

    def step(self, arm_index):
        if self.current_arms is None:
            raise ValueError("Call sample_arms() before step().")
        x = self.current_arms[arm_index]
        p = expit(x.dot(self.theta_star))
        reward = self.rng.binomial(1, p)
        exp_ps = expit(self.current_arms.dot(self.theta_star))
        regret = np.max(exp_ps) - p
        self.regrets.append(regret)
        return reward, regret

    def get_regret(self):
        return np.array(self.regrets)

class LinearThompsonSamplerPrecision:
    def __init__(self, d, env, lam=1.0, delta=0.1, random_state=None):
        self.d = d
        self.lam = lam
        self.delta = delta
        if random_state is None:
            random_state = np.random.randint(0, 2**32 - 1)
        self.rng = default_rng(int(random_state))
        self.V = lam * np.eye(d)
        self.Vinv = np.eye(d) / lam
        self.X = []
        self.y = []
        self.last_theta_hat = np.zeros(d)
        self.env = env

    def select_arm(self, arms):
        theta_tilde = self.rng.multivariate_normal(np.zeros(self.d), self.Vinv)
        return int(np.argmax(np.abs(arms.dot(theta_tilde))))

    def update(self, x, r):
        u = x.reshape(-1,1)
        # update V_t with Sherman-Morrison
        self.V += u.dot(u.T)
        v = self.Vinv.dot(u)
        denom = 1.0 + (u.T.dot(v))[0,0]
        self.Vinv -= (v.dot(v.T)) / denom
        self.X.append(x)
        self.y.append(r)
        X = np.vstack(self.X)
        y = np.array(self.y)
        theta_hat = logistic_mle(X, y, self.lam, init=self.last_theta_hat)
        self.last_theta_hat = theta_hat

    def run_to_precision(self, precision=0.1, max_T=100000):
        for t in range(max_T):
            arms = self.env.sample_arms()
            idx_ts = self.select_arm(arms)
            r, _ = self.env.step(idx_ts)
            self.update(arms[idx_ts], r)
            
            test_arms = self.env.sample_arms()
            scores = test_arms.dot(self.last_theta_hat)
            greedy_idx = np.argmax(scores)
            chosen_p = expit(test_arms[greedy_idx].dot(self.env.theta_star))
            max_p = np.max(expit(test_arms.dot(self.env.theta_star)))
            simple_regret = max_p - chosen_p
            
            if simple_regret <= precision:
                return t + 1
        
        return max_T

class TryHardThompsonSamplerPrecision:
    def __init__(self, env, lam=1.0, S=10.0, delta=0.1, random_state=None):
        self.env = env
        self.lam = lam
        self.S = S
        self.delta = delta
        self.rng = default_rng(random_state)
        self.d = env.d
        self.L = lam * np.eye(self.d)
        self.Linv = np.eye(self.d) / lam
        self.X_data = []
        self.y_data = []
        self.last_theta_bar = np.zeros(self.d)
        self.t = 0

    @staticmethod
    def _mu_dot(z):
        z = np.clip(z, -10, 10)
        p = 1.0 / (1.0 + np.exp(-z))
        return p * (1 - p)
    
    def _theta_prime(self, x, x0):
        # theta_prime is the projection of theta_bar onto B_d(S)
        best_theta = x / np.linalg.norm(x) * self.S
        return best_theta

    def select_arm(self, arms):
        self.t += 1
        theta_tilde = self.rng.multivariate_normal(np.zeros(self.d), self.Linv)
        if not self.X_data:
            theta_bar = np.zeros(self.d)
        else:
            theta_bar = logistic_mle(np.vstack(self.X_data), np.array(self.y_data), 
                                    self.lam, init=self.last_theta_bar)
        self.last_theta_bar = theta_bar
        # self.conf_sets.append((self.t, theta_bar.copy(), beta))
        mu_dots = expit(arms.dot(theta_bar))
        inners = np.abs(arms.dot(theta_tilde))
        scores = mu_dots * inners
        return int(np.argmax(scores)), theta_bar

    def update(self, arm_vec, reward, theta_bar):
        self.X_data.append(arm_vec)
        self.y_data.append(reward)
        theta_prime = self._theta_prime(arm_vec, x0=self.last_theta_bar)
        # update L_t with Sherman-Morrison
        u = arm_vec.reshape(-1)
        w = self._mu_dot(u.dot(theta_prime))
        self.L += w * np.outer(u, u)
        v = self.Linv.dot(u)
        denom = 1 + w * u.dot(v)
        self.Linv -= (w * np.outer(v, v)) / denom

    def run_to_precision(self, precision=0.1, max_T=100000):
        for t in range(max_T):
            arms = self.env.sample_arms()
            idx, theta_bar = self.select_arm(arms)
            reward, regret = self.env.step(idx)
            self.update(arms[idx], reward, theta_bar)
            # theta_bar is the unconstrained MLE
            theta_bar = logistic_mle(np.vstack(self.X_data), np.array(self.y_data), 
                                    self.lam, init=theta_bar)
            
            test_arms = self.env.sample_arms()
            scores = test_arms.dot(theta_bar)
            greedy_idx = np.argmax(scores)
            chosen_p = expit(test_arms[greedy_idx].dot(self.env.theta_star))
            max_p = np.max(expit(test_arms.dot(self.env.theta_star)))
            simple_regret = max_p - chosen_p
            
            if simple_regret <= precision:
                return t + 1 
        
        return max_T 


def run_single_experiment(run_id, d, theta_norm, bp, sbp, lam1, lam2, precision, max_T):
    env_lin = LogisticBanditEnv(d, theta_norm=theta_norm, best=bp, 
                                        second_best=sbp, random_state=run_id)
    env_th = LogisticBanditEnv(d, theta_norm=theta_norm, best=bp, 
                                       second_best=sbp, random_state=run_id)
    
    lin = LinearThompsonSamplerPrecision(d, env=env_lin, lam=lam1, delta=0.05, 
                                         random_state=run_id)
    th = TryHardThompsonSamplerPrecision(env_th, lam=lam2, S=env_th.S+1, delta=0.05, 
                                         random_state=run_id)
    
    rounds_lin = lin.run_to_precision(precision, max_T)
    rounds_th = th.run_to_precision(precision, max_T)
    
    return rounds_lin, rounds_th


if __name__ == "__main__":
    # Problem parameters
    GLOBAL_SEED = 42
    d = 50
    K = d + 1
    bp = 0.3  # best arm parameter
    sbp = -0.3  # second best arm parameter
    lam1 = 1.0  # regularization for LinearTS
    lam2 = 1.0  # regularization for TryHardTS
    
    # Experiment parameters
    precision = 1e-4  # target suboptimality gap
    max_T = 5000  # maximum rounds
    n_runs = 100  # number of runs per theta_norm
    theta_norm_values = np.arange(10, 100, 5)/10  # theta_norm from 1 to 10
    
    print(f"Running experiments for d={d}, K={K}, precision={precision}")
    print(f"Best arm: {bp}, Second best: {sbp}")
    print(f"Theta norm values: {theta_norm_values}")
    print(f"Number of runs per setting: {n_runs}")
    
    results_summary = []
    
    for theta_norm in theta_norm_values:
        print(f"\nRunning experiments for theta_norm={theta_norm}...")
        
        # Run experiments in parallel
        results = Parallel(n_jobs=-1)(
            delayed(run_single_experiment)(
                i + GLOBAL_SEED, d, theta_norm, bp, sbp, lam1, lam2, precision, max_T
            )
            for i in range(n_runs)
        )
        
        rounds_lin = [r[0] for r in results]
        rounds_th = [r[1] for r in results]

        mean_rounds_lin = np.mean(rounds_lin)
        std_rounds_lin = np.std(rounds_lin)
        mean_rounds_th = np.mean(rounds_th)
        std_rounds_th = np.std(rounds_th)

        success_lin = sum(1 for r in rounds_lin if r < max_T)
        success_th = sum(1 for r in rounds_th if r < max_T)
        
        results_summary.append({
            'theta_norm': theta_norm,
            'mean_rounds_lin': mean_rounds_lin,
            'std_rounds_lin': std_rounds_lin,
            'mean_rounds_th': mean_rounds_th,
            'std_rounds_th': std_rounds_th,
            'success_rate_lin': success_lin / n_runs,
            'success_rate_th': success_th / n_runs,
            'rounds_lin': rounds_lin,  # Keep raw data
            'rounds_th': rounds_th
        })
        
        print(f"θ_norm={theta_norm}: LinTS={mean_rounds_lin:.1f}±{std_rounds_lin:.1f} "
              f"(success: {success_lin}/{n_runs}), "
              f"TryHardTS={mean_rounds_th:.1f}±{std_rounds_th:.1f} "
              f"(success: {success_th}/{n_runs})")
    

    dir_path = f"./precision_exp_d{d}_K{K}_bp{bp}_sbp{sbp}_precision{precision}_runs{n_runs}"
    os.makedirs(dir_path, exist_ok=True)

    theta_norms = [r['theta_norm'] for r in results_summary]
    lin_means = [r['mean_rounds_lin'] for r in results_summary]
    lin_stds = [r['std_rounds_lin'] for r in results_summary]
    th_means = [r['mean_rounds_th'] for r in results_summary]
    th_stds = [r['std_rounds_th'] for r in results_summary]
    
    # Create plot
    plt.figure(figsize=(12, 8))
    plt.errorbar(theta_norms, lin_means, yerr=lin_stds, label='LinTS+MLE', 
                marker='o', capsize=5, capthick=2, color='blue')
    plt.errorbar(theta_norms, th_means, yerr=th_stds, label='TryHardTS', 
                marker='s', capsize=5, capthick=2, color='green')
    
    plt.xlabel('M', fontsize=14)
    plt.ylabel(f'Mean Rounds to Precision {precision}', fontsize=14)
    plt.title(f'Rounds to Achieve Precision {precision}\n'
             f'(d={d}, K={K}, bp={bp}, sbp={sbp}, {n_runs} runs)', fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    plot_file = f'{dir_path}/rounds_to_precision.png'
    plt.savefig(plot_file, dpi=300, bbox_inches='tight')
    print(f"\nSaved plot: {plot_file}")
    
    # Create success rate plot
    plt.figure(figsize=(12, 8))
    lin_success = [r['success_rate_lin'] for r in results_summary]
    th_success = [r['success_rate_th'] for r in results_summary]
    
    plt.plot(theta_norms, lin_success, label='LinTS+MLE', marker='o', color='blue')
    plt.plot(theta_norms, th_success, label='TryHardTS', marker='s', color='green')
    plt.xlabel('M', fontsize=14)
    plt.ylabel('Success Rate', fontsize=14)
    plt.title(f'Success Rate (Achieving Precision {precision} within {max_T} rounds)\n'
             f'(d={d}, K={K}, bp={bp}, sbp={sbp}, {n_runs} runs)', fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.ylim([0, 1.05])
    plt.tight_layout()
    
    success_plot_file = f'{dir_path}/success_rates.png'
    plt.savefig(success_plot_file, dpi=300, bbox_inches='tight')
    print(f"Saved plot: {success_plot_file}")
    
    # Save detailed results as text file
    results_file = f'{dir_path}/precision_results.txt'
    with open(results_file, 'w') as f:
        f.write(f"Rounds to achieve precision {precision}\n")
        f.write(f"Parameters: d={d}, K={K}, bp={bp}, sbp={sbp}, runs={n_runs}\n")
        f.write(f"Regularization: lam_LinTS={lam1}, lam_TryHardTS={lam2}\n\n")
        f.write("θ_norm\tLinTS_mean\tLinTS_std\tTryHardTS_mean\tTryHardTS_std\t"
               "LinTS_success\tTryHardTS_success\n")
        for r in results_summary:
            f.write(f"{r['theta_norm']}\t{r['mean_rounds_lin']:.2f}\t"
                   f"{r['std_rounds_lin']:.2f}\t{r['mean_rounds_th']:.2f}\t"
                   f"{r['std_rounds_th']:.2f}\t{r['success_rate_lin']:.3f}\t"
                   f"{r['success_rate_th']:.3f}\n")
    
    # Save raw data as numpy arrays
    data_dict = {
        'theta_norm_values': np.array(theta_norms),
        'lin_means': np.array(lin_means),
        'lin_stds': np.array(lin_stds),
        'th_means': np.array(th_means),
        'th_stds': np.array(th_stds),
        'lin_success_rates': np.array(lin_success),
        'th_success_rates': np.array(th_success),
        'all_results': results_summary,
        'parameters': {
            'd': d, 'K': K, 'bp': bp, 'sbp': sbp,
            'precision': precision, 'max_T': max_T, 'n_runs': n_runs,
            'lam1': lam1, 'lam2': lam2
        }
    }
    raw_data_file = f'{dir_path}/raw_data.npy'
    np.save(raw_data_file, data_dict)
    
    # Save as CSV for easy import
    import csv
    csv_file = f'{dir_path}/precision_results.csv'
    with open(csv_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['theta_norm', 'LinTS_mean', 'LinTS_std', 'TryHardTS_mean', 
                        'TryHardTS_std', 'LinTS_success_rate', 'TryHardTS_success_rate'])
        for r in results_summary:
            writer.writerow([r['theta_norm'], r['mean_rounds_lin'], r['std_rounds_lin'],
                           r['mean_rounds_th'], r['std_rounds_th'],
                           r['success_rate_lin'], r['success_rate_th']])
    
    print(f"\nSaved results in multiple formats:")
    print(f"  - Text file: {results_file}")
    print(f"  - NumPy dictionary: {raw_data_file}")
    print(f"  - CSV file: {csv_file}")
    
    # Print summary
    print(f"\n=== SUMMARY ===")
    print(f"Experiment: Rounds to achieve precision {precision}")
    print(f"Parameters: d={d}, K={K}, bp={bp}, sbp={sbp}")
    print(f"Runs per theta_norm: {n_runs}")
    print(f"Theta norm values tested: {len(theta_norms)}")
    
    # Find best performance
    best_lin_idx = np.argmin(lin_means)
    best_th_idx = np.argmin(th_means)
    print(f"\nBest performance (lowest mean rounds):")
    print(f"LinTS+MLE: θ_norm={theta_norms[best_lin_idx]}, "
          f"{lin_means[best_lin_idx]:.1f} rounds")
    print(f"TryHardTS: θ_norm={theta_norms[best_th_idx]}, "
          f"{th_means[best_th_idx]:.1f} rounds")
    
    # Compare algorithms
    print(f"\nAlgorithm comparison:")
    lin_better = sum(1 for i in range(len(theta_norms)) if lin_means[i] < th_means[i])
    print(f"LinTS+MLE better in {lin_better}/{len(theta_norms)} settings")
    print(f"TryHardTS better in {len(theta_norms)-lin_better}/{len(theta_norms)} settings")