from __future__ import annotations

from dataclasses import dataclass
from typing import List, Tuple, Optional

import os
import pickle
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, jit
from tqdm import trange


import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(style="whitegrid")
from mpl_toolkits.mplot3d import Axes3D  # needed for get_proj hack

@dataclass
class EnvParams:
    p12: float = 0.8
    p21: float = 0.8
    m: float = 0.0
    r_mult: float = 2.0
    c_cost: float = 1.0




@jit
def softmax(Q: jnp.ndarray, beta: float) -> jnp.ndarray:
    return jax.nn.softmax(beta * Q, axis=-1)


@jit
def choose_actions_stateful(
    key: jax.Array,
    Q_state: jnp.ndarray,
    beta: float,
) -> jnp.ndarray:
    logits = beta * Q_state  # [N, M]
    return random.categorical(key, logits, axis=-1).astype(jnp.int32)  # [N]



def p_transition(s: int, n_c: int, n_d: int, N: int, params: EnvParams) -> float:
    if s == 0:
        p = params.p12 * (n_d / N)
    else:
        p = params.p21 * (n_c / N)
    return float(max(0.0, min(1.0, p)))


def pgg_rewards(n_c: int, N: int, actions: jnp.ndarray, r_mult: float, c_cost: float) -> jnp.ndarray:
    share = (r_mult * c_cost * n_c) / N
    is_c = (actions == 0).astype(jnp.float32)
    return share - c_cost * is_c  # [N]


def env_step_once(
    key: jax.Array,
    s_cur: int,
    actions: jnp.ndarray,  # [N] with {0,1}
    params: EnvParams,
) -> Tuple[int, jnp.ndarray, int, int]:
    N = actions.shape[0]
    n_c = int(jnp.sum(actions == 0))
    n_d = N - n_c

    p_other = p_transition(s_cur, n_c, n_d, N, params)
    key, key_trans = random.split(key, 2)
    go_other = random.bernoulli(key_trans, p_other)  # bool
    s_next = int(1 - s_cur) if bool(go_other) else int(s_cur)

    if s_cur == 0 and s_next == 0:
        rewards = pgg_rewards(n_c, N, actions, params.r_mult, params.c_cost)
    else:
        rewards = jnp.full((N,), params.m, dtype=jnp.float32)

    return s_next, rewards, n_c, n_d



def td_errors_markov_batch(
    Q: jnp.ndarray,              # [N, S, M]
    states: jnp.ndarray,         # [B] (int32)
    actions: jnp.ndarray,        # [B, N] (int32)
    rewards: jnp.ndarray,        # [B, N] (float32)
    next_states: jnp.ndarray,    # [B] (int32)
    gamma: float,
) -> jnp.ndarray:
    B, N = actions.shape
    S = Q.shape[1]
    M = Q.shape[2]
    assert S == 2 and M == 2, "This implementation assumes 2 states, 2 actions."

    agent_idx = jnp.broadcast_to(jnp.arange(N, dtype=jnp.int32), (B, N))            # [B, N]
    s_idx = jnp.broadcast_to(states[:, None].astype(jnp.int32), (B, N))             # [B, N]
    a_idx = actions.astype(jnp.int32)                                               # [B, N]
    s_next_idx = jnp.broadcast_to(next_states[:, None].astype(jnp.int32), (B, N))   # [B, N]

    q_sa = Q[agent_idx, s_idx, a_idx]  # [B, N]

    max_over_a = jnp.max(Q, axis=2)  # [N, S]
    q_next_max = max_over_a[agent_idx, s_next_idx]  # [B, N]

    deltas = rewards + gamma * q_next_max - q_sa  # [B, N]
    return deltas


def apply_batch_update_groupby_state_action(
    Q: jnp.ndarray,              # [N, S, M]
    deltas: jnp.ndarray,         # [B, N]
    states: jnp.ndarray,         # [B]
    actions: jnp.ndarray,        # [B, N]
    alpha: float,
) -> jnp.ndarray:
    B, N = actions.shape
    S = Q.shape[1]
    M = Q.shape[2]

    onehot_s = jax.nn.one_hot(states.astype(jnp.int32), S)              # [B, S]
    onehot_s = onehot_s[:, None, :, None]                               # [B, 1, S, 1] (for broadcast over N,M)
    onehot_a = jax.nn.one_hot(actions.astype(jnp.int32), M)             # [B, N, M]
    onehot_a = onehot_a[:, :, None, :]                                  # [B, N, 1, M]

    mask = onehot_s * onehot_a                                          # [B, N, S, M]
    deltas_expanded = deltas[:, :, None, None]                           # [B, N, 1, 1]

    contrib = (deltas_expanded * mask).sum(axis=0)                       # [N, S, M]
    counts = mask.sum(axis=0)                                            # [N, S, M]
    safe_counts = jnp.maximum(counts, 1.0)

    avg_delta = contrib / safe_counts                                    # [N, S, M]
    return Q + alpha * avg_delta



