"""
Registry for gauss-mad framework.

Registries:
- task_registry: All 6 MAD synthetic tasks
- layer_registry: Gauss SSM + baselines + MLP channel mixers
- model_registry: Language models and autoencoder
"""

from mad.data.instances import (
    generate_in_context_recall_instance,
    generate_noisy_in_context_recall_instance,
    generate_fuzzy_in_context_recall_instance,
    generate_memorization_instance,
    generate_compression_instance,
    generate_selective_copying_instance
)
from mad.model import layers
from mad import model


# =============================================================================
# TASK REGISTRY - All 6 MAD synthetic tasks
# =============================================================================
task_registry = {
    'in-context-recall': {
        'instance_fn': generate_in_context_recall_instance,
        'cfg': 'configs/tasks/in-context-recall.yml',
        'shorthand': 'CR'
    },
    'noisy-in-context-recall': {
        'instance_fn': generate_noisy_in_context_recall_instance,
        'cfg': 'configs/tasks/noisy-in-context-recall.yml',
        'shorthand': 'NR'
    },
    'fuzzy-in-context-recall': {
        'instance_fn': generate_fuzzy_in_context_recall_instance,
        'cfg': 'configs/tasks/fuzzy-in-context-recall.yml',
        'shorthand': 'FR'
    },
    'memorization': {
        'instance_fn': generate_memorization_instance,
        'cfg': 'configs/tasks/memorization.yml',
        'shorthand': 'M'
    },
    'compression': {
        'instance_fn': generate_compression_instance,
        'cfg': 'configs/tasks/compression.yml',
        'shorthand': 'C'
    },
    'selective-copying': {
        'instance_fn': generate_selective_copying_instance,
        'cfg': 'configs/tasks/selective-copying.yml',
        'shorthand': 'SC'
    },
}


# =============================================================================
# LAYER REGISTRY - Gauss SSM + baselines + MLP channel mixers
# =============================================================================
layer_registry = {
    # === GAUSS SSM (main implementation) ===
    'gauss': {
        'module': layers.GaussBlock,
        'cfg': 'configs/layers/gauss.yml',
        'shorthand': 'G'
    },
    # Gauss MIMO rank variants
    'gauss-r1': {
        'module': layers.GaussBlock,
        'cfg': 'configs/layers/gauss_r1.yml',
        'shorthand': 'G-r1'
    },
    'gauss-r1-Dvec': {
        'module': layers.GaussBlock,
        'cfg': 'configs/layers/gauss_r1_Dvec.yml',
        'shorthand': 'G-r1D'
    },
    'gauss-r1-variance': {
        'module': layers.GaussBlock,
        'cfg': 'configs/layers/gauss_r1_variance.yml',
        'shorthand': 'G-r1V'
    },
    'gauss-r4': {
        'module': layers.GaussBlock,
        'cfg': 'configs/layers/gauss_r4.yml',
        'shorthand': 'G-r4'
    },
    'gauss-r4-Dvec': {
        'module': layers.GaussBlock,
        'cfg': 'configs/layers/gauss_r4_Dvec.yml',
        'shorthand': 'G-r4D'
    },
    'gauss-r4-variance': {
        'module': layers.GaussBlock,
        'cfg': 'configs/layers/gauss_r4_variance.yml',
        'shorthand': 'G-r4V'
    },
    'gauss-r8': {
        'module': layers.GaussBlock,
        'cfg': 'configs/layers/gauss_r8.yml',
        'shorthand': 'G-r8'
    },
    'gauss-r8-Dvec': {
        'module': layers.GaussBlock,
        'cfg': 'configs/layers/gauss_r8_Dvec.yml',
        'shorthand': 'G-r8D'
    },
    'gauss-r8-variance': {
        'module': layers.GaussBlock,
        'cfg': 'configs/layers/gauss_r8_variance.yml',
        'shorthand': 'G-r8V'
    },

    # === KLA (Kalman Linear Attention) ===
    'kla': {
        'module': layers.KLABlock,
        'cfg': 'configs/layers/kla.yml',
        'shorthand': 'KLA'
    },

    # === BASELINES ===
    'mamba': {
        'module': layers.Mamba,
        'cfg': 'configs/layers/mamba.yml',
        'shorthand': 'Ma'
    },
    'gated-delta-net': {
        'module': layers.GatedDeltaNet,
        'cfg': 'configs/layers/gated-delta-net.yml',
        'shorthand': 'GDN'
    },
    'gated-linear-attention': {
        'module': layers.GatedLinearAttention,
        'cfg': 'configs/layers/gated-linear-attention.yml',
        'shorthand': 'GLA'
    },

    # === CHANNEL MIXERS ===
    'mlp': {
        'module': layers.Mlp,
        'cfg': 'configs/layers/mlp.yml',
        'shorthand': 'M'
    },
    'swiglu': {
        'module': layers.SwiGLU,
        'cfg': 'configs/layers/swiglu.yml',
        'shorthand': 'Sg'
    },
}


# =============================================================================
# MODEL REGISTRY - Backbone models
# =============================================================================
model_registry = {
    'language-model': model.LanguageModel,
    'language-model-minimal': model.LanguageModelMinimal,
    'autoencoder': model.AutoEncoder,
}
