import numpy as np
import matplotlib.pyplot as plt
import os

# --- 1. Variance-Weighted OFUL (VW-OFUL) Algorithm Class ---
class OFUL:
    """
    Implements the Variance-Weighted OFUL algorithm, where updates are weighted
    by the inverse of the observation variance (1/sigma^2).
    """

    def __init__(self, n_features, lambda_reg, delta, T):
        self.n_features = n_features
        self.lambda_reg = lambda_reg
        self.delta = delta
        self.T = T

        # A_t is the weighted covariance matrix (lambda*I + sum_{s=1}^{t-1} (1/sigma_s^2) * x_s * x_s^T)
        self.A_t = lambda_reg * np.identity(n_features)
        self.A_t_inv = np.linalg.inv(self.A_t)
        # b_t is the weighted reward-context vector
        self.b_t = np.zeros(n_features)

        self.theta_hat = np.dot(self.A_t_inv, self.b_t)

    def _get_beta(self, t):
        """Calculates the exploration bonus (beta) for round t."""
        # Note: The bound calculation remains the same for simplicity,
        # though a formal bound would account for the weighting.
        term1 = np.sqrt(self.lambda_reg)
        term2 = np.sqrt(2 * np.log(1 / self.delta) + self.n_features * np.log(1 + t / self.lambda_reg))
        return term1 * term2

    def select_action(self, contexts, t):
        """Selects the best arm from the *currently provided* decision set."""
        beta_t = self._get_beta(t)
        best_ucb = -np.inf
        selected_arm_index = -1

        for a in range(contexts.shape[0]):
            x_a = contexts[a]
            mean_reward_estimate = np.dot(x_a, self.theta_hat)
            uncertainty_term = np.sqrt(np.dot(x_a, np.dot(self.A_t_inv, x_a)))
            ucb_a = mean_reward_estimate + beta_t * uncertainty_term

            if ucb_a > best_ucb:
                best_ucb = ucb_a
                selected_arm_index = a

        return selected_arm_index

    def update(self, selected_context, observed_reward, current_noise_std):
        """
        Updates the internal model using Variance-Weighted Least Squares.
        The weight w_t is 1 / sigma_t^2.
        """
        x_t = selected_context.reshape(-1, 1)

        # Calculate the weight based on the observed noise variance
        variance = current_noise_std ** 2
        weight = 1.0 / variance

        # --- 1. Update A_t (Weighted Covariance Matrix) ---
        # A_t = A_{t-1} + w_t * x_t * x_t^T
        weighted_x_x_T = weight * np.dot(x_t, x_t.T)
        self.A_t += weighted_x_x_T

        # --- 2. Update A_t_inv (Efficient Update using Sherman-Morrison) ---
        # The Sherman-Morrison update must now account for the weight factor w_t
        A_inv_prev = self.A_t_inv

        # numerator = w_t * A_inv * x_t * x_t^T * A_inv
        numerator = weight * np.dot(np.dot(A_inv_prev, x_t), np.dot(x_t.T, A_inv_prev))

        # denominator = 1 + w_t * x_t^T * A_inv * x_t
        denominator = 1 + weight * np.dot(x_t.T, np.dot(A_inv_prev, x_t))

        self.A_t_inv = A_inv_prev - numerator / denominator

        self.A_t_inv = (self.A_t_inv + self.A_t_inv.T) / 2

        # --- 3. Update b_t (Weighted Reward-Context Vector) ---
        # b_t = b_{t-1} + w_t * r_t * x_t
        weighted_r_x = weight * observed_reward * x_t.flatten()
        self.b_t += weighted_r_x

        # --- 4. Update Parameter Estimate ---
        self.theta_hat = np.dot(self.A_t_inv, self.b_t)


# --- 2. Simulation Environment and Setup ---

def linear_reward_function(context, theta_star, noise_std):
    """Generates a reward based on a linear model with Gaussian noise."""
    mean_reward = np.dot(context, theta_star)
    reward = mean_reward + np.random.normal(0, noise_std)
    return reward, mean_reward


def setup_simulation_parameters(d_total, seed=99):
    """
    Sets up the two orthogonal true parameter vectors (theta_star).
    These are fixed (unchanged) across all rounds.
    """
    #np.random.seed(seed)

    d_instance = d_total // 2

    # Theta 1: Non-zero only in the first d_instance (5) dimensions
    theta_1_part = np.random.rand(d_instance)
    theta_1 = np.concatenate([theta_1_part, np.zeros(d_instance)])

    # Theta 2: Non-zero only in the second d_instance (5) dimensions
    theta_2_part = np.random.rand(d_instance)
    theta_2 = np.concatenate([np.zeros(d_instance), theta_2_part])

    np.random.seed(None)
    return theta_1, theta_2


