import json
import dataclasses
from .fq_models import SkillQuestionSet, CandidateSet

@dataclasses.dataclass
class TrainingConfig:
    llm_path: str
    output_dir: str
    question_set: SkillQuestionSet
    candidate_set: CandidateSet
    personality_prediction_model_path: str

    target_skill: str

    question_set_path: str
    candidate_set_path: str
    tree_depth:int
    summary_prefix:str

    num_epochs: int
    fairness_batch_size: int
    rl_batch_size: int
    num_rollouts: int
    rollout_depth: int

    save_frequency: int
    intial_learning_rate: float

    fairness_epsilon: float

    random_seed:int

    # Optional fields for fairness distillation
    num_distillation_steps: int | None = None
    distillation_lr: float | None = None

    @classmethod
    def from_file(cls, filepath: str, cache_dir:str|None) -> "TrainingConfig":
        with open(filepath, "r") as f:
            config_dict = json.load(f)

        # Adjust assertion to account for potentially missing optional fields
        # The number of required fields is len(cls.__dataclass_fields__) - 2 (for question_set, candidate_set) - 2 (for new optional fields)
        # However, it's simpler to check if all *expected* keys from the JSON are present in the dataclass fields
        # and handle optional ones during instantiation.
        # For now, let's assume the config might not have the new fields.
        # The original assertion was:
        # assert len(config_dict) == len(cls.__dataclass_fields__) - 2, f"Expected {len(cls.__dataclass_fields__) -2} fields, got {len(config_dict)}"
        # A more robust way would be to iterate through dataclass fields and check presence in config_dict,
        # but for now, we'll rely on the **config_dict unpacking to handle missing optional fields if they are None by default.

        question_set = SkillQuestionSet.load_from_json(config_dict["question_set_path"])
        candidate_set = CandidateSet.load_from_json(config_dict["candidate_set_path"], cache_dir)

        # Ensure all fields from config_dict are valid TrainingConfig fields
        valid_field_names = {f.name for f in dataclasses.fields(cls)}
        for key in config_dict.keys():
            if key not in valid_field_names:
                # Allow 'question_set_path' and 'candidate_set_path' as they are used to load objects
                # but are not direct fields of the same name in the dataclass after loading.
                if key not in ['question_set_path', 'candidate_set_path']:
                     raise ValueError(f"Unexpected field '{key}' in config file {filepath}")

        # Create an instance, allowing **config_dict to fill matching fields.
        # Optional fields not in config_dict will take their default None values.
        instance_args = {
            "question_set": question_set,
            "candidate_set": candidate_set,
        }
        for field in dataclasses.fields(cls):
            if field.name in config_dict:
                instance_args[field.name] = config_dict[field.name]
            # Optional fields not in config_dict will use their default if defined in dataclass
            # or will be absent from instance_args if no default and not in config_dict,
            # which is fine if they are Optional[...]=None.

        return cls(**instance_args)
