import re
from dataclasses import dataclass, field
from typing import Literal

import cvxpy as cp
from cvxpy.settings import ERROR, INF_OR_UNB
from trl import DPOConfig, PPOConfig, RewardConfig

DatasetVersion = Literal["full", "400k", "40k", "20k", "5k"]


@dataclass
class DRConfig:
    eps: float = field(
        default=0,
        metadata={"help": "Regularization strength (epsilon value for the q optimization problem)"},
    )
    dataset_version: DatasetVersion = field(
        default="400k",
        metadata={"help": "Version of the dataset to use (full, 400k, 40k, 20k)"},
    )
    subset_to_remove: str = field(
        default="",
        metadata={"help": "Subset to remove from the training data"},
    )
    dist_fn: Literal["tv", "chi2o"] = field(
        default="tv",
        metadata={"help": "Distance function to use (tv, chi2o)"},
    )
    log_completions_interval: int | None = field(
        default=None,
        metadata={"help": "How often to log the completions from the eval dataset (only for RLHF training). Default: None (do not log completions)"},
    )


@dataclass
class PPOLossType:
    loss_type: Literal["reward", "pilossgrad_pi", "pilossgrad_all"] = field(
        default="reward",
        metadata={"help": "Which DR loss type to use (only for PPO)"},
    )


def get_rew_run_name(dataset_name: str, model_name_or_path: str, dr_config: DRConfig, config: RewardConfig):
    run_name = f"reward_{dataset_name}_"
    run_name += f"{dr_config.dataset_version}_"
    if dr_config.subset_to_remove != "":
        run_name += f"no-{dr_config.subset_to_remove}_"
    run_name += f"lr{config.learning_rate}_"
    if dr_config.eps != 0:
        run_name += f"dr_eps{dr_config.eps}_"
    if dr_config.dist_fn != "tv":
        run_name += f"{dr_config.dist_fn}dist_"
    if config.weight_decay is not None and config.weight_decay != 0.0:
        run_name += f"wd{config.weight_decay}_"
    if config.center_rewards_coefficient is not None and config.center_rewards_coefficient != 0:
        run_name += f"crc{config.center_rewards_coefficient}_"
    run_name += model_name_or_path.replace("/", "_").lower()

    return run_name


def get_rlhf_run_name(method: str, rm_name: str | None, model_name_or_path: str, dr_config: DRConfig, config: PPOConfig | DPOConfig, ppo_loss_type: PPOLossType | None = None):
    run_name = f"{method}_"
    if rm_name is not None:
        if rm_name.startswith("models/"):
            rm_name = rm_name[7:]
        run_name += f"{rm_name.replace('/', '_')}_"
    run_name += f"{dr_config.dataset_version}_"
    if dr_config.subset_to_remove != "":
        run_name += f"no-{dr_config.subset_to_remove}_"
    if dr_config.dist_fn != "tv":
        run_name += f"{dr_config.dist_fn}dist_"
    run_name += f"lr{config.learning_rate}_"
    if isinstance(config, PPOConfig) and config.kl_coef != 0:
        run_name += f"kl{config.kl_coef}_"
    if dr_config.eps != 0:
        run_name += f"dr_eps{dr_config.eps}_"
    if ppo_loss_type is not None:
        run_name += f"solve{ppo_loss_type.loss_type}_"
    run_name += model_name_or_path.replace("/", "_").lower()

    return run_name


def get_eps_from_model_name(model_name: str) -> str:
    epsilon = "0"
    if "dr_eps" not in model_name:
        return epsilon
    result = re.search(r"dr_eps(\d+\.\d+)", model_name)
    if result is None:
        return epsilon
    epsilon = result[1]
    return epsilon


SOLVER_KWARGS = {"solver": cp.SCS, "eps": 1e-5}
EXCEPTION = "exception"
ERROR_STATUSES = INF_OR_UNB + ERROR + [EXCEPTION]

QParams = tuple[cp.Variable, cp.Parameter, cp.Problem]
