import numpy as np
import matplotlib.pyplot as plt
import os

# --- 1. OFUL Algorithm Class (Core Logic Unchanged) ---
# This class remains the same as its job is simply to take the observed data
# and update its single, best estimate (theta_hat) for the current environment.
class OFUL:
    def __init__(self, n_features, lambda_reg, delta, T):
        self.n_features = n_features
        self.lambda_reg = lambda_reg
        self.delta = delta
        self.T = T

        # A_t and A_t_inv (Covariance matrix and its inverse)
        self.A_t = lambda_reg * np.identity(n_features)
        self.A_t_inv = np.linalg.inv(self.A_t)
        # b_t (Reward-weighted context vector)
        self.b_t = np.zeros(n_features)

        self.theta_hat = np.dot(self.A_t_inv, self.b_t)

    def _get_beta(self, t):
        """Calculates the exploration bonus (beta) for round t."""
        term1 = np.sqrt(self.lambda_reg)
        term2 = np.sqrt(2 * np.log(1 / self.delta) + self.n_features * np.log(1 + t / self.lambda_reg))
        return term1 * term2

    def select_action(self, contexts, t):
        """Selects the best arm based on the agent's current theta_hat."""
        beta_t = self._get_beta(t)
        best_ucb = -np.inf
        selected_arm_index = -1

        for a in range(contexts.shape[0]):
            x_a = contexts[a]
            mean_reward_estimate = np.dot(x_a, self.theta_hat)
            uncertainty_term = np.sqrt(np.dot(x_a, np.dot(self.A_t_inv, x_a)))
            ucb_a = mean_reward_estimate + beta_t * uncertainty_term

            if ucb_a > best_ucb:
                best_ucb = ucb_a
                selected_arm_index = a

        return selected_arm_index

    def update(self, selected_context, observed_reward):
        """Updates the internal model efficiently using the Sherman-Morrison formula."""
        x_t = selected_context.reshape(-1, 1)

        # 1. Update A_t
        self.A_t += np.dot(x_t, x_t.T)

        # 2. Update A_t_inv (Sherman-Morrison)
        A_inv_prev = self.A_t_inv
        numerator = np.dot(np.dot(A_inv_prev, x_t), np.dot(x_t.T, A_inv_prev))
        denominator = 1 + np.dot(x_t.T, np.dot(A_inv_prev, x_t))
        self.A_t_inv = A_inv_prev - numerator / denominator
        self.A_t_inv = (self.A_t_inv + self.A_t_inv.T) / 2

        # 3. Update b_t
        self.b_t += observed_reward * x_t.flatten()

        # 4. Update Parameter Estimate
        self.theta_hat = np.dot(self.A_t_inv, self.b_t)


# --- 2. Simulation Environment and Setup ---

def linear_reward_function(context, theta_star, noise_std):
    """Generates a reward based on a linear model with Gaussian noise."""
    mean_reward = np.dot(context, theta_star)
    reward = mean_reward + np.random.normal(0, noise_std)
    return reward, mean_reward


def setup_simulation_parameters(d_total, seed=99):
    """
    Sets up the two orthogonal true parameter vectors (theta_star).
    These are fixed (unchanged) across all rounds.
    """
    #np.random.seed(seed)

    d_instance = d_total // 2

    # Theta 1: Non-zero only in the first d_instance dimensions
    theta_1_part = np.random.rand(d_instance)
    theta_1 = np.concatenate([theta_1_part, np.zeros(d_instance)])

    # Theta 2: Non-zero only in the second d_instance dimensions
    theta_2_part = np.random.rand(d_instance)
    theta_2 = np.concatenate([np.zeros(d_instance), theta_2_part])

    # Reset seed for the main simulation
    np.random.seed(None)

    return theta_1, theta_2


def run_oful_simulation_rapid_switch(n_arms, d_total, T_total, theta_1, theta_2, noise_std_1, noise_std_2, T_switch,
                                     lambda_reg=1.0, delta=0.1):
    """
    Runs the OFUL simulation with rapid, oscillating switches between two environments,
    ensuring contexts are masked according to the active instance.
    """
    oful_agent = OFUL(d_total, lambda_reg, delta, T_total)
    d_instance = d_total // 2

    cumulative_regret = 0
    instantaneous_regrets = []

    for t in range(1, T_total + 1):

        # --- Environment Switches ---
        if t <= T_switch:
            current_noise_std = noise_std_1
        else:
            current_noise_std = noise_std_2

        # --- THETA VECTOR SWITCH & CONTEXT MASKING (The Modification) ---
        base_contexts = np.random.rand(n_arms, d_total)  # Full d=10 context

        if t % 2 != 0:
            # Odd rounds (Instance 1: depends on first half of features)
            current_theta = theta_1
            # Mask contexts: set the last d_instance features to zero
            current_contexts = base_contexts * np.concatenate([np.ones(d_instance), np.zeros(d_instance)])
        else:
            # Even rounds (Instance 2: depends on second half of features)
            current_theta = theta_2
            # Mask contexts: set the first d_instance features to zero
            current_contexts = base_contexts * np.concatenate([np.zeros(d_instance), np.ones(d_instance)])

        # 1. Decision Set for the round is 'current_contexts'
        contexts = current_contexts

        # 2. Find the Optimal Arm (Oracle)
        optimal_mean_reward = -np.inf
        for a in range(n_arms):
            # The oracle uses the masked context and the active theta
            _, mean_reward = linear_reward_function(contexts[a], current_theta, 0)
            if mean_reward > optimal_mean_reward:
                optimal_mean_reward = mean_reward

        # 3. Agent Selects Arm (The agent's theta_hat is d=10 and averages both environments)
        a_index = oful_agent.select_action(contexts, t)
        x_t = contexts[a_index]

        # 4. Observe Reward
        r_t, mean_reward_t = linear_reward_function(x_t, current_theta, current_noise_std)

        # 5. Calculate Regret and Update
        regret_t = optimal_mean_reward - mean_reward_t
        instantaneous_regrets.append(regret_t)
        cumulative_regret += regret_t

        # 6. Agent Updates Model (x_t is the masked context)
        oful_agent.update(x_t, r_t)

    return cumulative_regret, instantaneous_regrets


