"""Configuration management for the fine-tuning pipeline."""

import os
import random
from dataclasses import dataclass, field
from pathlib import Path

import yaml


DEFAULT_BASE_URLS = [
    "http://localhost:4000/v1",
    "http://localhost:4001/v1",
    "http://localhost:4002/v1",
    "http://localhost:4003/v1",
    "http://localhost:4004/v1",
    "http://localhost:4005/v1",
    "http://localhost:4006/v1",
    "http://localhost:4007/v1",
]


@dataclass
class LLMConfig:
    """LLM configuration for agent conversations.
    
    base_url is hardcoded to load balance across ports 4000-4007.
    model can be configured in yaml (single or list).
    """
    api_key: str = ""  # Will be loaded from env if empty
    model: str = ""  # Optional: single model
    models: list[str] = field(default_factory=list)  # Optional: list of models for load balancing

    def get_api_key(self) -> str:
        """Get API key from config or environment variable."""
        return self.api_key or os.getenv("LLM_API_KEY", "")

    def get_base_url(self) -> str:
        """Get base URL - randomly select from hardcoded list (ports 4000-4007)."""
        return random.choice(DEFAULT_BASE_URLS)

    def get_model(self) -> str:
        """Get model name - randomly select from list if multiple provided."""
        if self.models:
            return random.choice(self.models)
        return self.model or os.getenv("LLM_MODEL", "gpt-5")


@dataclass
class PipelineConfig:
    """Pipeline configuration."""
    max_iterations: int = 5
    timeout_hours: float = 12.0  # Global timeout in hours (0 = no timeout)
    # Path to save final summary (e.g., "results/final_summary.json")
    # If relative, it's relative to workspace
    results_summary_path: str = "results/final_summary.json"


@dataclass
class DataConfig:
    """Data configuration."""
    # Dataset name (key in datasets.json mapping file)
    # e.g., "chemcot" -> looks up path in configs/datasets.json
    dataset: str = "chemcot_mol_und"
    # Path to datasets mapping file (relative to workspace)
    datasets_config: str = "configs/datasets.json"
    # Maximum number of samples to select for cleaning (default 2000)
    max_samples: int = 2000


@dataclass
class TrainingConfig:
    """LlamaFactory training configuration.

    Only base_model is configured here. Other training parameters
    (lora_rank, epochs, etc.) are decided by the Training Agent.
    """
    base_model: str = "Qwen/Qwen2.5-7B-Instruct"
    num_gpus: int = 1  # Number of available GPUs
    gpu_ids: list[int] = field(default_factory=list)  # Empty list = auto-select


@dataclass
class EvaluationConfig:
    """OpenCompass evaluation configuration.

    Only benchmarks and data ranges are configured here.
    Other evaluation parameters (mode, gpu_count, etc.) are
    hardcoded in the prompts sent to the Evaluation Agent.
    """
    benchmarks: list[str] = field(default_factory=lambda: ["chemcotbench_mol_und"])
    validation_range: str = "[:min(100, len(index_list)//2)]"
    test_range: str = "[-min(100, len(index_list)//2):]"
    # LLM Judge config (for benchmarks that need LLM scoring)
    judge_model: str = "Qwen/Qwen2.5-32B-Instruct"
    judge_api_base: str = ""
    judge_api_key: str = ""


@dataclass
class ModelPoolConfig:
    """Model pool configuration for data processing LLM calls.

    Data Agent can use these models in cleaning scripts to:
    - Generate CoT reasoning
    - Data augmentation
    - Quality assessment
    """
    strong_models: list[str] = field(default_factory=lambda: ["gpt-4o"])
    weak_models: list[str] = field(default_factory=lambda: ["gpt-4o-mini"])
    api_base: str = ""
    api_key: str = ""
    max_workers: int = 50
    timeout: int = 120


@dataclass
class HuggingFaceConfig:
    """HuggingFace Hub configuration."""
    hf_token: str = ""  # Will be loaded from HF_TOKEN env if empty

    def get_token(self) -> str:
        """Get HF token from config or environment variable."""
        return self.hf_token or os.getenv("HF_TOKEN", "") or os.getenv("HUGGING_FACE_HUB_TOKEN", "")


@dataclass
class Config:
    """Main configuration container."""
    llm: LLMConfig = field(default_factory=LLMConfig)
    model_pool: ModelPoolConfig = field(default_factory=ModelPoolConfig)
    huggingface: HuggingFaceConfig = field(default_factory=HuggingFaceConfig)
    pipeline: PipelineConfig = field(default_factory=PipelineConfig)
    data: DataConfig = field(default_factory=DataConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
    workspace: str = ""

    @classmethod
    def from_yaml(cls, path: str | Path) -> "Config":
        """Load configuration from a YAML file."""
        with open(path) as f:
            data = yaml.safe_load(f)

        config = cls()

        if "llm" in data:
            config.llm = LLMConfig(**data["llm"])
        if "model_pool" in data:
            config.model_pool = ModelPoolConfig(**data["model_pool"])
        if "huggingface" in data:
            config.huggingface = HuggingFaceConfig(**data["huggingface"])
        if "pipeline" in data:
            config.pipeline = PipelineConfig(**data["pipeline"])
        if "data" in data:
            config.data = DataConfig(**data["data"])
        if "training" in data:
            config.training = TrainingConfig(**data["training"])
        if "evaluation" in data:
            config.evaluation = EvaluationConfig(**data["evaluation"])
        if "workspace" in data:
            config.workspace = data["workspace"]

        return config

    def to_yaml(self, path: str | Path) -> None:
        """Save configuration to a YAML file."""
        llm_data = {"api_key": self.llm.api_key}
        # base_url is hardcoded, only save model config
        if self.llm.models:
            llm_data["models"] = self.llm.models
        elif self.llm.model:
            llm_data["model"] = self.llm.model
            
        data = {
            "llm": llm_data,
            "model_pool": {
                "strong_models": self.model_pool.strong_models,
                "weak_models": self.model_pool.weak_models,
                "api_base": self.model_pool.api_base,
                "api_key": self.model_pool.api_key,
                "max_workers": self.model_pool.max_workers,
                "timeout": self.model_pool.timeout,
            },
            "huggingface": {
                "hf_token": self.huggingface.hf_token,
            },
            "pipeline": {
                "max_iterations": self.pipeline.max_iterations,
                "timeout_hours": self.pipeline.timeout_hours,
                "results_summary_path": self.pipeline.results_summary_path,
            },
            "data": {
                "dataset": self.data.dataset,
                "datasets_config": self.data.datasets_config,
                "max_samples": self.data.max_samples,
            },
            "training": {
                "base_model": self.training.base_model,
                "num_gpus": self.training.num_gpus,
                "gpu_ids": self.training.gpu_ids,
            },
            "evaluation": {
                "benchmarks": self.evaluation.benchmarks,
                "validation_range": self.evaluation.validation_range,
                "test_range": self.evaluation.test_range,
                "judge_model": self.evaluation.judge_model,
                "judge_api_base": self.evaluation.judge_api_base,
                "judge_api_key": self.evaluation.judge_api_key,
            },
            "workspace": self.workspace,
        }

        with open(path, "w") as f:
            yaml.dump(data, f, default_flow_style=False)
