"""
This code implements a flexible Self-attention architecture,
designed to study the relationship between self-attention mechanisms and 
strategic classification.

In particular, this module serves as the foundation for our empirical validation of
GLIM, a gradient-free method for bi-level strategic classification using
in-context learning (ICL). The code enables controlled analysis of how 
self-attention layers within large models implicitly simulate:

  - Inner-level optimization (strategic manipulation via attention-induced updates),
  - Outer-level optimization (decision rule adaptation without retraining),
  - Distributional shifts of agent features under attention transformations,
  - Similarity between ICL-induced updates and classical gradient descent directions.

It supports:
  - Softmax, unnormalized (linear), and mixed attention mechanisms,
  - Extraction of attention weights for interpretability,
  - Modularity for structural ablations (e.g., removing MLP blocks, normalization, etc.),
  - Position encodings and contextual prompting for emulating strategic environments.

This implementation is used in Section 4 and Appendix J of our paper to support
comparisons across gradient-based methods, ICL-based forward-only strategies,
and distributional manipulation patterns.
"""


import dataclasses
import math
import warnings
from typing import Optional

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np


# ---------------------------------------------------------------------------
# 1) Embedding / Vocabulary
# ---------------------------------------------------------------------------
class TokenVocab(hk.Module):
  """Learnable Vocabulary with certain "token" size.
  
  These will be chosen based on a sequence of integers given to the system
  and function as input to the Transformer. If 'logits=True', acts like a 
  'decoder' projection from hidden states to vocab logits.
  """

  def __init__(
      self,
      w_init: hk.initializers.Initializer,
      e_size: Optional[int] = 128,
      vocab_size: Optional[int] = 60000,
      name: Optional[str] = None,
  ):
    super().__init__(name=name)
    self.w_init = w_init
    self.e_size = e_size
    self.vocab_size = vocab_size

  def __call__(self, x, logits=False):
    """Forward pass:
      - If logits=False: x is token indices -> returns embeddings from vocab.
      - If logits=True : x is hidden states -> returns vocab-sized logits.
    """
    vocab = hk.get_parameter("vocab", [self.vocab_size, 1, self.e_size], init=self.w_init)
    if logits:
      # Project hidden state onto each vocab entry
      return jnp.einsum("...l,Vl->...V", x, jnp.squeeze(vocab))
    else:
      # Embedding lookup
      return jnp.take_along_axis(vocab, jnp.expand_dims(x, axis=-1), axis=0)


# ---------------------------------------------------------------------------
# 2) Multi-Head Attention
# ---------------------------------------------------------------------------
@dataclasses.dataclass
class MultiHeadAttention(hk.Module):
  """Multi-headed attention (MHA) module.

  Key points:
    - Supports optional softmax attention, linear or non-linear mixes.
    - sum_normalization: experimental feature for linearly normalized attention.
  """

  num_heads: int
  key_size: int
  w_init: hk.initializers.Initializer
  value_size: Optional[int] = None
  model_size: Optional[int] = None
  use_bias_p: bool = False
  use_softmax: bool = False
  use_non_lin_mix: bool = False
  sum_normalization: bool = False
  name: Optional[str] = None

  def __call__(
      self,
      query: jnp.ndarray,
      key: jnp.ndarray,
      value: jnp.ndarray,
      mask: Optional[jnp.ndarray] = None,
  ) -> jnp.ndarray:
    """Computes multi-head attention outputs and weights."""
    super().__init__(name=self.name)
    self.value_size = self.value_size or self.key_size
    self.model_size = self.model_size or self.key_size * self.num_heads

    # 1) Linear projections for Q, K, V
    query_heads = self._linear_projection(query, self.key_size, self.use_bias_p, "query")
    key_heads   = self._linear_projection(key,   self.key_size, self.use_bias_p, "key")
    value_heads = self._linear_projection(value, self.value_size, self.use_bias_p, "value")

    # Optional sum normalization from Schlag et al. (if used)
    if self.sum_normalization:
      query_heads = query_heads / (jnp.sum(query_heads, axis=-1, keepdims=True) + 1e-6)
      key_heads   = key_heads   / (jnp.sum(key_heads,   axis=-1, keepdims=True) + 1e-6)

    # 2) Compute attention logits: QK^T
    attn_logits = jnp.einsum("...thd,...Thd->...htT", query_heads, key_heads)

    # Optional mask
    if mask is not None:
      if mask.ndim != attn_logits.ndim:
        raise ValueError(
            f"Mask dimensionality {mask.ndim} must match logits {attn_logits.ndim}."
        )
      attn_logits = jnp.where(mask, attn_logits, -1e30)

    # 3) Compute attention weights
    if self.use_softmax:
      # Standard softmax attention, scaled by sqrt(d_k)
      attn_weights = jax.nn.softmax(attn_logits / np.sqrt(self.key_size).astype(key.dtype))
    elif self.use_non_lin_mix:
      # Non-linear mix of softmax + raw logits
      y = hk.Linear(1, with_bias=False, w_init=self.w_init, name='non_lin_mix')(jnp.array([1.0]))
      sigmoid_gate = jax.nn.sigmoid(y * 10)
      attn_weights = (jax.nn.softmax(attn_logits / np.sqrt(self.key_size).astype(key.dtype)) * sigmoid_gate
                      + (1 - sigmoid_gate) * attn_logits)
    else:
      # "Linear" or no normalization scenario
      attn_weights = attn_logits

    # 4) Weighted sum of the value vectors
    attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads)
    # Flatten the multi-head dimension
    attn = jnp.reshape(attn, (*query.shape[:-1], -1))

    # Final projection
    final_projection = hk.Linear(self.model_size, w_init=self.w_init, with_bias=self.use_bias_p)
    out = final_projection(attn)

    return out, attn_weights

  @hk.transparent
  def _linear_projection(
      self,
      x: jnp.ndarray,
      head_size: int,
      with_bias: bool,
      name: str,
  ) -> jnp.ndarray:
    """Project x => [num_heads * head_size], then reshape => [..., num_heads, head_size]."""
    y = hk.Linear(self.num_heads * head_size, with_bias=with_bias,
                  w_init=self.w_init, name=name)(x)
    # Reshape for multi-head
    return y.reshape((*x.shape[:-1], self.num_heads, head_size))


