from __future__ import annotations

from dataclasses import dataclass, field
from functools import partial
from typing import Tuple, Sequence, Optional, Callable

import jax
import jax.numpy as jnp
from jax import random, jit
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(style="whitegrid")
plt.rcParams['font.family'] = 'Arial'

import os
os.makedirs('figures', exist_ok=True)



@jit
def softmax(Q: jnp.ndarray, beta: float) -> jnp.ndarray:
    return jax.nn.softmax(beta * Q, axis=-1)

@jit
def choose_actions(key: jax.Array, Q: jnp.ndarray, beta: float) -> jnp.ndarray:
    logits = beta * Q  # [N, M]
    return random.categorical(key, logits, axis=-1)  # [N]



@jit
def get_td_errors_batch(
    payoff_matrix: jnp.ndarray,  # [M, M]
    a_values:      jnp.ndarray,  # [B, N]
    Q_values:      jnp.ndarray,  # [N, M]
    gamma:         float
) -> jnp.ndarray:                # returns [B, N, N]
    a_idx = a_values.astype(jnp.int32)  # [B, N]

    r_matrix = payoff_matrix[a_idx[:, :, None], a_idx[:, None, :]]  # [B, N, N]

    agent_idx = jnp.arange(a_idx.shape[1], dtype=jnp.int32)[None, :]  # [1, N]
    q_sa = Q_values[agent_idx, a_idx]  # [B, N]

    q_next_max = jnp.max(Q_values, axis=1)  # [N]

    td_errors = r_matrix + gamma * q_next_max[None, :, None] - q_sa[:, :, None]  # [B, N, N]
    return td_errors


@jit
def get_td_errors(
    payoff_matrix: jnp.ndarray,  # [M, M]
    a_values:      jnp.ndarray,  # [N]
    Q_values:      jnp.ndarray,  # [N, M]
    gamma:         float
) -> jnp.ndarray:                # returns [N, N]
    a_idx = a_values.astype(jnp.int32)  # [N]

    r_matrix = payoff_matrix[a_idx[:, None], a_idx[None, :]]  # [N, N]

    q_sa = Q_values[jnp.arange(Q_values.shape[0], dtype=jnp.int32), a_idx]  # [N]
    q_next_max = jnp.max(Q_values, axis=1)  # [N]

    td_errors = r_matrix + gamma * q_next_max[:, None] - q_sa[:, None]  # [N, N]
    return td_errors



@jit
def average_over_opponents_excluding_self(td_matrix_BNN: jnp.ndarray) -> jnp.ndarray:
    B, N, _ = td_matrix_BNN.shape
    mask = (1.0 - jnp.eye(N))[None, :, :]  # [1, N, N]
    summed = (td_matrix_BNN * mask).sum(axis=2)  # [B, N]
    return summed / (N - 1)


@jit
def average_over_opponents_excluding_self_single(td_matrix_NN: jnp.ndarray) -> jnp.ndarray:
    N = td_matrix_NN.shape[0]
    mask = (1.0 - jnp.eye(N))
    summed = (td_matrix_NN * mask).sum(axis=1)  # [N]
    return summed / (N - 1)


@jit
def update_Q_values_batch(
    Q: jnp.ndarray,              # [N, M]
    td_batch: jnp.ndarray,       # [B, N, N]
    a_batch: jnp.ndarray,        # [B, N]
    alpha: float
) -> jnp.ndarray:                # [N, M]
    B, N, _ = td_batch.shape
    M = Q.shape[1]

    avg_td_over_opponents = average_over_opponents_excluding_self(td_batch)  # [B, N]

    onehot = jax.nn.one_hot(a_batch, M)                                 # [B, N, M]
    contrib = avg_td_over_opponents[..., None] * onehot                 # [B, N, M]

    counts = onehot.sum(axis=0)                                         # [N, M]
    safe_counts = jnp.maximum(counts, 1.0)                              # avoid /0; if 0 then contrib=0

    avg_updates = contrib.sum(axis=0) / safe_counts                     # [N, M]

    return Q + alpha * avg_updates