def init_Q(key_Q: jax.random.PRNGKey, N: int, S: int, M: int) -> jnp.ndarray:
    normal_samples = jax.random.normal(key_Q, shape=(N, S, M))
    Q = jnp.clip(normal_samples, -1.0, 1.0)
    Q = Q.at[..., 0].add(1.0)
    return Q.astype(jnp.float32)


def sim_markov_batch(
    N: int,
    time_steps: int,
    alpha: float,
    beta: float,
    gamma: float,
    env_params: EnvParams,
    batch_size: int = 10,
    num_reps: int = 1,
    init_key: int = 0,
    record_every: int = 10,
    init_state: int = 0,
) -> Tuple[jnp.ndarray, jnp.ndarray, list]:
    S, M = 2, 2
    Q_history_all, Pi_history_all, time_history_all = [], [], []

    for rep in range(num_reps):
        key = random.PRNGKey(init_key + rep)
        key, key_Q = random.split(key, 2)

        Q = init_Q(key_Q, N=N, S=S, M=M)
        Q_history = [Q]
        Pi_history = [softmax(Q, beta)]
        t_rec = [0]

        s = int(init_state)

        for t in trange(time_steps):
            states_b = []
            actions_b = []
            rewards_b = []
            next_states_b = []

            s_cur = s
            key_loop = key

            for b in range(batch_size):
                key_loop, k_act, k_env = random.split(key_loop, 3)

                Q_state = Q[:, s_cur, :]                   # [N, M]
                a = choose_actions_stateful(k_act, Q_state, beta)  # [N]

                s_nxt, r, n_c, n_d = env_step_once(k_env, s_cur, a, env_params)

                # record
                states_b.append(s_cur)
                actions_b.append(a)
                rewards_b.append(r)
                next_states_b.append(s_nxt)

                s_cur = s_nxt

            key = key_loop
            s = s_cur

            states_b = jnp.array(states_b, dtype=jnp.int32)             # [B]
            actions_b = jnp.stack(actions_b, axis=0).astype(jnp.int32)  # [B, N]
            rewards_b = jnp.stack(rewards_b, axis=0).astype(jnp.float32)  # [B, N]
            next_states_b = jnp.array(next_states_b, dtype=jnp.int32)   # [B]

            deltas = td_errors_markov_batch(
                Q=Q,
                states=states_b,
                actions=actions_b,
                rewards=rewards_b,
                next_states=next_states_b,
                gamma=gamma,
            )  # [B, N]

            Q = apply_batch_update_groupby_state_action(
                Q=Q,
                deltas=deltas,
                states=states_b,
                actions=actions_b,
                alpha=alpha,
            )

            if (t + 1) % record_every == 0 or t == 0:
                Q_history.append(Q)
                Pi_history.append(softmax(Q, beta))
                t_rec.append(t + 1)

        Q_history_all.append(jnp.stack(Q_history))     # [T_rec, N, S, M]
        Pi_history_all.append(jnp.stack(Pi_history))   # [T_rec, N, S, M]
        time_history_all.append(t_rec)

    return (
        jnp.stack(Q_history_all),   # [num_reps, T_rec, N, S, M]
        jnp.stack(Pi_history_all),  # [num_reps, T_rec, N, S, M]
        time_history_all,           # list of lists
    )
    
    
@jax.jit
def _avg_max_diff_all_pairs_Q(Q_t: jnp.ndarray) -> jnp.ndarray:
    Q_i = Q_t[:, None, :, :]
    Q_j = Q_t[None, :, :, :]
    abs_diffs = jnp.abs(Q_i - Q_j)              # [N, N, S, M]
    max_diffs_per_pair = jnp.max(abs_diffs, axis=(2, 3))  # [N, N]

    N_agents = Q_t.shape[0]
    mask = jnp.triu(jnp.ones((N_agents, N_agents)), k=1)
    valid_diffs = max_diffs_per_pair * mask
    num_pairs = N_agents * (N_agents - 1) // 2
    return jnp.sum(valid_diffs) / num_pairs

