from __future__ import annotations

import ast
from pathlib import Path
from typing import TypeVar, cast

from typing_extensions import Self

from mow.common.trainer import CustomTrainerConfig

T = TypeVar("T", bound="ConfigMixin | TrainConfigMixin")


def _load_config_from_file(cls: type[T], path: str | Path) -> T:
    path = Path(path)
    with open(path, "r") as f:
        node = ast.parse(f.read())
    obj = compile(cast(str, node), "<ast>", "exec")
    loc = {}
    exec(obj, globals(), loc)
    config = loc[cls.key]
    if not isinstance(config, cls):
        raise TypeError(
            f"Expected config of type {cls.__name__}, got {type(config).__name__}"
        )
    return cast(T, config)


class ConfigMixin:
    def __init_subclass__(cls, key: str) -> None:
        cls.key = key
        return super().__init_subclass__()

    @classmethod
    def from_file(cls, path: str | Path) -> Self:
        return _load_config_from_file(cls, path)


class TrainConfigMixin[TrainConfigType: CustomTrainerConfig]:
    def __init_subclass__(cls, key: str) -> None:
        cls.key = key
        return super().__init_subclass__()

    def __init__(
        self,
        *,
        train_config: TrainConfigType,
    ):
        self.train_config = train_config

    @classmethod
    def from_file(cls, path: str | Path) -> Self:
        config = _load_config_from_file(cls, path)
        if config.train_config.output_dir is None:
            config.train_config.output_dir = Path(path).parent
        return config