@jit
def update_Q_values_single(
    Q: jnp.ndarray,          # [N, M]
    avg_td: jnp.ndarray,     # [N]
    a_values: jnp.ndarray,   # [N]
    alpha: float
) -> jnp.ndarray:
    M = Q.shape[1]
    onehot_actions = jax.nn.one_hot(a_values, M)        # [N, M]
    td_values = avg_td[:, None] * onehot_actions        # [N, M]
    return Q + alpha * td_values




def init_Q(key_Q: jax.random.PRNGKey, N: int, M: int) -> jnp.ndarray:
    QC = jax.random.normal(key_Q, shape=[N, ])
    QD = jax.random.normal(key_Q, shape=[N, ]) * jnp.sqrt(0.5)
    Q = jnp.stack([QC, QD], axis=-1)
    Q = jnp.clip(Q, -1, 1)
    return Q

def sim_graph_batch(
    N: int,
    M: int,
    time_steps: int,
    alpha: float,
    beta: float,
    gamma: float,
    payoff_matrix: jnp.ndarray,
    batch_size: int = 10,
    num_reps: int = 10,
    init_key: int = 42,
    record_every: int = 10,
):
    Q_history_all = []
    X_history_all = []
    time_history_all = []

    for rep in range(num_reps):
        key = random.PRNGKey(init_key + rep)
        key, key_Q = random.split(key, 2)

        Q = init_Q(key_Q, N, M)
        Q_history = [Q]
        X_history = [softmax(Q, beta)]
        time_history = [0]

        for t in range(time_steps):
            keys = random.split(key, batch_size + 1)
            key = keys[0]
            action_keys = keys[1:]

            batch_actions = []
            for b in range(batch_size):
                a = choose_actions(action_keys[b], Q, beta)  # [N]
                batch_actions.append(a)
            a_batch = jnp.stack(batch_actions, axis=0)       # [B, N]

            td_batch = get_td_errors_batch(payoff_matrix, a_batch, Q, gamma)   # [B, N, N]
            Q = update_Q_values_batch(Q, td_batch, a_batch, alpha)

            if (t + 1) % record_every == 0 or t == 0:
                Q_history.append(Q)
                X_history.append(softmax(Q, beta))
                time_history.append(t + 1)

        Q_history_all.append(jnp.stack(Q_history))
        X_history_all.append(jnp.stack(X_history))
        time_history_all.append(jnp.array(time_history))

    return (
        jnp.stack(Q_history_all),   # [num_reps, T, N, M]
        jnp.stack(X_history_all),   # [num_reps, T, N, M]
        time_history_all,           # list of length num_reps with jnp arrays
    )


def sim_graph_single(
    N: int,
    M: int,
    time_steps: int,
    alpha: float,
    beta: float,
    gamma: float,
    payoff_matrix: jnp.ndarray,
    num_reps: int = 10,
    init_key: int = 42,
    record_every: int = 10,
):
    Q_history_all = []
    X_history_all = []

    for rep in range(num_reps):
        key = random.PRNGKey(init_key + rep)
        key, key_Q = random.split(key, 2)
        Q = init_Q(key_Q, N, M)

        Q_history = [Q]
        X_history = [softmax(Q, beta)]

        for t in range(time_steps):
            key, key_a = random.split(key, 2)
            a = choose_actions(key_a, Q, beta)  # [N]

            td = get_td_errors(payoff_matrix, a, Q, gamma)             # [N, N]
            avg_td = average_over_opponents_excluding_self_single(td)  # [N]
            Q = update_Q_values_single(Q, avg_td, a, alpha)

            if (t + 1) % record_every == 0 or t == 0:
                Q_history.append(Q)
                X_history.append(softmax(Q, beta))

        Q_history_all.append(jnp.stack(Q_history))
        X_history_all.append(jnp.stack(X_history))

    return jnp.stack(Q_history_all), jnp.stack(X_history_all)