# --- 3. Run and Plot Simulation ---
if __name__ == '__main__':
    # >>> NEW: Set a global random seed for reproducibility <<<
    GLOBAL_SEED = 42
    np.random.seed(GLOBAL_SEED)

    N_ARMS = 32  # Fixed size of the decision set
    D_TOTAL = 10  # Total feature dimension
    T_TOTAL = 4000  # Total number of rounds
    T_SWITCH = 2000  # Round where the environment changes

    # Environment Parameters
    NOISE_STD_1 = 0.1  # Low noise for t=1 to 1000
    NOISE_STD_2 = 1.0  # High noise for t=1001 to 2000

    # OFUL Hyperparameters
    LAMBDA_REG = 1
    DELTA = 0.05
    N_SIMULATIONS = 100

    # Setup the true orthogonal parameter vectors (Theta_1 and Theta_2)
    THETA_1, THETA_2 = setup_simulation_parameters(D_TOTAL)

    print(f"Running OFUL with Rapid Switches (T={T_TOTAL}, d={D_TOTAL}).")
    print(f"Theta switches every round (Odd: T1, Even: T2).")
    print(f"Noise switches from {NOISE_STD_1} to {NOISE_STD_2} at T={T_SWITCH}.")

    # Run multiple simulations
    all_regrets = []
    for i in range(N_SIMULATIONS):
        np.random.seed(GLOBAL_SEED + i)
        _, regrets = run_oful_simulation_rapid_switch(
            N_ARMS, D_TOTAL, T_TOTAL,
            THETA_1, THETA_2,
            NOISE_STD_1, NOISE_STD_2,
            T_SWITCH, LAMBDA_REG, DELTA
        )
        all_regrets.append(regrets)

    # Calculate average cumulative regret
    avg_regrets = np.mean(all_regrets, axis=0)
    avg_cumulative_regret = np.cumsum(avg_regrets)

    print(f"\nAverage Cumulative Regret after {T_TOTAL} steps: {avg_cumulative_regret[-1]:.2f}")

    # Plotting the Results
    plt.figure(figsize=(10, 6))
    plt.plot(np.arange(1, T_TOTAL + 1), avg_cumulative_regret,
             label='OFUL Average Cumulative Regret')

    # Highlight the switch point
    plt.axvline(x=T_SWITCH, color='r', linestyle='--', label='Environment Switch (t=1000)')
    plt.text(T_SWITCH - 50, avg_cumulative_regret[-1] * 0.9,
             'Change Point', color='red', rotation=90, verticalalignment='center')

    plt.title('OFUL: Piecewise-Stationary Linear Bandit (Change in $\\theta^*$ and $\\sigma$)')
    plt.xlabel('Time Steps (t)')
    plt.ylabel('Cumulative Regret')
    plt.grid(True, linestyle='--')
    plt.legend()

    # 2. Define the save path using the instance name
    FOLDER = 'simulation_results'  # Reuse the folder defined in the previous saving step
    if not os.path.exists(FOLDER):
        os.makedirs(FOLDER)

    plot_filename = os.path.join(FOLDER, "instance2_cumulative_regret_plot.png")
    # 3. Save Plot to Image File
    plt.savefig(plot_filename)
    print("✅ Plot saved successfully to: instance2")
    plt.show()
    # 4. Display Plot (optional, but standard for interactive scripting)

    INSTANCE_NAME = 'instance2'
    summary_text = []
    summary_text.append("## 📊 Cumulative Regret Summary (Every 500 Rounds)")
    summary_text.append("| Round (t) | Avg. Cumulative Regret | Phase Change |")
    summary_text.append("| :---: | :---: | :---: |")

    report_rounds = np.arange(500, T_TOTAL + 1, 500)

    for t in report_rounds:
        # Array indices are 0-based, so for round t, the index is t-1
        index = t - 1
        regret_value = avg_cumulative_regret[index]

        # Determine phase status
        phase_status = "Switch" if t == T_SWITCH else ("High Noise" if t > T_SWITCH else "Low Noise")

        summary_text.append(f"| {t:,} | {regret_value:.4f} | {phase_status} |")

    summary_text.append("-" * 50)

    # Create folder if it doesn't exist
    if not os.path.exists(FOLDER):
        os.makedirs(FOLDER)

    # Save to TXT file
    txt_filename = os.path.join(FOLDER, f"{INSTANCE_NAME}_regret_summary.txt")
    with open(txt_filename, 'w') as f:
        f.write("\n".join(summary_text))

    print(f"✅ Regret summary saved successfully to: {txt_filename}")