"""Configuration classes for SteerCLR training."""

import random
import os
from datetime import datetime
from pathlib import Path
from typing import Literal

import numpy as np
import torch
from pydantic import BaseModel, Field, field_validator


class SteerCLRTrainerConfig(BaseModel):
    """Configuration for SteerCLRTrainer."""

    # Model and data configuration
    model_name: str = Field(..., description="HuggingFace model name or path")
    target_layer: int = Field(..., description="Layer index for activation capture")
    source_layer: int = Field(..., description="Layer index for activation steering")
    source_submodule: str | None = Field(
        default=None,
        description="Dot-path of submodule within the transformer block to inject (e.g., 'mlp.down_proj'). If None, inject at the whole block output.",
    )
    target_submodule: str | None = Field(
        default=None,
        description="Dot-path of submodule within the transformer block to capture from. If None, capture at the whole block output.",
    )

    # Training hyperparameters
    n_training_steps: int = Field(
        default=1000, gt=0, description="Number of training steps"
    )
    n_vectors: int = Field(
        default=8, gt=0, description="Number of steering vectors to learn"
    )

    # Optimizer configuration
    learning_rate: float = Field(
        default=1e-3, gt=0, description="Learning rate for optimizer"
    )
    optimizer_weight_decay: float = Field(
        default=0.0, ge=0, description="Weight decay for optimizer"
    )
    warmup_steps_ratio: float = Field(
        default=0.1, ge=0, le=1, description="Ratio of training steps to use for warmup"
    )

    # Numerical stability
    epsilon: float = Field(
        default=1e-6,
        gt=0,
        description="Epsilon for numerical stability in normalization",
    )
    token_idxs: int = Field(
        default=-1,
        description="Slice start for token aggregation; negative means last -token_idxs tokens (e.g., -5 for last 5).",
    )

    # Loss function weights
    alpha: float = Field(default=1.0, ge=0, description="Weight for magnitude loss")
    alpha_p: float = Field(default=2.0, description="Exponent for magnitude loss")
    alpha_q: float = Field(default=1.0, description="Exponent for magnitude loss")
    beta: float = Field(default=1.0, ge=0, description="Weight for diversity loss")
    diversity_loss_type: Literal["supcon", "circle", "multisimilarity", "ntxent"] = (
        Field(default="supcon", description="Type of diversity loss to use")
    )
    lambda_: float = Field(
        default=0.1, ge=0, description="Weight for orthogonality loss"
    )
    tau: float = Field(default=0.1, gt=0, description="Temperature for diversity loss")
    orthogonality_style: Literal["cosine_offdiag", "mse_identity"] = Field(
        default="cosine_offdiag",
        description="Orthogonality penalty style: sum off-diagonal cosines or MSE(ΔΔ^T, I)",
    )

    # Vector constraints
    radius: float = Field(
        default=1.0, gt=0, description="Maximum L2 norm for steering vectors"
    )
    normalize_steering_vectors: bool = Field(
        default=True, description="Normalize steering vectors to radius"
    )

    # Reproducibility
    seed: int = Field(default=42, description="Random seed for reproducibility")

    # Data configuration
    batch_size: int = Field(default=4, gt=0, description="Batch size for training")
    num_vectors_per_batch: int = Field(
        default=4, gt=0, description="Number of vectors per batch"
    )
    max_length: int = Field(default=512, gt=0, description="Maximum sequence length")
    train_texts_files: list[str] = Field(
        default=[],
        min_length=1,
        description="List of paths to JSON files with training texts (alternative to train_texts_file)",
    )

    # Output configuration
    output_dir: str | Path = Field(
        default="./outputs", description="Directory to save outputs"
    )
    experiment_id: str | None = Field(
        default=None, description="Unique experiment identifier"
    )

    # Validation configuration
    val_frequency: int = Field(
        default=250, gt=0, description="Run validation every N steps"
    )
    val_texts_files: list[str] | None = Field(
        default=None, description="List of paths to JSON files with validation texts"
    )
    val_num_samples: int | None = Field(
        default=None,
        gt=0,
        description="If set, limit validation to a deterministic subset of this many samples",
    )
    val_subset_seed: int | None = Field(
        default=None,
        description="Seed to deterministically choose validation subset (defaults to seed)",
    )

    # Generation configuration for validation
    save_generations: bool = Field(
        default=True, description="Whether to save generation outputs to disk"
    )
    generate_during_training: bool = Field(
        default=False,
        description="If True, run text generation during intermediate validations; if False, only after training ends",
    )
    val_max_new_tokens: int = Field(
        default=64, gt=0, description="Max new tokens to generate for validation"
    )
    val_temperature: float = Field(
        default=0.7, gt=0.0, description="Sampling temperature for validation"
    )
    val_top_p: float = Field(
        default=0.95, gt=0.0, description="Top-p for nucleus sampling in validation"
    )
    val_do_sample: bool = Field(
        default=True, description="Use sampling in validation generation"
    )
    val_vectors_per_call: int = Field(
        default=4,
        gt=0,
        description="How many steering vectors to evaluate per generate() call during validation (higher = fewer calls, more memory)",
    )

    # Device configuration
    device: str = Field(
        default="cuda", description="Device to use ('auto', 'cpu', 'cuda', 'mps')"
    )

    # Weights & Biases logging
    wandb_project: str = Field(default="steerclr", description="W&B project name")
    wandb_entity: str | None = Field(
        default=None, description="W&B entity/team name (optional)"
    )
    wandb_run_name: str | None = Field(
        default=None, description="W&B run name (auto-generated if None)"
    )
    wandb_tags: list[str] = Field(default_factory=list, description="Tags for W&B run")

    @field_validator("output_dir")
    @classmethod
    def validate_output_dir(cls, v):
        """Convert string to Path if needed."""
        return Path(v)

    def get_experiment_dir(self) -> Path:
        """Get the experiment directory path with timestamp and random ID."""
        base_dir = Path(self.output_dir)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        model_name = self.model_name.split("/")[-1].replace(".", "_").replace("-", "_")
        model_dir = base_dir / model_name
        # Include PID to avoid collisions across concurrent runs with same seed
        if self.experiment_id:
            safe_exp_id = str(self.experiment_id).replace(" ", "_")
            exp_name = f"{timestamp}_{safe_exp_id}_{os.getpid():06d}"
        else:
            exp_name = f"{timestamp}_{os.getpid():06d}"
        return model_dir / exp_name

    @field_validator("device")
    @classmethod
    def validate_device(cls, v):
        """Validate device string."""
        valid_devices = ["auto", "cpu", "cuda", "mps"]
        if v not in valid_devices:
            raise ValueError(f"Device must be one of {valid_devices}")
        return v

    def setup_seeds(self) -> None:
        """Set up random seeds and deterministic behavior."""
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)

    def get_device(self) -> torch.device:
        """Get the appropriate device."""
        if self.device == "auto":
            if torch.cuda.is_available():
                return torch.device("cuda")
            elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
                return torch.device("mps")
            else:
                return torch.device("cpu")
        else:
            return torch.device(self.device)

    def to_dict(self) -> dict:
        """Convert config to dictionary for logging."""
        return self.model_dump()

    def to_yaml(self) -> str:
        """Convert config to YAML string for logging."""
        import yaml

        return yaml.dump(self.to_dict(), default_flow_style=False)

    class Config:
        """Pydantic configuration."""

        extra = "forbid"
        validate_assignment = True
