import yaml
import json
import os
from pathlib import Path
from typing import Optional, Dict, Any, Type, TypeVar, List
from pydantic import BaseModel, ValidationError
import logging
from .models import Settings, CrossTranslateConfig
T = TypeVar('T', bound=BaseModel)
logger = logging.getLogger(__name__)
class ConfigLoader:
    def __init__(self, config_path: Optional[str] = None):
        self.config_path = config_path or "settings.yaml"
        self._config: Optional[Settings] = None
    def load_config(self, config_class: Type[T] = Settings) -> T:
        config_file = Path(self.config_path)
        if not config_file.exists():
            raise FileNotFoundError(f"Configuration file not found: {config_file}")
        try:
            with open(config_file, 'r', encoding='utf-8') as f:
                config_data = yaml.safe_load(f)
            config = config_class(**config_data)
            if config_class == Settings:
                self._config = config
            logger.info(f"Configuration loaded from {config_file}")
            return config
        except yaml.YAMLError as e:
            raise ValueError(f"Invalid YAML in config file: {e}")
        except ValidationError as e:
            raise ValueError(f"Configuration validation failed: {e}")
        except Exception as e:
            raise RuntimeError(f"Failed to load configuration: {e}")
    def get_config(self) -> Optional[Settings]:
        return self._config
    def save_config(self, config: BaseModel, path: Optional[str] = None) -> None:
        save_path = Path(path or self.config_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        try:
            config_dict = config.dict()
            with open(save_path, 'w', encoding='utf-8') as f:
                yaml.dump(config_dict, f, default_flow_style=False, indent=2)
            logger.info(f"Configuration saved to {save_path}")
        except Exception as e:
            raise RuntimeError(f"Failed to save configuration: {e}")
    def create_default_config(self, path: Optional[str] = None) -> Settings:
        default_cross_translate = CrossTranslateConfig(
            train_dataset="arguana",
            test_dataset="fiqa",
            model={
                "source_model": "mistral",
                "target_model": "nv-embed"
            },
            mapper={
                "mapper_name": "gating-moe"
            }
        )
        default_settings = Settings(
            cross_translate=default_cross_translate
        )
        self.save_config(default_settings, path)
        logger.info(f"Default configuration created and saved to {path or self.config_path}")
        return default_settings
def load_settings(config_path: Optional[str] = None) -> Settings:
    loader = ConfigLoader(config_path)
    return loader.load_config()
def create_default_settings(config_path: str = "settings.yaml") -> Settings:
    loader = ConfigLoader(config_path)
    return loader.create_default_config()
def get_cross_translate_config(config_path: Optional[str] = None) -> CrossTranslateConfig:
    settings = load_settings(config_path)
    return settings.cross_translate
def to_env_key(dot_key: str, prefix: str = "APP") -> str:
    return f"{prefix}_" + dot_key.upper().replace(".", "__")
def _set_nested_value(d: Dict[str, Any], keys: List[str], value: Any) -> None:
    for i, key in enumerate(keys[:-1]):
        next_key = keys[i + 1]
        is_next_index = next_key.isdigit()
        if key.isdigit():
            key = int(key)
            while len(d) <= key:
                d.append({} if not is_next_index else [])
            if d[key] is None:
                d[key] = [] if is_next_index else {}
            d = d[key]
        else:
            if key not in d:
                d[key] = [] if is_next_index else {}
            d = d[key]
    final_key = keys[-1]
    if final_key.isdigit():
        final_key = int(final_key)
        while len(d) <= final_key:
            d.append(None)
        d[final_key] = value
    else:
        d[final_key] = value
def _has_list_index(key: str) -> bool:
    parts = key.split(".")
    return any(part.isdigit() for part in parts)
def transfer_args_to_env(
    args: List[str], 
    prefix: str = "APP",
    auto_parse_types: bool = True
) -> None:
    list_args: Dict[str, Dict[str, Any]] = {}
    regular_args: List[tuple] = []
    for kv in args:
        if "=" not in kv:
            logger.warning(f"Skipping invalid argument (no '=' found): {kv}")
            continue
        k, v = kv.split("=", 1)
        if not v.strip():
            logger.warning(f"Skipping empty value for key: {k}")
            continue
        parsed_v = v
        if auto_parse_types:
            try:
                parsed_v = json.loads(v)
            except (json.JSONDecodeError, ValueError):
                parsed_v = v
        if _has_list_index(k):
            parts = k.split(".")
            root_key = parts[0]
            if root_key not in list_args:
                if len(parts) > 1 and parts[1].isdigit():
                    list_args[root_key] = []
                else:
                    list_args[root_key] = {}
            _set_nested_value(list_args[root_key], parts[1:], parsed_v)
        else:
            regular_args.append((k, v, parsed_v))
    for root_key, structure in list_args.items():
        env_key = f"{prefix}_{root_key.upper()}"
        json_value = json.dumps(structure)
        os.environ[env_key] = json_value
        logger.info(f"Set environment variable: {env_key}={json_value}")
    for k, v, parsed_v in regular_args:
        env_key = to_env_key(k, prefix)
        if isinstance(parsed_v, (dict, list)):
            os.environ[env_key] = json.dumps(parsed_v)
        else:
            os.environ[env_key] = str(parsed_v) if not isinstance(parsed_v, str) else v
        logger.info(f"Set environment variable: {env_key}={os.environ[env_key]}")
def _deep_merge(base: Any, override: Any) -> Any:
    if isinstance(base, dict) and isinstance(override, dict):
        result = base.copy()
        for key, value in override.items():
            if key in result:
                result[key] = _deep_merge(result[key], value)
            else:
                result[key] = value
        return result
    elif isinstance(base, list) and isinstance(override, list):
        result = list(base)
        for i, value in enumerate(override):
            if value is not None:
                if i < len(result):
                    result[i] = _deep_merge(result[i], value)
                else:
                    while len(result) < i:
                        result.append(None)
                    result.append(value)
        return result
    else:
        return override if override is not None else base
def _parse_args_to_overrides(
    args: List[str],
    auto_parse_types: bool = True
) -> tuple[Dict[str, Any], List[tuple]]:
    list_args: Dict[str, Any] = {}
    regular_args: List[tuple] = []
    for kv in args:
        if "=" not in kv:
            logger.warning(f"Skipping invalid argument (no '=' found): {kv}")
            continue
        k, v = kv.split("=", 1)
        if not v.strip():
            logger.warning(f"Skipping empty value for key: {k}")
            continue
        parsed_v = v
        if auto_parse_types:
            try:
                parsed_v = json.loads(v)
            except (json.JSONDecodeError, ValueError):
                parsed_v = v
        if _has_list_index(k):
            parts = k.split(".")
            root_key = parts[0]
            if root_key not in list_args:
                if len(parts) > 1 and parts[1].isdigit():
                    list_args[root_key] = []
                else:
                    list_args[root_key] = {}
            _set_nested_value(list_args[root_key], parts[1:], parsed_v)
        else:
            regular_args.append((k, v, parsed_v))
    return list_args, regular_args
def load_config_with_args(
    config_class: Type[T],
    args: List[str],
    prefix: str = "APP",
) -> T:
    env_prefix = f"{prefix}_"
    keys_to_remove = []
    for key in list(os.environ.keys()):
        if key.startswith(env_prefix):
            value = os.environ[key]
            if not value or value.strip() == "":
                keys_to_remove.append(key)
    for key in keys_to_remove:
        logger.debug(f"Removing empty environment variable: {key}")
        del os.environ[key]
    list_args, regular_args = _parse_args_to_overrides(args)
    for k, v, parsed_v in regular_args:
        env_key = to_env_key(k, prefix)
        if isinstance(parsed_v, (dict, list)):
            os.environ[env_key] = json.dumps(parsed_v)
        else:
            os.environ[env_key] = str(parsed_v) if not isinstance(parsed_v, str) else v
        logger.info(f"Set environment variable: {env_key}={os.environ[env_key]}")
    try:
        config = config_class()
        if list_args:
            config_dict = config.model_dump()
            for root_key, override_structure in list_args.items():
                if root_key in config_dict:
                    config_dict[root_key] = _deep_merge(
                        config_dict[root_key], 
                        override_structure
                    )
                    logger.info(f"Deep-merged overrides for '{root_key}'")
                else:
                    config_dict[root_key] = override_structure
                    logger.info(f"Set new key '{root_key}' from overrides")
            config = config_class.model_validate(config_dict)
        return config
    except Exception as e:
        logger.error(f"Failed to load configuration: {e}")
        logger.error(f"Args provided: {args}")
        logger.error(f"Environment variables with prefix {prefix}_:")
        for key in sorted(os.environ.keys()):
            if key.startswith(env_prefix):
                logger.error(f"  {key}={os.environ[key]}")
        raise
