import numpy as np
from numpy.random import default_rng
from scipy.special import log_expit,expit
from itertools import product
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
import copy
import os

# def expit(z):
#     z = np.clip(z, -15, 15)
#     return 1.0 / (1.0 + np.exp(-z))


def llog(w, x, y):
    xw = x @ w
    return -1.0 * np.sum(y * log_expit(xw) + (1 - y) * log_expit(-xw)) + 0.5*1e-8 * w.T@w


def dllog(w, x, y):
    return (expit(x @ w) - y) @ x + 1e-8 * w 


def hllog(w, x, y):
    p = expit(x @ w)  # Sigmoid probabilities
    W = np.diag(p * (1 - p))  # Diagonal weight matrix
    temp = x.T @ W @ x
    return temp + 1e-8 * np.eye(temp.shape[0])  # Hessian matrix
def logistic_mle(X, y, lam, init=None, tol=1e-8):
    X = np.asarray(X)
    y = np.asarray(y)
    n, d = X.shape
    if init is None:
        w = np.zeros(d)
    else:
        w = np.asarray(init, dtype=float)
        if w.shape != (d,):
            raise ValueError("Initial weights must be a vector of shape (d,)")
    for _ in range(1000):
        grad = dllog(w, X, y)
        hess = hllog(w, X, y)
        step = np.linalg.solve(hess, grad)
        beta = np.linalg.norm(step)
        w = w - np.log1p(beta) / beta * step
        if np.linalg.norm(grad) < tol:
            break
    return w


    
class HypercubeLogisticBanditEnv:
    '''
        theta_norm is actually M in the paper
        best and second_best are the last coordinate values for the best and second best arms
        d is the dimension of theta (and arms)
    '''
    def __init__(self, d, theta_norm, best=0.1,second_best=-0.05,random_state=None):
        self.theta_star = theta_norm * np.ones(d, dtype=float)
        self.theta_star[-1] = 1
        self.arms = -1 * np.eye(d)
        self.arms[-1,-1]=best
        second_opt = np.zeros(d)
        second_opt[-1] = second_best
        self.arms = np.vstack([self.arms, second_opt])
        self.rng = default_rng(random_state)
        self.regrets = []
        self.current_arms = None
        self.d=d
        self.K=d+1
        self.S = np.linalg.norm(self.theta_star)  # norm of the true theta_star

    def sample_arms(self):
        self.current_arms = self.arms
        return self.current_arms

    def step(self, arm_index):
        if self.current_arms is None:
            raise ValueError("Call sample_arms() before step().")
        x = self.current_arms[arm_index]
        p = expit(x.dot(self.theta_star))
        reward = self.rng.binomial(1, p)

        exp_ps = expit(self.current_arms.dot(self.theta_star))
        regret = np.max(exp_ps) - p
        self.regrets.append(regret)
        return reward, regret

    def get_regret(self):
        return np.array(self.regrets)

class LinearThompsonSampler:
    def __init__(self, d, env, lam=1.0, delta=0.1, random_state=None,n_samples=100):
        self.d = d
        self.lam = lam
        self.delta = delta
        if random_state is None:
            random_state = np.random.randint(0, 2**32 - 1)
            print(random_state)
        self.rng = default_rng(int(random_state))
        self.V = lam * np.eye(d)
        self.Vinv = np.eye(d) / lam
        self.X = []
        self.y = []
        self.simple = []
        self.last_theta_hat = np.zeros(d)
        self.env = env
        self.n_samples = n_samples
        self.test_env = copy.deepcopy(env)
        self.test_arms_sets = np.stack([
            self.test_env.sample_arms()
            for _ in range(self.n_samples)
        ], axis=0)
        P_true = expit(self.test_arms_sets.dot(env.theta_star))
        self.p_max  = P_true.max(axis=1)
        self.P_true = P_true
        self.theta_hat_history = []

    def select_arm(self, arms):
        theta_tilde = self.rng.multivariate_normal(np.zeros(self.d), self.Vinv)
        return int(np.argmax(np.abs(arms.dot(theta_tilde))))

    def update(self, x, r):
        u = x.reshape(-1,1)
        # update V_t with Sherman-Morrison
        self.V += u.dot(u.T)
        v = self.Vinv.dot(u)
        denom = 1.0 + (u.T.dot(v))[0,0]
        self.Vinv -= (v.dot(v.T)) / denom
        self.X.append(x)
        self.y.append(r)
        X = np.vstack(self.X)
        y = np.array(self.y)
        theta_hat = logistic_mle(X, y, self.lam, init=self.last_theta_hat)
        self.theta_hat_history.append(theta_hat.copy())
        self.last_theta_hat = theta_hat

    def fit_and_act(self, arms):
        if not self.X:
            theta_hat = np.zeros(self.d)
        else:
            theta_hat = self.last_theta_hat
        return int(np.argmax(arms.dot(theta_hat)))

    def run(self, T):
        self.simple.clear()
        for t in range(T):
            arms = self.env.sample_arms()
            idx_ts = self.select_arm(arms)
            r, _ = self.env.step(idx_ts)
            self.update(arms[idx_ts], r)
            scores = self.test_arms_sets.dot(self.last_theta_hat)
            greedy_ix = scores.argmax(axis=1)
            chosen_p  = self.P_true[np.arange(self.n_samples), greedy_ix]
            regrets   = self.p_max - chosen_p      
            self.simple.append(regrets.mean())
        return np.array(self.simple),np.vstack(self.theta_hat_history)