data_dir = "simulation_data"
os.makedirs(data_dir, exist_ok=True)
demo_data_file = os.path.join(data_dir, "EPG_1.pkl")



if __name__ == "__main__":
    N = 100
    time_steps = 5000
    alpha = 0.01
    beta = 1.0
    gamma = 0.9

    env = EnvParams(
        p12=0.8,   # state0 -> state1
        p21=0.8,   # state1 -> state0
        m=0.0,     # payoff if not (0->0)
        r_mult=2.0,
        c_cost=1.0,
    )

    batch_size = 64
    num_reps = 1
    record_every = 1
    init_state = 0

    Q_hist, Pi_hist, t_hist = sim_markov_batch(
        N=N,
        time_steps=time_steps,
        alpha=alpha,
        beta=beta,
        gamma=gamma,
        env_params=env,
        batch_size=batch_size,
        num_reps=num_reps,
        init_key=42,
        record_every=record_every,
        init_state=init_state,
    )

    print("Shapes:", Q_hist.shape, Pi_hist.shape, len(t_hist[0]), "records")
    
    demo_data = {
        'Q_hist': Q_hist,
        'Pi_hist': Pi_hist,
        't_hist': t_hist,
        'parameters': {
            'N': N,
            'time_steps': time_steps,
            'alpha': alpha,
            'beta': beta,
            'gamma': gamma,
            'env_params': env,
            'batch_size': batch_size,
            'num_reps': num_reps,
            'record_every': record_every,
            'init_state': init_state,
        }
    }
    
    with open(demo_data_file, 'wb') as f:
        pickle.dump(demo_data, f)
    print(f"Demo results saved to {demo_data_file}")

N_default = 100
time_steps = 5000
alpha = 0.01
beta = 1.0
gamma_default = 0.9
record_every = 1
init_state = 0

env = EnvParams(
    p12=0.8,
    p21=0.8,
    m=0.0,
    r_mult=2.0,
    c_cost=1.0,
)

@jax.jit
def _avg_max_diff_all_pairs_Q(Q_t: jnp.ndarray) -> jnp.ndarray:
    Q_i = Q_t[:, None, :, :]
    Q_j = Q_t[None, :, :, :]
    abs_diffs = jnp.abs(Q_i - Q_j)              # [N, N, S, M]
    max_diffs_per_pair = jnp.max(abs_diffs, axis=(2, 3))  # [N, N]

    N_agents = Q_t.shape[0]
    mask = jnp.triu(jnp.ones((N_agents, N_agents)), k=1)
    valid_diffs = max_diffs_per_pair * mask
    num_pairs = N_agents * (N_agents - 1) // 2
    return jnp.sum(valid_diffs) / num_pairs

def run_one_setting_and_measure(
    *,
    N: int,
    time_steps: int,
    alpha: float,
    beta: float,
    gamma: float,
    env_params: EnvParams,
    batch_size: int,
    record_every: int,
    init_state: int,
    init_key: int,
):
    Q_hist, Pi_hist, t_hist = sim_markov_batch(
        N=N,
        time_steps=time_steps,
        alpha=alpha,
        beta=beta,
        gamma=gamma,
        env_params=env_params,
        batch_size=batch_size,
        num_reps=1,
        init_key=init_key,
        record_every=record_every,
        init_state=init_state,
    )
    time_points = np.array(t_hist[0], dtype=int)
    Q_mean_over_reps = jnp.mean(jnp.array(Q_hist), axis=0)  # [T_rec, N, S, M]

    compute_all_times = jax.vmap(_avg_max_diff_all_pairs_Q, in_axes=0)
    avg_max_diffs = compute_all_times(Q_mean_over_reps)     # [T_rec]

    delta0 = float(avg_max_diffs[0])
    return time_points, np.array(avg_max_diffs), delta0

data_dir = "simulation_data"
os.makedirs(data_dir, exist_ok=True)
batch_data_file = os.path.join(data_dir, "batch_size_results.pkl")
gamma_data_file = os.path.join(data_dir, "gamma_results.pkl")
N_data_file = os.path.join(data_dir, "N_results.pkl")


batch_list = [16, 64, 128]
results_batch = []
delta0_list = []
time_axis_ref = None

