from enum import Enum
from typing import Literal


from src.settings.base import ExtraFieldsNotAllowedBaseModel
from src.settings.cherry_pick import ChatCherryPickSettings
from src.settings.datasets.pair_preference import PairPreferenceMultiDatasetSettings
from src.settings.pipelines.train.base import BaseTrainExperimentSettings
from src.settings.tf.trainer import TrainerSettings


class DPOLossesType(str, Enum):
    SIGMOID = 'sigmoid'
    IPO = 'ipo'
    KTO = 'kto'
    ORPO = 'orpo'
    SIMPO = 'simpo'
    APO_ZERO = 'apo_zero'
    APO_DOWN = 'apo_down'
    ASFT = 'asft'
    NCA = "nca"
    CALDPO = "cal_dpo"


class DPOLossSettings(ExtraFieldsNotAllowedBaseModel):
    loss_type: DPOLossesType


class KTOLossSettings(DPOLossSettings):
    loss_type: Literal[DPOLossesType.KTO]
    beta: float = 0.1


class SigmoidLossSettings(DPOLossSettings):
    loss_type: Literal[DPOLossesType.SIGMOID]
    label_smoothing: float = 0
    beta: float = 0.1


class IPOLossSettings(DPOLossSettings):
    loss_type: Literal[DPOLossesType.IPO]
    beta: float = 0.1


class CALDPOLossSettings(DPOLossSettings):
    loss_type: Literal[DPOLossesType.CALDPO]
    beta: float = 0.1


class SimPOLossSettings(DPOLossSettings):
    loss_type: Literal[DPOLossesType.SIMPO]
    gamma: float = 0.5
    beta: float = 0.1


class ORPOLossSettings(DPOLossSettings):
    loss_type: Literal[DPOLossesType.ORPO]
    beta: float = 1.0


class ASFTLossSettings(DPOLossSettings):
    loss_type: Literal[DPOLossesType.ASFT]
    beta: float = 1.0


class APOZeroLossSettings(DPOLossSettings):
    loss_type: Literal[DPOLossesType.APO_ZERO]
    beta: float = 0.1


class APODownLossSettings(DPOLossSettings):
    loss_type: Literal[DPOLossesType.APO_DOWN]
    beta: float = 0.1


class NCALossSettings(DPOLossSettings):
    loss_type: Literal[DPOLossesType.NCA]
    beta: float = 0.1


class SyncRefModelSettings(ExtraFieldsNotAllowedBaseModel):
    sync_ref_model: bool = False
    alpha: float = 1.0
    sync_steps: int = 1


class LogpType(str, Enum):
    CUM_LOG_PROB = 'cum_log_prob'
    AVG_LOG_PROB = 'avg_log_prob'


class DPOTrainerSettings(TrainerSettings):
    loss_settings: (
        SigmoidLossSettings
        | IPOLossSettings
        | KTOLossSettings
        | ORPOLossSettings
        | ASFTLossSettings
        | SimPOLossSettings
        | APOZeroLossSettings
        | APODownLossSettings
        | NCALossSettings
        | CALDPOLossSettings
    )
    sync_ref_settings: SyncRefModelSettings
    use_ref_model: bool = True
    use_sft_model: bool = False
    ce_coef: float = 0.0
    unll_coef: float = 0.0
    lam: float = 1.0
    logp_type: LogpType


class DPOTrainExperimentSettings(BaseTrainExperimentSettings):
    train_dataset_settings: PairPreferenceMultiDatasetSettings
    val_dataset_settings: PairPreferenceMultiDatasetSettings

    cherry_pick_settings: ChatCherryPickSettings

    trainer_settings: DPOTrainerSettings