class TryHardThompsonSampler:
    def __init__(self, env, lam=1.0, S=10.0, delta=0.1, random_state=None, n_samples=1000):
        self.env = env
        self.lam = lam
        self.S = S
        self.delta = delta
        self.rng = default_rng(random_state)
        self.d = env.d
        self.L = lam * np.eye(self.d)
        self.Linv = np.eye(self.d) / lam
        self.X_data = []
        self.y_data = []
        self.last_theta_bar = np.zeros(self.d)
        self.t = 0
        self.cached_theta = None
        self.last_checked = 0
        self.conf_sets = []
        self.projection_count = 0
        self.simple = []

        self.env = env
        self.n_samples = n_samples
        self.test_env = copy.deepcopy(env)
        self.test_arms_sets = np.stack([
            self.test_env.sample_arms()
            for _ in range(self.n_samples)
        ], axis=0)
        P_true = expit(self.test_arms_sets.dot(env.theta_star))
        self.p_max = P_true.max(axis=1)                              
        self.P_true = P_true
        self.cnt_prime = 0
        self.theta_bar_history=[]

    @staticmethod
    def _mu_dot(z):
        z = np.clip(z, -10, 10)
        p = 1.0 / (1.0 + np.exp(-z))
        return p * (1 - p)
    
    def _loss(self, theta, X, y):
        if X.size == 0:
            return self.lam * np.dot(theta, theta)

        eps = 1e-8 
        p   = expit(X.dot(theta))
        ll  = y * np.log(p + eps) + (1 - y) * np.log(1 - p + eps)
        return -np.sum(ll) + self.lam * np.dot(theta, theta)
    
            

    def _theta_prime(self, x, x0):
        # theta_prime is the projection of theta_bar onto B_d(S)
        best_theta = x / np.linalg.norm(x) * self.S 
        return best_theta

    def select_arm(self, arms):
        self.t += 1
        theta_tilde = self.rng.multivariate_normal(np.zeros(self.d), self.Linv)
        if not self.X_data:
            theta_bar = np.zeros(self.d)
        else:
            theta_bar = logistic_mle(np.vstack(self.X_data), np.array(self.y_data), self.lam, init=self.last_theta_bar)
        self.last_theta_bar = theta_bar
        self.theta_bar_history.append(theta_bar.copy().reshape(1,-1))
        # self.conf_sets.append((self.t, theta_bar.copy(), beta))
        mu_dots = expit(arms.dot(theta_bar))
        inners = np.abs(arms.dot(theta_tilde))
        scores = mu_dots * inners
        return int(np.argmax(scores)), theta_bar

    def update(self, arm_vec, reward, theta_bar):
        self.X_data.append(arm_vec)
        self.y_data.append(reward)
        theta_prime = self._theta_prime(arm_vec, x0=self.last_theta_bar)
        # update L_t with Sherman-Morrison
        u = arm_vec.reshape(-1)
        w = self._mu_dot(u.dot(theta_prime))
        self.L += w * np.outer(u, u)
        v = self.Linv.dot(u)
        denom = 1 + w * u.dot(v)
        self.Linv -= (w * np.outer(v, v)) / denom

    def run(self, T):
        for _ in range(T):
            arms = self.env.sample_arms()
            idx, theta_bar = self.select_arm(arms)
            reward, regret = self.env.step(idx)
            self.update(arms[idx], reward, theta_bar)
            # theta_bar is the unconstrained MLE
            theta_bar = logistic_mle(np.vstack(self.X_data), np.array(self.y_data), self.lam, init=theta_bar)
            scores = self.test_arms_sets.dot(theta_bar)
            greedy_ix = scores.argmax(axis=1)
            chosen_p = self.P_true[np.arange(self.n_samples), greedy_ix]
            regrets = self.p_max - chosen_p      
            self.simple.append(regrets.mean())
            
        return self.simple,np.vstack(self.theta_bar_history)