for i, bsz in enumerate(batch_list):
    print(f"Running batch_size = {bsz}...")
    t, diffs, d0 = run_one_setting_and_measure(
        N=N_default,
        time_steps=time_steps,
        alpha=alpha,
        beta=beta,
        gamma=gamma_default,
        env_params=env,
        batch_size=bsz,
        record_every=record_every,
        init_state=init_state,
        init_key=100 + i, 
    )
    results_batch.append((bsz, t, diffs))
    delta0_list.append(d0)
    if time_axis_ref is None:
        time_axis_ref = t
    else:
        assert np.array_equal(time_axis_ref, t), "Recording intervals are inconsistent across different experiments!"

delta0_for_bound = max(delta0_list)
theory_batch = delta0_for_bound * np.exp(-alpha * (1.0 - gamma_default) * time_axis_ref)

batch_data = {
    'results_batch': results_batch,
    'delta0_list': delta0_list,
    'time_axis_ref': time_axis_ref,
    'theory_batch': theory_batch,
    'gamma_default': gamma_default,
    'alpha': alpha,
    'parameters': {
        'N_default': N_default,
        'time_steps': time_steps,
        'alpha': alpha,
        'beta': beta,
        'gamma_default': gamma_default,
        'record_every': record_every,
        'init_state': init_state,
        'env_params': env,
    }
}

with open(batch_data_file, 'wb') as f:
    pickle.dump(batch_data, f)
print(f"Batch size results saved to {batch_data_file}")

gamma_list = [0.5, 0.7, 0.9]
batch_for_gamma = 64
results_gamma = []

for j, g in enumerate(gamma_list):
    print(f"Running gamma = {g}...")
    t_g, diffs_g, d0_g = run_one_setting_and_measure(
        N=N_default,
        time_steps=time_steps,
        alpha=alpha,
        beta=beta,
        gamma=g,
        env_params=env,
        batch_size=batch_for_gamma,
        record_every=record_every,
        init_state=init_state,
        init_key=200 + j,
    )
    theory_g = d0_g * np.exp(-alpha * (1.0 - g) * t_g)
    results_gamma.append((g, t_g, diffs_g, theory_g))

gamma_data = {
    'results_gamma': results_gamma,
    'batch_for_gamma': batch_for_gamma,
    'parameters': {
        'N_default': N_default,
        'time_steps': time_steps,
        'alpha': alpha,
        'beta': beta,
        'record_every': record_every,
        'init_state': init_state,
        'env_params': env,
    }
}

with open(gamma_data_file, 'wb') as f:
    pickle.dump(gamma_data, f)
print(f"Gamma results saved to {gamma_data_file}")

N_list = [10, 10, 10, 10, 10, 
          100, 100, 100, 100, 100, 
          1000, 1000, 1000, 1000, 1000]
batch_for_N = 64 
results_N = []

for k, n in enumerate(N_list):
    print(f"Running N = {n}...")
    t_n, diffs_n, d0_n = run_one_setting_and_measure(
        N=n,
        time_steps=time_steps,
        alpha=alpha,
        beta=beta,
        gamma=gamma_default,
        env_params=env,
        batch_size=batch_for_N,
        record_every=record_every,
        init_state=init_state,
        init_key=300 + k,
    )
    theory_n = d0_n * np.exp(-alpha * (1.0 - gamma_default) * t_n)
    results_N.append((n, t_n, diffs_n, theory_n))

N_data = {
    'results_N': results_N,
    'batch_for_N': batch_for_N,
    'gamma_default': gamma_default,
    'parameters': {
        'time_steps': time_steps,
        'alpha': alpha,
        'beta': beta,
        'gamma_default': gamma_default,
        'record_every': record_every,
        'init_state': init_state,
        'env_params': env,
    }
}

with open(N_data_file, 'wb') as f:
    pickle.dump(N_data, f)
print(f"N results saved to {N_data_file}")





data_dir = "simulation_data"
os.makedirs(data_dir, exist_ok=True)
batch_data_file = os.path.join(data_dir, "batch_size_results.pkl")
gamma_data_file = os.path.join(data_dir, "gamma_results.pkl")
N_data_file = os.path.join(data_dir, "N_results.pkl")

with open(batch_data_file, 'rb') as f:
    batch_data = pickle.load(f)

results_batch = batch_data['results_batch']
theory_batch = batch_data['theory_batch']
time_axis_ref = batch_data['time_axis_ref']
gamma_default = batch_data['gamma_default']

with open(gamma_data_file, 'rb') as f:
    gamma_data = pickle.load(f)

results_gamma = gamma_data['results_gamma']
batch_for_gamma = gamma_data['batch_for_gamma']