def run_vw_oful_simulation_piecewise(n_arms, d_total, T_total, theta_1, theta_2, noise_std_1, noise_std_2, T_switch,
                                     lambda_reg=1.0, delta=0.1):
    """
    Runs the Variance-Weighted OFUL simulation with two distinct phases.
    """
    vw_oful_agent = OFUL(d_total, lambda_reg, delta, T_total)
    d_instance = d_total // 2

    # Define the two masking arrays
    mask_1 = np.concatenate([np.ones(d_instance), np.zeros(d_instance)])
    mask_2 = np.concatenate([np.zeros(d_instance), np.ones(d_instance)])

    cumulative_regret = 0
    instantaneous_regrets = []

    for t in range(1, T_total + 1):

        base_contexts = np.random.rand(n_arms, d_total)

        # --- PHASE SWITCH LOGIC, CONTEXT MASKING & NOISE SELECTION ---
        if t <= T_switch:
            current_theta = theta_1
            current_noise_std = noise_std_1
            contexts = base_contexts * mask_1
        else:
            current_theta = theta_2
            current_noise_std = noise_std_2
            contexts = base_contexts * mask_2

        # 2. Find the Optimal Arm (Oracle)
        optimal_mean_reward = -np.inf
        for a in range(n_arms):
            _, mean_reward = linear_reward_function(contexts[a], current_theta, 0)
            if mean_reward > optimal_mean_reward:
                optimal_mean_reward = mean_reward

        # 3. Agent Selects Arm
        a_index = vw_oful_agent.select_action(contexts, t)
        x_t = contexts[a_index]

        # 4. Observe Reward
        r_t, mean_reward_t = linear_reward_function(x_t, current_theta, current_noise_std)

        # 5. Calculate Regret
        regret_t = optimal_mean_reward - mean_reward_t
        instantaneous_regrets.append(regret_t)
        cumulative_regret += regret_t

        # 6. Agent Updates Model (Passing the current noise level for weighting)
        vw_oful_agent.update(x_t, r_t, current_noise_std)

    return cumulative_regret, instantaneous_regrets


# --- 3. Run and Plot Simulation ---
if __name__ == '__main__':
    # >>> NEW: Set a global random seed for reproducibility <<<
    GLOBAL_SEED = 42
    np.random.seed(GLOBAL_SEED)

    N_ARMS = 32  # Fixed size of the decision set
    D_TOTAL = 10  # Total feature dimension
    T_TOTAL = 4000  # Total number of rounds
    T_SWITCH = 2000  # Round where the environment changes

    # Environment Parameters
    NOISE_STD_1 = 0.1  # Low noise for the first phase (Weight = 1/(0.01^2) = 10,000)
    NOISE_STD_2 = 1.0  # High noise for the second phase (Weight = 1/(0.5^2) = 4)

    # OFUL Hyperparameters
    LAMBDA_REG = 1.0
    DELTA = 0.05
    N_SIMULATIONS = 100

    # Setup the two fixed orthogonal parameter vectors
    THETA_1, THETA_2 = setup_simulation_parameters(D_TOTAL)

    print("--- Variance-Weighted OFUL (VW-OFUL) Simulation Setup ---")
    print(f"Total Rounds (T): {T_TOTAL}, Dimension (d): {D_TOTAL}")
    print("Agent uses Inverse Variance Weighting (VW-OLS).")
    print(f"Switch at T={T_SWITCH}. Noise: {NOISE_STD_1} (High Weight) -> {NOISE_STD_2} (Low Weight)")

    # Run multiple simulations
    all_regrets = []
    for i in range(N_SIMULATIONS):
        np.random.seed(GLOBAL_SEED + i)
        _, regrets = run_vw_oful_simulation_piecewise(
            N_ARMS, D_TOTAL, T_TOTAL,
            THETA_1, THETA_2,
            NOISE_STD_1, NOISE_STD_2,
            T_SWITCH, LAMBDA_REG, DELTA
        )
        all_regrets.append(regrets)

    # Calculate average cumulative regret
    avg_regrets = np.mean(all_regrets, axis=0)
    avg_cumulative_regret = np.cumsum(avg_regrets)

    print(f"\nAverage Cumulative Regret after {T_TOTAL} steps: {avg_cumulative_regret[-1]:.2f}")

    # Plotting the Results
    plt.figure(figsize=(10, 6))
    plt.plot(np.arange(1, T_TOTAL + 1), avg_cumulative_regret,
             label='OFUL Average Cumulative Regret')

    # Highlight the switch point
    plt.axvline(x=T_SWITCH, color='r', linestyle='--', label='Environment Switch (t=1000)')
    plt.text(T_SWITCH - 50, avg_cumulative_regret[-1] * 0.9,
             'Change Point', color='red', rotation=90, verticalalignment='center')

    plt.title('OFUL: Piecewise-Stationary Linear Bandit (Change in $\\theta^*$ and $\\sigma$)')
    plt.xlabel('Time Steps (t)')
    plt.ylabel('Cumulative Regret')
    plt.grid(True, linestyle='--')
    plt.legend()

    # 2. Define the save path using the instance name
    FOLDER = 'simulation_results'  # Reuse the folder defined in the previous saving step
    if not os.path.exists(FOLDER):
        os.makedirs(FOLDER)

    plot_filename = os.path.join(FOLDER, "instance3_cumulative_regret_plot.png")
    # 3. Save Plot to Image File
    plt.savefig(plot_filename)
    print("✅ Plot saved successfully to: instance3")
    # 4. Display Plot (optional, but standard for interactive scripting)

    INSTANCE_NAME = 'instance3'
    summary_text = []
    summary_text.append("## 📊 Cumulative Regret Summary (Every 500 Rounds)")
    summary_text.append("| Round (t) | Avg. Cumulative Regret | Phase Change |")
    summary_text.append("| :---: | :---: | :---: |")

    report_rounds = np.arange(500, T_TOTAL + 1, 500)

    for t in report_rounds:
        # Array indices are 0-based, so for round t, the index is t-1
        index = t - 1
        regret_value = avg_cumulative_regret[index]

        # Determine phase status
        phase_status = "Switch" if t == T_SWITCH else ("High Noise" if t > T_SWITCH else "Low Noise")

        summary_text.append(f"| {t:,} | {regret_value:.4f} | {phase_status} |")

    summary_text.append("-" * 50)

    # Create folder if it doesn't exist
    if not os.path.exists(FOLDER):
        os.makedirs(FOLDER)

    # Save to TXT file
    txt_filename = os.path.join(FOLDER, f"{INSTANCE_NAME}_regret_summary.txt")
    with open(txt_filename, 'w') as f:
        f.write("\n".join(summary_text))

    print(f"✅ Regret summary saved successfully to: {txt_filename}")