import numpy as np
import matplotlib.pyplot as plt
import os

# --- 1. OFUL Algorithm Class (Unchanged Core Logic) ---
class OFUL:
    """
    Implements the Optimistic Follow the Leader (OFUL) algorithm for
    Linear Contextual Bandits. Uses Sherman-Morrison for efficient O(d^2) updates.
    """

    def __init__(self, n_features, lambda_reg, delta, T):
        self.n_features = n_features
        self.lambda_reg = lambda_reg
        self.delta = delta
        self.T = T

        # A_t = lambda * I + sum_{s=1}^{t-1} x_s * x_s^T (Covariance matrix)
        self.A_t = lambda_reg * np.identity(n_features)
        self.A_t_inv = np.linalg.inv(self.A_t)
        # b_t = sum_{s=1}^{t-1} r_s * x_s (Reward-weighted context vector)
        self.b_t = np.zeros(n_features)

        self.theta_hat = np.dot(self.A_t_inv, self.b_t)

    def _get_beta(self, t):
        """Calculates the exploration bonus (beta) for round t."""
        term1 = np.sqrt(self.lambda_reg)
        term2 = np.sqrt(2 * np.log(1 / self.delta) + self.n_features * np.log(1 + t / self.lambda_reg))
        return term1 * term2

    def select_action(self, contexts, t):
        """Selects the best arm from the *currently provided* decision set."""
        beta_t = self._get_beta(t)
        best_ucb = -np.inf
        selected_arm_index = -1

        for a in range(contexts.shape[0]):
            x_a = contexts[a]

            # Mean reward estimate: x_a^T * theta_hat
            mean_reward_estimate = np.dot(x_a, self.theta_hat)

            # Uncertainty term: beta_t * sqrt(x_a^T * A_inv * x_a)
            uncertainty_term = np.sqrt(np.dot(x_a, np.dot(self.A_t_inv, x_a)))

            ucb_a = mean_reward_estimate + beta_t * uncertainty_term

            if ucb_a > best_ucb:
                best_ucb = ucb_a
                selected_arm_index = a

        return selected_arm_index

    def update(self, selected_context, observed_reward):
        """
        Updates the internal model efficiently using the Sherman-Morrison formula.
        """
        x_t = selected_context.reshape(-1, 1)  # Reshape to column vector (d x 1)

        # --- 1. Update A_t (Covariance Matrix) ---
        self.A_t += np.dot(x_t, x_t.T)

        # --- 2. Update A_t_inv (Efficient Update using Sherman-Morrison) ---
        A_inv_prev = self.A_t_inv
        numerator = np.dot(np.dot(A_inv_prev, x_t), np.dot(x_t.T, A_inv_prev))
        denominator = 1 + np.dot(x_t.T, np.dot(A_inv_prev, x_t))
        self.A_t_inv = A_inv_prev - numerator / denominator

        self.A_t_inv = (self.A_t_inv + self.A_t_inv.T) / 2

        # --- 3. Update b_t (Reward-Weighted Context Vector) ---
        self.b_t += observed_reward * x_t.flatten()

        # --- 4. Update Parameter Estimate ---
        self.theta_hat = np.dot(self.A_t_inv, self.b_t)


# --- 2. Simulation Environment and Setup ---

def linear_reward_function(context, theta_star, noise_std):
    """Generates a reward based on a linear model with Gaussian noise."""
    mean_reward = np.dot(context, theta_star)
    reward = mean_reward + np.random.normal(0, noise_std)
    return reward, mean_reward


def setup_simulation_parameters(d_total, seed=99):
    """
    Sets up the two orthogonal true parameter vectors (theta_star).
    These are fixed (unchanged) across all rounds.
    """
    #np.random.seed(seed)

    d_instance = d_total // 2

    # Theta 1: Non-zero only in the first d_instance (5) dimensions
    theta_1_part = np.random.rand(d_instance)
    theta_1 = np.concatenate([theta_1_part, np.zeros(d_instance)])

    # Theta 2: Non-zero only in the second d_instance (5) dimensions
    theta_2_part = np.random.rand(d_instance)
    theta_2 = np.concatenate([np.zeros(d_instance), theta_2_part])

    np.random.seed(None)
    return theta_1, theta_2


def run_oful_simulation_piecewise(n_arms, d_total, T_total, theta_1, theta_2, noise_std_1, noise_std_2, T_switch,
                                  lambda_reg=1.0, delta=0.1):
    """
    Runs the OFUL simulation with two distinct phases (piecewise-stationary),
    where contexts are masked to be orthogonal to the irrelevant theta vector.
    """
    oful_agent = OFUL(d_total, lambda_reg, delta, T_total)
    d_instance = d_total // 2

    # Define the two masking arrays
    mask_1 = np.concatenate([np.ones(d_instance), np.zeros(d_instance)])  # [1,1,1,1,1, 0,0,0,0,0]
    mask_2 = np.concatenate([np.zeros(d_instance), np.ones(d_instance)])  # [0,0,0,0,0, 1,1,1,1,1]

    cumulative_regret = 0
    instantaneous_regrets = []

    for t in range(1, T_total + 1):

        # 1. Generate Base Contexts (Full d=10 context)
        base_contexts = np.random.rand(n_arms, d_total)

        # --- PHASE SWITCH LOGIC & CONTEXT MASKING ---
        if t <= T_switch:
            current_theta = theta_1
            current_noise_std = noise_std_1
            # Phase 1: Contexts only interact with Theta 1 (mask out last 5 dimensions)
            contexts = base_contexts * mask_1
        else:
            current_theta = theta_2
            current_noise_std = noise_std_2
            # Phase 2: Contexts only interact with Theta 2 (mask out first 5 dimensions)
            contexts = base_contexts * mask_2

        # 2. Find the Optimal Arm (Oracle)
        optimal_mean_reward = -np.inf
        for a in range(n_arms):
            # Calculate the true mean reward using the current_theta and the masked context
            _, mean_reward = linear_reward_function(contexts[a], current_theta, 0)
            if mean_reward > optimal_mean_reward:
                optimal_mean_reward = mean_reward

        # 3. Agent Selects Arm
        a_index = oful_agent.select_action(contexts, t)
        x_t = contexts[a_index]

        # 4. Observe Reward
        r_t, mean_reward_t = linear_reward_function(x_t, current_theta, current_noise_std)

        # 5. Calculate Regret
        regret_t = optimal_mean_reward - mean_reward_t
        instantaneous_regrets.append(regret_t)
        cumulative_regret += regret_t

        # 6. Agent Updates Model (x_t is the masked context)
        oful_agent.update(x_t, r_t)

    return cumulative_regret, instantaneous_regrets