with open(N_data_file, 'rb') as f:
    N_data = pickle.load(f)

results_N = N_data['results_N']
batch_for_N = N_data['batch_for_N']






def truncate_colormap(cmap, minval: float = 0.0, maxval: float = 1.0, n: int = 256):
    from matplotlib.colors import LinearSegmentedColormap
    new_colors = cmap(np.linspace(minval, maxval, n))
    return LinearSegmentedColormap.from_list(
        f"trunc({cmap.name},{minval:.2f},{maxval:.2f})", new_colors
    )

def custom_colormap(base_cmap, data: np.ndarray, breakpoints: np.ndarray):
    from matplotlib.colors import LinearSegmentedColormap, Normalize
    # breakpoints are in [0,1]; sample base cmap there:
    sampled = base_cmap(breakpoints)
    cmap = LinearSegmentedColormap.from_list("custom_cmap", sampled)
    norm = Normalize(vmin=np.min(data), vmax=np.max(data))
    return cmap, norm

def load_results_gamma_if_needed(
    default_path: str = "simulation_data/gamma_results.pkl"
) -> List[Tuple[float, np.ndarray, np.ndarray, np.ndarray]]:
    if "results_gamma" in globals():
        rg = globals()["results_gamma"]
        out = []
        for g, t_g, diffs_g, theory_g in rg:
            out.append(
                (
                    float(g),
                    np.array(t_g),
                    np.array(diffs_g),
                    np.array(theory_g),
                )
            )
        return out

    if os.path.exists(default_path):
        with open(default_path, "rb") as f:
            gamma_data = pickle.load(f)
        rg = gamma_data["results_gamma"]
        out = []
        for g, t_g, diffs_g, theory_g in rg:
            out.append(
                (
                    float(g),
                    np.array(t_g),
                    np.array(diffs_g),
                    np.array(theory_g),
                )
            )
        return out

    raise FileNotFoundError(
        "results_gamma not found in memory and file "
        f"'{default_path}' does not exist. Please run your simulation code first."
    )

def plot_gamma_3d_from_results(
    threshold: float = 0.0,
    title: str = "Average max Q-value differences at selected gamma",
):
    rg = load_results_gamma_if_needed()

    rg.sort(key=lambda x: x[0])

    gamma_vals = np.array([g for g, _, _, _ in rg], dtype=float)
    all_z = np.concatenate([np.r_[diffs_g, theory_g] for _, _, diffs_g, theory_g in rg])
    z_min, z_max = float(np.min(all_z)), float(np.max(all_z))

    data_for_cmap = np.linspace(z_min, z_max, 1000)
    custom_cmap_obj, norm = custom_colormap(
        plt.cm.viridis,
        data=data_for_cmap,
        breakpoints=np.concatenate(
            [
                np.linspace(0, 1 / 3, 200),
                np.linspace(1 / 3, 2 / 3, 400),
                np.linspace(2 / 3, 1, 400),
            ]
        ),
    )

    num_curves = len(rg)
    curve_colors = truncate_colormap(plt.cm.viridis, 0.3, 0.8)(np.linspace(0, 1, num_curves))

    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection="3d")

    for i, (g, t_g, diffs_g, theory_g) in enumerate(rg):
        mask_sim = diffs_g < threshold
        x_sim = np.ma.array(t_g, mask=mask_sim)
        y_sim = np.ma.array(np.ones_like(t_g) * g, mask=mask_sim)
        z_sim = np.ma.array(diffs_g, mask=mask_sim)

        ax.plot(
            x_sim,
            y_sim,
            z_sim,
            color=curve_colors[i],
            alpha=0.5,
            linewidth=2.0,
            label=f"gamma={g:.1f} (sim)" if i == 0 else None,  # keep legend compact
        )

        mask_the = theory_g < threshold
        x_the = np.ma.array(t_g, mask=mask_the)
        y_the = np.ma.array(np.ones_like(t_g) * g, mask=mask_the)
        z_the = np.ma.array(theory_g, mask=mask_the)

        ax.plot(
            x_the,
            y_the,
            z_the,
            color=curve_colors[i],  # same color as simulation line
            linestyle="--",
            linewidth=1.8,
            alpha=0.7,
            label="theory upper bound" if i == 0 else None,
        )
    
    ax.set_xlabel("Time Step")
    ax.set_ylabel("$\\gamma$", labelpad=20)
    ax.set_zlabel("Average $\\Delta_t$")
    ax.set_title(title)

    t_min = min(int(np.min(t_g)) for _, t_g, _, _ in rg)
    t_max = max(int(np.max(t_g)) for _, t_g, _, _ in rg)
    ax.set_xlim(t_min, t_max)
    ax.set_xticks([t_min, t_max])

    ax.set_ylim(0.4, 1.0)
    ax.set_yticks([0.5, 0.7, 0.9])

    z_min, z_max = ax.get_zlim()
    ax.set_zlim(0, z_max)
    ax.set_zticks([0, 1.5])

    ax.legend()

    ax.view_init(elev=10, azim=-70)
    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False

    ax.grid(False)
    
    ax.xaxis.pane.set_edgecolor('gray')
    ax.yaxis.pane.set_edgecolor('gray')
    ax.zaxis.pane.set_edgecolor('gray')
    
    ax.xaxis._axinfo['grid']['color'] = (1, 1, 1, 0)  # transparent
    ax.yaxis._axinfo['grid']['color'] = (1, 1, 1, 0)  # transparent
    ax.zaxis._axinfo['grid']['color'] = (1, 1, 1, 0)  # transparent

    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([1, 3, 1, 2.6]))

    plt.tight_layout()
    plt.savefig("figures/EPG_gamma_plot.png", dpi=600)
    plt.show()