if __name__=="__main__":
    GLOBAL_SEED = 42
    d, T = 10, 1500
    K=d+1 
    norm_theta_star = 4.0
    bp=0.3
    sbp=-0.3
    env = HypercubeLogisticBanditEnv(d,theta_norm=norm_theta_star,best=bp,second_best=sbp, random_state=1)
    print(env.arms)
    print(expit(env.arms.dot(env.theta_star)))
    print(env.arms.dot(env.theta_star))
    runs = 100
    lam1 = 1.0
    lam2 = 1.0

    def run_lin(rnd):
        env = HypercubeLogisticBanditEnv(d,theta_norm=norm_theta_star,best=bp,second_best=sbp, random_state=rnd+GLOBAL_SEED)
    
        lin = LinearThompsonSampler(d, env=env, lam=lam1, delta=0.05, random_state=rnd, n_samples=1)
        return lin.run(T)

    def run_th(rnd):
        env = HypercubeLogisticBanditEnv(d,theta_norm=norm_theta_star,best=bp,second_best=sbp, random_state=rnd+GLOBAL_SEED)
        th = TryHardThompsonSampler(env, lam=lam2, S=env.S+1, delta=0.05, random_state=rnd,n_samples=1)
        return th.run(T)

    data_lin = Parallel(n_jobs=-1)(delayed(run_lin)(i) for i in range(runs))
    data_th  = Parallel(n_jobs=-1)(delayed(run_th)(i) for i in range(runs))
    theta_lin = [np.array(data_lin[i][1]) for i in range(runs)]
    theta_th = [np.array(data_th[i][1]) for i in range(runs)]
    sr_lin = [data_lin[i][0] for i in range(runs)]
    sr_th = [data_th[i][0] for i in range(runs)]
    all_sr_lin = np.vstack(sr_lin)
    all_sr_th  = np.vstack(sr_th)
    all_theta_lin = np.array(theta_lin)
    all_theta_th  = np.array(theta_th)

    std_sr_lin = np.std(all_sr_lin, axis=0)
    std_sr_th = np.std(all_sr_th, axis=0)
    std_theta_lin = np.std(all_theta_lin, axis=0)
    std_theta_th = np.std(all_theta_th, axis=0) 

    avg_theta_lin = np.mean(all_theta_lin, axis=0)

    avg_theta_th  = np.mean(all_theta_th, axis=0)
    print(avg_theta_th.shape)
    avg_sr_lin = np.mean(all_sr_lin, axis=0)
    avg_sr_th  = np.mean(all_sr_th, axis=0)
    rounds = np.arange(1, T+1)

    theta_star = env.theta_star
    print(theta_star)

    dir_path = f"./{norm_theta_star}_K{K}_d{d}_lamLin_{lam1}_lamLog_{lam2}_rescale_{T}_run_{runs}_bp_{bp}_sbp_{sbp}"
    os.makedirs(dir_path, exist_ok=True)

    # plot regret
    
    plt.figure(figsize=(8, 5))
    plt.tick_params(labelsize=14)
    plt.plot(rounds, avg_sr_lin, label='LinTS+MLE', color='blue')
    plt.fill_between(rounds, np.maximum(avg_sr_lin - std_sr_lin,np.zeros(avg_sr_lin.shape)), avg_sr_lin + std_sr_lin, 
                 alpha=0.3, color='blue')
    plt.plot(rounds, avg_sr_th,  label='TryHardTS', color='green')
    plt.fill_between(rounds, np.maximum(avg_sr_th - std_sr_th,np.zeros(avg_sr_th.shape)), avg_sr_th + std_sr_th, 
                 alpha=0.3, color='green')
    plt.xlabel('Round',fontsize=14)
    plt.ylabel('Simple Regret', fontsize=14)
    plt.title(f'Average Simple Regret over {runs} runs, d={d}, M={norm_theta_star}')
    plt.legend(fontsize=12)
    plt.grid(True)
        # save the figure with descriptive filename
    filename = f"{dir_path}/log_simple_regret_norm{norm_theta_star}_K{K}_d{d}_lamLin_{lam1}_lamLog_{lam2}_rescale_{T}run_{runs}_bp_{bp}_sbp_{sbp}.png"
    plt.savefig(filename)
    print(f"Saved figure to {filename}")

    # plot theta_hat[0]
    for i in range(d):
        plt.figure(figsize=(8, 5))
        plt.tick_params(labelsize=14)
        plt.plot(rounds, avg_theta_lin[:,i], label='LinTS+MLE', color='blue')
        plt.fill_between(rounds, avg_theta_lin[:,i] - std_theta_lin[:,i], 
                        avg_theta_lin[:,i] + std_theta_lin[:,i], 
                        alpha=0.3, color='blue')
        plt.plot(rounds, avg_theta_th[:,i], label='TryHardTS', color='green')
        plt.fill_between(rounds, avg_theta_th[:,i] - std_theta_th[:,i], 
                        avg_theta_th[:,i] + std_theta_th[:,i], 
                        alpha=0.3, color='green')
        plt.plot(rounds, theta_star[i]*np.ones(T), label='Ground Truth', 
                linestyle='--', color='black')
        plt.xlabel('Round', fontsize=14)
        plt.ylabel(f'theta_hat[{i}]',fontsize=14)
        plt.title(f'Average estimation of theta_star[{i}] over {runs} runs (with std), M={norm_theta_star}')
        plt.legend()
        plt.grid(True)
        filename = f"{dir_path}/log_theta_{i}_norm{norm_theta_star}_K{K}_d{d}_lamLin_{lam1}_lamLog_{lam2}_rescale_hor_{T}_bp_{bp}_sbp_{sbp}.png"
        plt.savefig(filename)
        print(f"Saved figure to {filename}")
    