"""
Configuration for Hyperfitting Analysis Experiments
"""

from dataclasses import dataclass, field
from typing import List


@dataclass
class ModelConfig:
    """
    Configuration for model loading.
    """
    model_name: str = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
    torch_dtype: str = "bfloat16"  # or "float16" for older GPUs
    device_map: str = "auto"
    load_in_8bit: bool = False
    load_in_4bit: bool = False
    

@dataclass
class HyperfittingConfig:
    """
    Configuration for hyperfitting training (following the paper).
    """
    # Dataset
    dataset_name: str = "fiction-stories"
    dataset_split: str = "train"
    num_samples: int = 2000
    sequence_length: int = 256
    
    # Training
    num_epochs: int = 20
    learning_rate: float = 1e-6
    batch_size: int = 8
    weight_decay: float = 0.0  # Paper explicitly uses no weight decay
    
    # Optimizer
    optimizer: str = "adam"
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_epsilon: float = 1e-8
    
    # Gradient
    gradient_accumulation_steps: int = 1
    max_grad_norm: float = 50.0
    
    # Saving
    save_dir: str = "./checkpoints"
    save_every_epoch: bool = True
    

@dataclass
class GenerationConfig:
    """
    Configuration for text generation.
    """
    max_new_tokens: int = 224
    do_sample: bool = False  # Greedy decoding by default
    temperature: float = 1.0
    top_p: float = 0.9
    top_k: int = 50
    repetition_penalty: float = 1.0
    

@dataclass
class EvaluationConfig:
    # Contexts for generation
    num_eval_contexts: int = 300
    context_length: int = 32  # First 32 tokens as context as in the original paper
    continuation_length: int = 224  # Generate 224 more tokens as in the original paper
    
    # Metrics
    ttr_window: int = 96  # TTR calculated on last 96 tokens as in the original paper
    
    # Datasets for evaluation
    eval_datasets: List[str] = field(default_factory=lambda: [
        "wikipedia",
        "fiction", 
        "news"
    ])


@dataclass
class ExperimentConfig:
    # Experiment name
    experiment_name: str = "hyperfitting_vs_temperature"
    
    # Output
    output_dir: str = "./results"
    
    # Models to test
    models: List[str] = field(default_factory=lambda: [
        "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
    ])
    
    # Random seed
    seed: int = 42
    
    # Logging
    use_wandb: bool = False
    wandb_project: str = "hyperfitting-analysis"
    
    # Components
    model_config: ModelConfig = field(default_factory=ModelConfig)
    hyperfitting_config: HyperfittingConfig = field(default_factory=HyperfittingConfig)
    generation_config: GenerationConfig = field(default_factory=GenerationConfig)
    evaluation_config: EvaluationConfig = field(default_factory=EvaluationConfig)
    

# Pre-defined configurations for different experiments
TINY_LLAMA_CONFIG = ExperimentConfig(
    experiment_name="tinyllama_hyperfitting",
    models=["TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"],
    model_config=ModelConfig(
        model_name="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
        torch_dtype="bfloat16",
    ),
)

DEEPSEEK_CONFIG = ExperimentConfig(
    experiment_name="deepseek_hyperfitting",
    models=["deepseek-ai/deepseek-llm-7b-base"],
    model_config=ModelConfig(
        model_name="deepseek-ai/deepseek-llm-7b-base",
        torch_dtype="bfloat16",
    ),
)

LLAMA_8B_CONFIG = ExperimentConfig(
    experiment_name="llama8b_hyperfitting",
    models=["meta-llama/Llama-3.1-8B"],
    model_config=ModelConfig(
        model_name="meta-llama/Llama-3.1-8B",
        torch_dtype="bfloat16",
    ),
)


def get_config(config_name: str = "tiny_llama") -> ExperimentConfig:
    configs = {
        "tiny_llama": TINY_LLAMA_CONFIG,
        "deepseek": DEEPSEEK_CONFIG,
        "llama_8b": LLAMA_8B_CONFIG,
    }
    if config_name not in configs:
        raise ValueError(f"Unknown config: {config_name}. Available: {list(configs.keys())}")
    return configs[config_name]
