from typing import Literal

from tap import Tap


class TrainingConfig(Tap):
    task: Literal["length", "hh", "pku"] = "length"
    model: Literal["mistral7b", "alpaca7b"] = "mistral7b"
    suffix: str = "None"
    meta: bool = False
    max_completion_length: int = 256
    n: int = 4
    num_epochs: int = 1
    log_completions: bool = False
    beta: float = 0.1
    target_kl: float = 0.1
    scale_rewards: Literal["none", "batch", "group"] = "none"
    weights: list[float] = [0.5, 0.5]
    resume_from_checkpoint: bool = False
    vllm_mode: Literal["colocate", "server"] = "colocate"
    bon_type: Literal["bon", "softbon"] = "bon"
    tau: float | None = 0.1
    learning_rate: float = 1e-06
    lr_scheduler_type: Literal["constant", "linear"] = "linear"

    @property
    def model_dir(self) -> str:
        model_dirs = {
            "mistral7b": "mistralai/Mistral-7B-Instruct-v0.2",
            "alpaca7b": "PKU-Alignment/alpaca-7b-reproduced",
        }
        return model_dirs[self.model]

    @property
    def experiment_name(self) -> str:
        name = f"{self.task}_{self.model}"
        if self.suffix != "None":
            name += f"_{self.suffix}"
        return name

    @property
    def run_name(self) -> str:
        if self.task == "length":
            if self.bon_type == "bon":
                return f"meta={self.meta}_n={self.n}_beta={self.beta}"
            elif self.bon_type == "softbon":
                return f"meta={self.meta}_tau={self.tau}_beta={self.beta}"
            else:
                raise ValueError(f"Unknown bon_type: {self.bon_type}")
        else:
            return f"meta={self.meta}_n={self.n}_kl={self.target_kl}_weights={'-'.join(map('{:.1f}'.format, self.weights))}"

    @property
    def log_dir(self) -> str:
        return f"logs/{self.experiment_name}/{self.run_name}"


class RewardModelingConfig(Tap):
    dataset: Literal["hh"] = "hh"
    subset: Literal["harmless", "helpful"] = "harmless"
    model: Literal["mistral7b", "qwen0.5b"] = "qwen0.5b"
    num_epochs: int = 3

    @property
    def dataset_dir(self) -> str:
        dataset_dirs = {
            "hh": "Anthropic/hh-rlhf",
        }
        return dataset_dirs[self.dataset]

    @property
    def data_dir(self) -> str:
        data_dirs = {
            "harmless": "harmless-base",
            "helpful": "helpful-base",
        }
        return data_dirs[self.subset]

    @property
    def model_dir(self) -> str:
        model_dirs = {
            "mistral7b": "mistralai/Mistral-7B-Instruct-v0.2",
            "qwen0.5b": "Qwen/Qwen2-0.5B-Instruct",
        }
        return model_dirs[self.model]

    @property
    def experiment_name(self) -> str:
        name = f"rewards/{self.dataset}_{self.model}"
        return name

    @property
    def run_name(self) -> str:
        return self.subset

    @property
    def log_dir(self) -> str:
        return f"logs/{self.experiment_name}/{self.run_name}"
