# memory_module.py
# Differentiable external memory with read (aread) and write (awrite) operations.
# Implements a simple content-based addressing + gated erase/add writing mechanism.
# This follows the general style of NTM/DNC read/write operations (differentiable).

from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F


class MemoryModule(nn.Module):
    """
    A simple differentiable external memory.
    - M: memory matrix of shape (N_slots, D)
    - aread(query_h): content-based attention read -> returns read vector r and attention weights
    - awrite(write_h): content-based write with erase/add vectors -> updates memory in-place (differentiable)

    Notes:
      - This implementation keeps memory as a parameter tensor that can be persisted or reset.
      - For batch usage, queries can be batched; memory is shared across batch items (common in many designs).
    """

    def __init__(
        self,
        n_slots: int = 128,
        slot_dim: int = 256,
        key_dim: Optional[int] = None,
        device: Optional[str] = None,
    ):
        """
        Args:
          n_slots: number of memory slots (N)
          slot_dim: dimensionality of each slot (D)
          key_dim: query/key dimensionality (defaults to slot_dim)
        """
        super().__init__()
        self.n_slots = n_slots
        self.slot_dim = slot_dim
        self.key_dim = key_dim or slot_dim
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        # Memory matrix initialized learnably but can be re-initialized externally
        self.register_buffer("M", torch.randn(n_slots, slot_dim) * 1e-2)

        # Small projections for query/key and write controllers
        self.key_proj = nn.Linear(self.key_dim, slot_dim)       # project query to memory space
        # For writing: produce write weight logits, erase vector (0..1), add vector
        self.write_w_proj = nn.Linear(self.key_dim, n_slots)    # produces write weights over slots
        self.erase_proj = nn.Linear(self.key_dim, slot_dim)     # produces erase vector (sigmoid)
        self.add_proj = nn.Linear(self.key_dim, slot_dim)       # produces add vector (tanh or linear)

        # optional temperature for softmax addressing
        self._addr_temperature = 1.0

    def reset_memory(self, init: Optional[torch.Tensor] = None):
        """
        Reset memory. If init is provided, it must be (n_slots, slot_dim).
        """
        if init is None:
            self.M = torch.randn(self.n_slots, self.slot_dim, device=self.M.device) * 1e-2
        else:
            assert init.shape == (self.n_slots, self.slot_dim)
            self.M = init.to(self.M.device)

    def content_addressing(self, query: torch.Tensor) -> torch.Tensor:
        """
        Compute content similarity scores between query and memory slots.
        Args:
          query: (key_dim,) or (batch, key_dim)
        Returns:
          weights: (n_slots,) or (batch, n_slots) after softmax
        """
        # project query into memory space
        q_proj = self.key_proj(query)  # (slot_dim) or (batch, slot_dim)
        # compute cosine similarity or dot-product
        # normalize to avoid scale issues
        M = self.M  # (N, D)
        if q_proj.dim() == 1:
            # (D,) vs (N, D) -> (N,)
            scores = F.cosine_similarity(q_proj.unsqueeze(0), M, dim=-1) / max(self._addr_temperature, 1e-6)
            weights = F.softmax(scores, dim=-1)  # (N,)
        else:
            # batch (B, D) vs (N, D) -> (B, N)
            # expand and compute dot
            # compute normalized dot product
            qn = F.normalize(q_proj, dim=-1)    # (B, D)
            Mn = F.normalize(M, dim=-1)         # (N, D)
            scores = torch.matmul(qn, Mn.t()) / max(self._addr_temperature, 1e-6)  # (B, N)
            weights = F.softmax(scores, dim=-1)
        return weights

    def aread(self, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Read operation.
        Args:
          h: query tensor (key_dim,) or (batch, key_dim)
        Returns:
          r: read vector(s) (slot_dim,) or (batch, slot_dim)
          w: attention weights (n_slots,) or (batch, n_slots)
        """
        w = self.content_addressing(h)  # (N,) or (B, N)
        # produce read vector as weighted sum over memory slots
        if w.dim() == 1:
            r = torch.matmul(w.unsqueeze(0), self.M).squeeze(0)  # (D,)
        else:
            r = torch.matmul(w, self.M)  # (B, D)
        return r, w

    def awrite(self, h: torch.Tensor, gating: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Write operation. Uses content-based weights and parametric erase/add vectors.
        Updates the internal memory M in a differentiable manner.

        Args:
          h: write controller tensor (key_dim,) or (batch, key_dim)
          gating: optional scalar or per-slot gating (if provided). If batch, must match batch size.

        Returns:
          the write weights used (n_slots,) or (batch, n_slots)
        """
        # produce write weights (logits -> softmax)
        w_logits = self.write_w_proj(h)  # (N,) or (B, N)
        if w_logits.dim() == 1:
            w = F.softmax(w_logits, dim=-1)  # (N,)
        else:
            w = F.softmax(w_logits, dim=-1)  # (B, N)

        # produce erase vector in (0,1)
        erase = torch.sigmoid(self.erase_proj(h))  # (D,) or (B, D)
        add = torch.tanh(self.add_proj(h))         # (D,) or (B, D)

        # gating: optional scalar (0..1) to scale write strength
        if gating is not None:
            w = w * gating

        # Update memory:
        # For non-batched case:
        if w.dim() == 1:
            # outer product w_i * erase -> (N, D)
            erase_matrix = torch.ger(w, erase)         # (N, D)
            add_matrix = torch.ger(w, add)             # (N, D)
            # M <- M * (1 - erase_matrix) + add_matrix
            self.M = self.M * (1.0 - erase_matrix) + add_matrix
        else:
            # Batched writes: aggregate writes across batch using sum (common pattern)
            # w: (B, N), erase: (B, D), add: (B, D)
            # compute batch outer: for each batch b, compute w[b].unsqueeze(1) * erase[b].unsqueeze(0)
            B = w.shape[0]
            # accumulate contributions across batch
            erase_acc = torch.zeros_like(self.M)
            add_acc = torch.zeros_like(self.M)
            for b in range(B):
                wb = w[b]                       # (N,)
                e_b = erase[b]                  # (D,)
                a_b = add[b]                    # (D,)
                erase_acc += torch.ger(wb, e_b)
                add_acc += torch.ger(wb, a_b)
            self.M = self.M * (1.0 - erase_acc) + add_acc

        # return the write weights used
        return w
