from __future__ import annotations

from dataclasses import dataclass, field, asdict, fields
from typing import Literal, Self, Sequence
import json
from pathlib import Path
from copy import deepcopy

from core.api import *


BILLION = 10**9
TRILLION = 10**12
ALL_PARAMS: dict[str, 'Config'] = {}


@dataclass
class Config:

    @dataclass
    class TestArgs:
        batch_size: int
        context_size: int
        evaluator: dict = field(default_factory=dict)  # arguments for evaluator
        reasoner: dict = field(default_factory=dict)   # arguments for reasoner
        inference: dict = field(default_factory=param.inference.sampling)
        reflection: dict | None = None

    @dataclass
    class ReflectiveTraining:
        approach: ReflectionApproach
        collect: dict
        sft: NTPHyperParames | None = None
        sft_data: dict = field(default_factory=dict)
        n_train_samples: int | None = None
        n_train_instances: int | None = None
        n_val_samples: int | None = None
        n_val_instances: int | None = None
        rl: dict | None = None
        
    @dataclass
    class Reasoning:
        task: str
        impl: ThoughtImpl
        test: Config.TestArgs | dict[str | None, Config.TestArgs]
        sft: NTPHyperParames | None = None
        ppo: param.rl.ppo | None = None
        grpo: param.rl.grpo | None = None
        data_options: dict = field(default_factory=dict)
        train_vf: param.rl.train_vf | None = None
        train_refl: Config.ReflectiveTraining | None = None
        
        def get_test_args(self, key: str | None):
            if isinstance(self.test, Config.TestArgs):
                return self.test
            elif key in self.test:
                return self.test[key]
            elif None in self.test:
                return self.test[None]
            else:
                supported = ", ".join(map(str, self.test.keys()))
                raise KeyError(
                    f"{key} is not a supported test algorithm. The supported are: {supported}"
                )

    name: str
    model: ModelConfig
    corpus: list[str | dict]  # the corpus for tokenizer and pretraining
    pretrain_args: NTPHyperParames
    pretrain_data_dir: str = "data/pretrain/{name}"
    special_tokens: param.vocabulary = field(default_factory=param.vocabulary)
    reasoning: Reasoning | None = None
    reserved_vocab_size: int = 0
    use_pretrained_checkpoint: bool = False
    precision: Precision | None = None

    def __post_init__(self):
        # register into ALL_PARAMS
        if self.name in ALL_PARAMS:
            raise KeyError(f"\"{self.name}\" have already been used.")
        ALL_PARAMS[self.name] = self

    def get_pretrain_data_path(self, *filepath: str):
        path = Path(self.pretrain_data_dir.format(name=self.name))
        for f in filepath:
            path = path / f
        return path

    def out_path(self, *file_names: str):
        path = Path(f'out/{self.name}')
        for name in file_names:
            path = path / name
        return path
    
    def save(self, path: str | Path):
        with open(path, 'rt') as f:
            json.dump(asdict(self), f)
    
    def derive(self, name: str) -> Self:
        init_args = {}
        for field_ in fields(self):
            if field_.init:
                try:
                    v = getattr(self, field_.name)
                except AttributeError:
                    continue
                if isinstance(v, param.vocabulary):
                    init_args[field_.name] = v.copy()
                else:
                    init_args[field_.name] = deepcopy(v)
        init_args['name'] = name
        return self.__class__(**init_args)


def get_config_from_name(name: str):
    try:
        return ALL_PARAMS[name]
    except KeyError:
        supported = ', '.join(("\"%s\"" % k) for k in ALL_PARAMS.keys())
        print(f"\"{name}\" is not a supported configuration. The supported are {supported}.")
        exit()