# --- 3. Run and Plot Simulation ---
if __name__ == '__main__':
    # >>> NEW: Set a global random seed for reproducibility <<<
    GLOBAL_SEED = 42
    np.random.seed(GLOBAL_SEED)

    N_ARMS = 32  # Fixed size of the decision set
    D_TOTAL = 10  # Total feature dimension
    T_TOTAL = 4000  # Total number of rounds
    T_SWITCH = 2000  # Round where the environment changes

    # Environment Parameters
    NOISE_STD_1 = 0.1  # Low noise for the first phase (t=1 to 1000)
    NOISE_STD_2 = 1.0  # High noise for the second phase (t=1001 to 2000)

    # OFUL Hyperparameters
    LAMBDA_REG = 1
    DELTA = 0.05
    N_SIMULATIONS = 100

    # Setup the true orthogonal parameter vectors (Theta_1 and Theta_2)
    THETA_1, THETA_2 = setup_simulation_parameters(D_TOTAL)

    print(f"Running Piecewise-Stationary OFUL Simulation (T={T_TOTAL}, d={D_TOTAL}).")
    print(f"Phase 1 (t=1 to {T_SWITCH}): Theta={THETA_1[:5].round(2)}... (d=5), Noise={NOISE_STD_1}")
    print(f"Phase 2 (t={T_SWITCH + 1} to {T_TOTAL}): Theta={THETA_2[5:].round(2)}... (d=5), Noise={NOISE_STD_2}")

    # Run multiple simulations
    all_regrets = []
    for i in range(N_SIMULATIONS):
        np.random.seed(GLOBAL_SEED + i)
        _, regrets = run_oful_simulation_piecewise(
            N_ARMS, D_TOTAL, T_TOTAL,
            THETA_1, THETA_2,
            NOISE_STD_1, NOISE_STD_2,
            T_SWITCH, LAMBDA_REG, DELTA
        )
        all_regrets.append(regrets)

    # Calculate average cumulative regret
    avg_regrets = np.mean(all_regrets, axis=0)
    avg_cumulative_regret = np.cumsum(avg_regrets)

    print(f"\nAverage Cumulative Regret after {T_TOTAL} steps: {avg_cumulative_regret[-1]:.2f}")

    # Plotting the Results
    plt.figure(figsize=(10, 6))
    plt.plot(np.arange(1, T_TOTAL + 1), avg_cumulative_regret,
             label='OFUL Average Cumulative Regret')

    # Highlight the switch point
    plt.axvline(x=T_SWITCH, color='r', linestyle='--', label='Environment Switch (t=1000)')
    plt.text(T_SWITCH - 50, avg_cumulative_regret[-1] * 0.9,
             'Change Point', color='red', rotation=90, verticalalignment='center')

    plt.title('OFUL: Piecewise-Stationary Linear Bandit (Change in $\\theta^*$ and $\\sigma$)')
    plt.xlabel('Time Steps (t)')
    plt.ylabel('Cumulative Regret')
    plt.grid(True, linestyle='--')
    plt.legend()

    # 2. Define the save path using the instance name
    FOLDER = 'simulation_results'  # Reuse the folder defined in the previous saving step
    if not os.path.exists(FOLDER):
        os.makedirs(FOLDER)

    plot_filename = os.path.join(FOLDER, "instance1_cumulative_regret_plot.png")
    # 3. Save Plot to Image File
    plt.savefig(plot_filename)
    print("✅ Plot saved successfully to: instance1")
    plt.show()
    # 4. Display Plot (optional, but standard for interactive scripting)
    INSTANCE_NAME = 'instance1'
    summary_text = []
    summary_text.append("## 📊 Cumulative Regret Summary (Every 500 Rounds)")
    summary_text.append("| Round (t) | Avg. Cumulative Regret | Phase Change |")
    summary_text.append("| :---: | :---: | :---: |")

    report_rounds = np.arange(500, T_TOTAL + 1, 500)

    for t in report_rounds:
        # Array indices are 0-based, so for round t, the index is t-1
        index = t - 1
        regret_value = avg_cumulative_regret[index]

        # Determine phase status
        phase_status = "Switch" if t == T_SWITCH else ("High Noise" if t > T_SWITCH else "Low Noise")

        summary_text.append(f"| {t:,} | {regret_value:.4f} | {phase_status} |")

    summary_text.append("-" * 50)

    # Create folder if it doesn't exist
    if not os.path.exists(FOLDER):
        os.makedirs(FOLDER)

    # Save to TXT file
    txt_filename = os.path.join(FOLDER, f"{INSTANCE_NAME}_regret_summary.txt")
    with open(txt_filename, 'w') as f:
        f.write("\n".join(summary_text))

    print(f"✅ Regret summary saved successfully to: {txt_filename}")