if __name__ == "__main__":
    N: int          = 100   # agents
    M: int          = 2     # actions
    time_steps: int = 500
    alpha: float    = 0.01
    beta: float     = 1.0
    gamma: float    = 0.0
    b: float       = 2.0
    c: float        = 1.0
    payoff_matrix = jnp.array([[b - c, -c],
                               [b,      0.0]], dtype=jnp.float32)
    batch_size: int = 10
    num_reps: int   = 1
    record_every: int = 10

    Q_history_all, X_history_all, time_history_all = sim_graph_batch(
        N, M, time_steps, alpha, beta, gamma,
        payoff_matrix,
        batch_size=batch_size,
        num_reps=num_reps,
        init_key=42,
        record_every=record_every,
    )

    num_rep, T, N, M = Q_history_all.shape

    time_points = np.array(time_history_all[0])  # (T,)
    
        

    fs = 20
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))



    action_colors = {0: '#3D7EA6', 1: '#FFA5A5'}  # Action 0 (cooperate) and Action 1 (defect)
    action_labels = {0: 'Action 0', 1: 'Action 1'}


    for rep in range(num_reps):
        for agent in range(N):
            axes[0].plot(time_points, Q_history_all[rep, :, agent, 0], 
                        color=action_colors[0], alpha=0.15, linewidth=0.5)
            axes[0].plot(time_points, Q_history_all[rep, :, agent, 1], 
                        color=action_colors[1], alpha=0.15, linewidth=0.5)

    axes[0].plot([], [], color=action_colors[0], linewidth=2, label='action $a_0$')
    axes[0].plot([], [], color=action_colors[1], linewidth=2, label='action $a_1$')

    axes[0].set_yticks([-1, -0.5, 0, 0.5, 1])
    axes[0].tick_params(axis='both', labelsize=fs-4)
    axes[0].set_xlabel('Time Step', fontsize=fs-2)
    axes[0].set_ylabel('Q-value', fontsize=fs-2)
    axes[0].set_title('Individual Q-values for All Agents', fontsize=fs)
    axes[0].legend(fontsize=fs-2)
    axes[0].grid(True, alpha=0.3)

    Q_var_action0 = np.array(Q_history_all[:, :, :, 0]).var(axis=(0, 2))  # (T,)
    Q_var_action1 = np.array(Q_history_all[:, :, :, 1]).var(axis=(0, 2))  # (T,)

    axes[1].plot(time_points, Q_var_action0, label='action $a_0$ (simulation)', 
                linewidth=2, color=action_colors[0], alpha=0.8)
    axes[1].plot(time_points, Q_var_action1, label='action $a_1$ (simulation)', 
                linewidth=2, color=action_colors[1], alpha=0.8)

    Var0_action0 = Q_var_action0[0]
    Var0_action1 = Q_var_action1[0]
    theoretical_var_action0 = Var0_action0 * np.exp(-2 * alpha * time_points)
    theoretical_var_action1 = Var0_action1 * np.exp(-2 * alpha * time_points)

    n_points = 8
    indices = np.linspace(0, len(time_points)-1, n_points, dtype=int)
    axes[1].scatter(time_points[indices], theoretical_var_action0[indices], 
                color=action_colors[0], alpha=0.8, s=90, label='action $a_0$ (theory)')
    axes[1].scatter(time_points[indices], theoretical_var_action1[indices], 
                color=action_colors[1], alpha=0.8, s=90, label='action $a_1$ (theory)', zorder=10)

    axes[1].tick_params(axis='both', labelsize=fs-4)
    axes[1].set_xlabel('Time Step', fontsize=fs-2)
    axes[1].set_ylabel('Q-value Variance', fontsize=fs-2)
    axes[1].set_title('Variance of Q-values Over Time', fontsize=fs)
    axes[1].legend(fontsize=fs-2)
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('figures/normal-form1.png', dpi=600)
    plt.show()