# ---------------------------------------------------------------------------
# 3) MLP
# ---------------------------------------------------------------------------
@dataclasses.dataclass
class MLP(hk.Module):
  """A simple two-layer MLP (optionally 2 hidden layers), with GELU activations."""

  w_init: hk.initializers.Initializer
  widening_factor: int = 4
  second_layer: bool = False
  use_bias_p: bool = False
  outputdim: int = 0
  name: Optional[str] = None

  def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
    super().__init__(name=self.name)
    hiddens = x.shape[-1]
    # First hidden layer
    x = hk.Linear(self.widening_factor * hiddens, with_bias=self.use_bias_p,
                  w_init=self.w_init)(x)
    x = jax.nn.gelu(x)
    # Optionally second hidden layer
    if self.second_layer:
      x = hk.Linear(self.widening_factor * hiddens, with_bias=self.use_bias_p,
                    w_init=self.w_init)(x)
      x = jax.nn.gelu(x)

    # Output layer
    if self.outputdim == 0:
      return hk.Linear(hiddens, with_bias=self.use_bias_p, w_init=self.w_init)(x)
    else:
      return hk.Linear(self.outputdim, with_bias=self.use_bias_p, w_init=self.w_init)(x)


# ---------------------------------------------------------------------------
# 4) LayerNorm Utilities
# ---------------------------------------------------------------------------
@dataclasses.dataclass
class LNorm(hk.Module):
  """A small convenience wrapper around Haiku's LayerNorm."""

  name: Optional[str] = None

  def __call__(self, x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray:
    super().__init__(name=self.name)
    return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True, name=name)(x)


def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray:
  """Apply a LayerNorm operation to x with default settings."""
  return hk.LayerNorm(axis=-1,
                      create_scale=True,
                      create_offset=True,
                      name=name)(x)


# ---------------------------------------------------------------------------
# 5) Positional Encoding
# ---------------------------------------------------------------------------
def create_pos_encoding(context_size, input_size, flip=False):
  """Create constant sinusoidal positional encoding."""
  pe = np.zeros((context_size, input_size), dtype=np.float32)
  position = np.arange(0, context_size, dtype=np.float32)[:, None]
  div_term = np.exp(np.arange(0, input_size, 2) * (-math.log(10000.0) / input_size))
  pe[:, 0::2] = np.sin(position * div_term)
  pe[:, 1::2] = np.cos(position * div_term)

  if flip:
    pe = np.flip(pe, axis=0)

  # shape => [context_size, input_size]
  return jax.device_put(pe)


def create_pos_encoding_diff(context_size, input_size):
  """Another variant of sinusoidal positional encoding."""
  pe = np.zeros((context_size, input_size), dtype=np.float32)
  position = np.arange(0, context_size, dtype=np.float32)[:, None]
  twoi = np.arange(0, input_size, 2, dtype=np.float32)
  pe[:, 0::2] = np.sin(position / (10000 ** (twoi / input_size)))
  pe[:, 1::2] = np.cos(position / (10000 ** (twoi / input_size)))
  return jax.device_put(pe)