plot_gamma_3d_from_results(
    threshold=0.0,
    title="Average $\\Delta_t$ for different $\\gamma$"
)



fs = 20
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

num_curves = 3
curve_colors = truncate_colormap(plt.cm.viridis, 0.3, 0.8)(np.linspace(0, 1, num_curves))
action_colors = {0: curve_colors[0], 1: curve_colors[1], 2: curve_colors[2]}

ax1 = axes[0]
color_idx = 0
for bsz, t, diffs in results_batch:
    color = list(action_colors.values())[color_idx % len(action_colors)]
    ax1.plot(t, diffs, linewidth=2, label=f'$B={bsz}$ (simulation)', color=color, alpha=0.6, zorder=10)
    color_idx += 1
ax1.plot(time_axis_ref, theory_batch, linestyle='--', linewidth=3, label='upper bound (theory)', color="#7F55B1", alpha=0.5, dashes=(5, 3))
ax1.tick_params(axis='both', labelsize=fs-4)
ax1.set_xlabel('Time Step', fontsize=fs-2)
ax1.set_ylabel('Average $\\Delta_t$', fontsize=fs-2)
ax1.set_title('$N=100, \\gamma=0.9$'.format(gamma_default), fontsize=fs)
ax1.grid(True, alpha=0.3)
ax1.legend(fontsize=fs-2)

ax2 = axes[1]

n_groups = {}
theory_data = {}
for n, t_n, diffs_n, theory_n in results_N:
    if n not in n_groups:
        n_groups[n] = []
        theory_data[n] = (t_n, theory_n)
    n_groups[n].append(diffs_n)

color_idx = 0
max_n = max(n_groups.keys())
max_n_avg_initial = None

for n in sorted(n_groups.keys()):
    color = list(action_colors.values())[color_idx % len(action_colors)]
    
    avg_diffs = np.mean(n_groups[n], axis=0)
    t_n, theory_n = theory_data[n]
    
    if n == max_n:
        max_n_avg_initial = avg_diffs[0]
    
    ax2.plot(t_n, avg_diffs, linewidth=2, label=f'$N={n}$ (simulation)', color=color, alpha=0.6, zorder=10)
    
    color_idx += 1

for n in sorted(n_groups.keys()):
    if n == 1000:
        color = list(action_colors.values())[0]  # Use first color for consistency
        t_n, theory_n = theory_data[n]
        
        if max_n_avg_initial is not None:
            theory_adjusted = theory_n - theory_n[0] + max_n_avg_initial
        else:
            theory_adjusted = theory_n
            
        ax2.plot(t_n, theory_adjusted, linestyle='--', linewidth=3, label='upper bound (theory)', color="#7F55B1", alpha=0.5, dashes=(5, 3))
        break

ax2.tick_params(axis='both', labelsize=fs-4)
ax2.set_xlabel('Time Step', fontsize=fs-2)
ax2.set_ylabel('Average $\\Delta_t$', fontsize=fs-2)
ax2.set_title('$B=64, \\gamma=0.9$', fontsize=fs)
ax2.grid(True, alpha=0.3)
ax2.legend(fontsize=fs-2)

plt.tight_layout()
plt.savefig("figures/EPG_B_N.png", dpi=600)
plt.show()