# ---------------------------------------------------------------------------
# 6) Transformer
# ---------------------------------------------------------------------------
@dataclasses.dataclass
class Transformer(hk.Module):
  """A flexible Transformer implementation.

  Allows ablation studies on architecture choices such as including
  linear projections, skip connections, normalization layers, 
  and more. 
  """

  num_heads: int = 2
  widening_factor: int = 4
  num_layers: int = 3
  key_size: int = 5
  embedding_size: int = 64
  output_size: int = 1

  in_context_length: int = 17
  in_context_length_test: int = 17
  test_points: int = 1
  dropout_rate: float = 0.

  only_attention: bool = True
  use_layer_norm: bool = True
  use_pe: bool = True
  pe_size: int = 6
  concat_pe: bool = False
  output_mapping: bool = False
  input_mapping: bool = False
  use_bias_p: bool = True
  zero_embeddings: bool = False
  deq: bool = True
  init_scale: float = 0.02
  use_softmax: bool = False
  use_non_lin_mix: bool = False
  first_layer_sm: bool = False
  y_update: bool = False
  input_mlp: bool = False
  input_mlp_out_dim: int = 0
  gd_mlp_config: bool = False
  sum_norm: bool = False
  dampening: float = 1.0
  clip: float = 0.0
  ana_copy: bool = False
  flip: bool = False
  vocab_size: int = 0
  vocab_token_dim: int = 0
  vocab_init: float = 0.01
  return_logits: bool = False
  include_query: bool = False
  name: Optional[str] = None

  def __post_init__(self):
    super().__init__(name=self.name)
    # Prepare (optional) position encodings
    if self.pe_size > 0:
      self.pos_encoding = create_pos_encoding(self.in_context_length, self.pe_size, self.flip)
      self.pos_encoding_test = create_pos_encoding(self.in_context_length_test, self.pe_size, self.flip)
    else:
      self.pos_encoding = None
      self.pos_encoding_test = None

  def trans_block(self, h, nl):
    """One Transformer block:
       - LayerNorm (optionally) => Attention => Residual => 
         => (optional) MLP => Residual 
    """
    if self.deq:
      # For 'DEQ' style: store modules in self to reuse them at each forward pass
      h_norm = self.lnorm1(h) if self.use_layer_norm else h
      if not self.include_query:
        key = h_norm[:, :-1, :]
        value = h_norm[:, :-1, :]
      else:
        key = h_norm
        value = h_norm

      h_attn, att_map = self.attn_block(h_norm, key, value)
    else:
      # For standard (non-DEQ) usage
      if nl == 0:
        h_norm = h
      else:
        h_norm = layer_norm(h, name="norm_"+str(nl)) if self.use_layer_norm else h

      sm = (self.use_softmax or (self.first_layer_sm and nl == 0))
      mix = (self.use_non_lin_mix and nl == 0)

      attn_block = MultiHeadAttention(
          num_heads=self.num_heads,
          key_size=self.key_size,
          model_size=self.model_size,
          w_init=self.w_init,
          use_softmax=sm,
          use_non_lin_mix=mix,
          use_bias_p=self.use_bias_p,
          sum_normalization=self.sum_norm,
          name="layer_"+str(nl))
      if not self.include_query:
        key = h_norm[:, :-1, :]
        value = h_norm[:, :-1, :]
      else:
        key = h_norm
        value = h_norm

      h_attn, att_map = attn_block(h_norm, key, value)

    h_attn = hk.dropout(hk.next_rng_key(), self.dropout_rate, h_attn)

    if self.y_update:
      # y_update: updates only the "last token" or some dimension
      h = h.at[:, :, -1].set(h[:, :, -1] + self.dampening * h_attn[:, :, -1])
    else:
      # standard residual
      h = h + self.dampening * h_attn

    if self.clip > 0:
      h = jnp.clip(h, -self.clip, self.clip)

    # if only_attention=False, we do the MLP block
    if not self.only_attention:
      if self.deq:
        h_inter_norm = self.lnorm2(h) if self.use_layer_norm else h
        h_dense = self.dense_block(h_inter_norm)
      else:
        h_inter_norm = layer_norm(h) if self.use_layer_norm else h
        dense_block = MLP(w_init=self.w_init,
                          widening_factor=self.widening_factor,
                          use_bias_p=self.use_bias_p)
        h_dense = dense_block(h_inter_norm)

      h_dense = hk.dropout(hk.next_rng_key(), self.dropout_rate, h_dense)
      h = h + self.dampening * h_dense

      if self.clip > 0:
        h = jnp.clip(h, -self.clip, self.clip)

    return h, att_map

  def __call__(self, x: jnp.ndarray, is_training: bool, predict_test: bool):
    """Transformer forward pass.

    Args:
      x: [B, T, ...] input tokens or features
      is_training: whether training or not (used for dropout)
      predict_test: if True, use test context length & pos enc
    Returns:
      out: final transformer output, shape [B, T, output_size] if output_mapping
      stack_h: intermediate states
      stack_att: attention maps
    """

    # 1) If using vocab-based embeddings
    if self.vocab_size > 0 and self.vocab_token_dim > 0:
      self.w_init_vocab = hk.initializers.VarianceScaling(self.vocab_init)
      vocab = TokenVocab(w_init=self.w_init_vocab,
                         e_size=self.vocab_token_dim,
                         vocab_size=self.vocab_size)
      x = vocab(x)

    self.w_init = hk.initializers.VarianceScaling(self.init_scale)
    self.dropout_rate = self.dropout_rate if is_training else 0.

    # 2) Optional input mapping
    if self.input_mapping:
      embeddings = hk.Linear(self.embedding_size, with_bias=self.use_bias_p,
                             w_init=self.w_init, name="emb")(x)
    else:
      embeddings = x

    # 2a) Optionally an MLP on the input
    if self.input_mlp:
      input_mlp = MLP(w_init=self.w_init,
                      widening_factor=self.widening_factor,
                      second_layer=False,
                      use_bias_p=True,
                      outputdim=self.input_mlp_out_dim,
                      name="input_mlp")
      # "gd_mlp_config" is a placeholder for specialized usage, not shown in detail
      embeddings = embeddings + input_mlp(embeddings)

    # 3) Add or concat positional encodings
    if self.use_pe and self.pe_size > 0:
      if self.concat_pe:
        if predict_test:
          pos_encoding_test = self.pos_encoding_test[None, ...]
          pos_encoding_test = jnp.repeat(pos_encoding_test, embeddings.shape[0], axis=0)
          if self.zero_embeddings:
            pos_encoding_test = pos_encoding_test * 0
          h = jnp.concatenate([embeddings, pos_encoding_test], axis=2)
        else:
          pos_encoding = self.pos_encoding[None, ...]
          pos_encoding = jnp.repeat(pos_encoding, embeddings.shape[0], axis=0)
          if self.zero_embeddings:
            pos_encoding = pos_encoding * 0
          h = jnp.concatenate([embeddings, pos_encoding], axis=2)
      else:
        # standard "x + PE" approach
        if predict_test:
          pe_test = self.pos_encoding_test
          if self.zero_embeddings:
            pe_test = pe_test * 0
          h = pe_test + embeddings
        else:
          pe_train = self.pos_encoding
          if self.zero_embeddings:
            pe_train = pe_train * 0
          h = pe_train + embeddings
    else:
      h = embeddings

    # 4) Set the "model_size" once we know the shape
    if len(h.shape) == 2:
      # shape [B, D]
      _, model_size = h.shape
    elif len(h.shape) == 3:
      # shape [B, T, D]
      _, _, model_size = h.shape
    self.model_size = model_size

    # 5) For "deq" style usage, define single MHA/MLP modules for repeated usage
    if self.deq:
      self.attn_block = MultiHeadAttention(num_heads=self.num_heads,
                                           key_size=self.key_size,
                                           model_size=model_size,
                                           w_init=self.w_init,
                                           use_softmax=self.use_softmax,
                                           use_non_lin_mix=self.use_non_lin_mix,
                                           use_bias_p=self.use_bias_p,
                                           sum_normalization=self.sum_norm)
      if not self.only_attention:
        self.dense_block = MLP(w_init=self.w_init,
                               widening_factor=self.widening_factor,
                               use_bias_p=self.use_bias_p)

      if self.use_layer_norm:
        self.lnorm1 = LNorm()
        self.lnorm2 = LNorm()

    # 6) Recurrent or standard Transformer blocks
    # - optional storing of intermediate states for analysis
    st = h[:, -1, -1] * (-1.0) if (len(h.shape) == 3 and not self.ana_copy and not self.include_query) \
                                else (h if self.include_query else h[:, :-1, :])
    stack_h = [] if not self.input_mlp else [st]
    stack_att = []

    for nl in range(self.num_layers):
      h, att_map = self.trans_block(h, nl)
      st = h[:, -1, -1] * (-1.0) if (len(h.shape) == 3 and not self.ana_copy and not self.include_query) \
                                 else (h if self.include_query else h[:, :-1, :])
      stack_h.append(st)
      stack_att.append(att_map)

    # 7) Optional final linear mapping => output
    if self.output_mapping:
      out = hk.Linear(self.output_size, w_init=self.w_init)(h)
    else:
      out = h

    # 8) If we want to decode output to vocab logits
    if self.return_logits:
      # Reuse the same vocab -> logits
      vocab = TokenVocab(self.w_init, e_size=self.output_size, vocab_size=self.vocab_size)
      out = vocab(out, logits=True)

    return (out, stack_h, stack_att)
