import jax.numpy as jnp
import flax.linen as nn


def rrhf_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time=None,
    mirror_map=None,
    kto_z=None,
    config=None,
):
    logit_chosen_norm = logit_chosen / length_chosen
    logit_rejected_norm = logit_rejected / length_rejected
    logit_chosen = jnp.where(config["AVG_LOGITS"], logit_chosen_norm, logit_chosen)
    logit_rejected = jnp.where(
        config["AVG_LOGITS"], logit_rejected_norm, logit_rejected
    )
    losses = nn.relu(-logit_chosen + logit_rejected) - logit_chosen
    return losses


def slic_hf_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time=None,
    mirror_map=None,
    kto_z=None,
    config=None,
):
    logit_chosen_norm = logit_chosen / length_chosen
    logit_rejected_norm = logit_rejected / length_rejected
    logit_chosen = jnp.where(config["AVG_LOGITS"], logit_chosen_norm, logit_chosen)
    logit_rejected = jnp.where(
        config["AVG_LOGITS"], logit_rejected_norm, logit_rejected
    )
    losses = nn.relu(config["BETA"] - logit_chosen + logit_rejected) - logit_chosen
    return losses


def kto_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time=None,
    mirror_map=None,
    kto_z=None,
    config=None,
):
    logit_chosen_norm = logit_chosen / length_chosen
    logit_rejected_norm = logit_rejected / length_rejected
    ref_logit_chosen_norm = ref_logit_chosen / length_chosen
    ref_logit_rejected_norm = ref_logit_rejected / length_rejected
    logit_chosen = jnp.where(config["AVG_LOGITS"], logit_chosen_norm, logit_chosen)
    logit_rejected = jnp.where(
        config["AVG_LOGITS"], logit_rejected_norm, logit_rejected
    )
    ref_logit_chosen = jnp.where(
        config["AVG_LOGITS"], ref_logit_chosen_norm, ref_logit_chosen
    )
    ref_logit_rejected = jnp.where(
        config["AVG_LOGITS"], ref_logit_rejected_norm, ref_logit_rejected
    )
    chosen_logratios = logit_chosen - ref_logit_chosen
    chosen_losses = 1 - nn.sigmoid(config["BETA"] * (chosen_logratios - kto_z))
    rejected_logratios = logit_rejected - ref_logit_rejected
    rejected_losses = 1 - nn.sigmoid(config["BETA"] * (kto_z - rejected_logratios))
    losses = -(config["ALPHA1"] * chosen_losses + config["ALPHA2"] * rejected_losses)
    return losses


def fdpo_fun(x):
    y = jnp.exp(x)
    return jnp.log(2 * y / (y + 1))


def fdpo_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time=None,
    mirror_map=None,
    kto_z=None,
    config=None,
):
    logit_chosen_norm = logit_chosen / length_chosen
    logit_rejected_norm = logit_rejected / length_rejected
    ref_logit_chosen_norm = ref_logit_chosen / length_chosen
    ref_logit_rejected_norm = ref_logit_rejected / length_rejected
    logit_chosen = jnp.where(config["AVG_LOGITS"], logit_chosen_norm, logit_chosen)
    logit_rejected = jnp.where(
        config["AVG_LOGITS"], logit_rejected_norm, logit_rejected
    )
    ref_logit_chosen = jnp.where(
        config["AVG_LOGITS"], ref_logit_chosen_norm, ref_logit_chosen
    )
    ref_logit_rejected = jnp.where(
        config["AVG_LOGITS"], ref_logit_rejected_norm, ref_logit_rejected
    )
    return -nn.log_sigmoid(
        config["BETA"]
        * (
            fdpo_fun(logit_chosen - ref_logit_chosen)
            - fdpo_fun(logit_rejected - ref_logit_rejected)
        )
    )


def rdpo_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time=None,
    mirror_map=None,
    kto_z=None,
    config=None,
):
    length_norm = config["ALPHA"] * (length_chosen - length_rejected)
    losses = -nn.log_sigmoid(
        config["BETA"]
        * (logit_chosen - ref_logit_chosen - logit_rejected + ref_logit_rejected)
        + length_norm
    )
    return losses


def cpo_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time=None,
    mirror_map=None,
    kto_z=None,
    config=None,
):
    logit_chosen_norm = logit_chosen / length_chosen
    logit_rejected_norm = logit_rejected / length_rejected
    logit_chosen = jnp.where(config["AVG_LOGITS"], logit_chosen_norm, logit_chosen)
    logit_rejected = jnp.where(
        config["AVG_LOGITS"], logit_rejected_norm, logit_rejected
    )
    logits = logit_chosen - logit_rejected
    losses = (-nn.log_sigmoid(config["BETA"] * logits)) - logit_chosen
    return losses


def d_basic_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time=None,
    mirror_map=None,
    kto_z=None,
    config=None,
):
    logit_chosen_norm = logit_chosen / length_chosen
    logit_rejected_norm = logit_rejected / length_rejected
    logit_chosen = jnp.where(config["AVG_LOGITS"], logit_chosen_norm, logit_chosen)
    logit_rejected = jnp.where(
        config["AVG_LOGITS"], logit_rejected_norm, logit_rejected
    )
    logits = logit_chosen - logit_rejected
    losses = (-nn.log_sigmoid(0.62 * logits)) - 0.48 * logit_chosen
    return losses


def d_shuffled_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time=None,
    mirror_map=None,
    kto_z=None,
    config=None,
):
    logit_chosen_norm = logit_chosen / length_chosen
    logit_rejected_norm = logit_rejected / length_rejected
    logit_chosen = jnp.where(config["AVG_LOGITS"], logit_chosen_norm, logit_chosen)
    logit_rejected = jnp.where(
        config["AVG_LOGITS"], logit_rejected_norm, logit_rejected
    )
    logits = logit_chosen - logit_rejected
    losses = (-nn.log_sigmoid(0.16 * logits)) - 0.01 * logit_chosen
    return losses


def d_noisy_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time=None,
    mirror_map=None,
    kto_z=None,
    config=None,
):
    logit_chosen_norm = logit_chosen / length_chosen
    logit_rejected_norm = logit_rejected / length_rejected
    logit_chosen = jnp.where(config["AVG_LOGITS"], logit_chosen_norm, logit_chosen)
    logit_rejected = jnp.where(
        config["AVG_LOGITS"], logit_rejected_norm, logit_rejected
    )
    logits = logit_chosen - logit_rejected
    losses = (-nn.log_sigmoid(0.12 * logits)) - 0.05 * logit_chosen
    return losses


def simpo_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time=None,
    mirror_map=None,
    kto_z=None,
    config=None,
):
    label_smoothing = 0.0
    beta = config["BETA"]
    logit_chosen = logit_chosen / length_chosen
    logit_rejected = logit_rejected / length_rejected
    pi_logratios = logit_chosen - logit_rejected
    logits = pi_logratios

    losses = -nn.log_sigmoid(beta * logits - config["ALPHA"])

    return losses


def ipo_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time=None,
    mirror_map=None,
    kto_z=None,
    config=None,
):
    logit_chosen_norm = logit_chosen / length_chosen
    logit_rejected_norm = logit_rejected / length_rejected
    ref_logit_chosen_norm = ref_logit_chosen / length_chosen
    ref_logit_rejected_norm = ref_logit_rejected / length_rejected
    logit_chosen = jnp.where(config["AVG_LOGITS"], logit_chosen_norm, logit_chosen)
    logit_rejected = jnp.where(
        config["AVG_LOGITS"], logit_rejected_norm, logit_rejected
    )
    ref_logit_chosen = jnp.where(
        config["AVG_LOGITS"], ref_logit_chosen_norm, ref_logit_chosen
    )
    ref_logit_rejected = jnp.where(
        config["AVG_LOGITS"], ref_logit_rejected_norm, ref_logit_rejected
    )
    pi_logratios = logit_chosen - logit_rejected
    ref_logratios = ref_logit_chosen - ref_logit_rejected
    h = pi_logratios - ref_logratios
    return (h - 1 / (2 * config["BETA"])) ** 2


def dpo_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time=None,
    mirror_map=None,
    kto_z=None,
    config=None,
):
    logit_chosen_norm = logit_chosen / length_chosen
    logit_rejected_norm = logit_rejected / length_rejected
    ref_logit_chosen_norm = ref_logit_chosen / length_chosen
    ref_logit_rejected_norm = ref_logit_rejected / length_rejected
    logit_chosen = jnp.where(config["AVG_LOGITS"], logit_chosen_norm, logit_chosen)
    logit_rejected = jnp.where(
        config["AVG_LOGITS"], logit_rejected_norm, logit_rejected
    )
    ref_logit_chosen = jnp.where(
        config["AVG_LOGITS"], ref_logit_chosen_norm, ref_logit_chosen
    )
    ref_logit_rejected = jnp.where(
        config["AVG_LOGITS"], ref_logit_rejected_norm, ref_logit_rejected
    )
    return -nn.log_sigmoid(
        config["BETA"]
        * (logit_chosen - logit_rejected + ref_logit_rejected - ref_logit_chosen)
    )


def cdpo_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time=None,
    mirror_map=None,
    kto_z=None,
    config=None,
):
    return (1 - config["ALPHA"]) * dpo_loss(
        logit_chosen,
        logit_rejected,
        ref_logit_chosen,
        ref_logit_rejected,
        length_chosen,
        length_rejected,
        time,
        mirror_map,
        kto_z,
        config,
    ) + config["ALPHA"] * dpo_loss(
        logit_chosen,
        logit_rejected,
        ref_logit_chosen,
        ref_logit_rejected,
        length_chosen,
        length_rejected,
        time,
        mirror_map,
        kto_z,
        config,
    )


def sft_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time=None,
    mirror_map=None,
    kto_z=None,
    config=None,
):
    logit_chosen_norm = logit_chosen / length_chosen
    logit_chosen = jnp.where(config["AVG_LOGITS"], logit_chosen_norm, logit_chosen)
    return -logit_chosen


def orpo_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time=None,
    mirror_map=None,
    kto_z=None,
    config=None,
):
    logit_chosen_norm = logit_chosen / length_chosen
    logit_rejected_norm = logit_rejected / length_rejected
    logit_chosen = jnp.where(config["AVG_LOGITS"], logit_chosen_norm, logit_chosen)
    logit_rejected = jnp.where(
        config["AVG_LOGITS"], logit_rejected_norm, logit_rejected
    )
    log_odds_chosen = logit_chosen - jnp.log1p(-jnp.exp(logit_chosen))
    log_odds_rejected = logit_rejected - jnp.log1p(-jnp.exp(logit_rejected))
    return -logit_chosen - config["BETA"] * nn.log_sigmoid(
        log_odds_chosen - log_odds_rejected
    )



def mpo_loss(
    logit_chosen,
    logit_rejected,
    ref_logit_chosen,
    ref_logit_rejected,
    length_chosen,
    length_rejected,
    time,
    mirror_map,
    kto_z=None,
    config=None,
):
    logit_chosen_norm = logit_chosen / length_chosen
    logit_rejected_norm = logit_rejected / length_rejected
    logit_chosen = jnp.where(config["AVG_LOGITS"], logit_chosen_norm, logit_chosen)
    logit_rejected = jnp.where(
        config["AVG_LOGITS"], logit_rejected_norm, logit_rejected
    )
    return (
        -mirror_map.apply_fn(
            {"params": mirror_map.params}, jnp.exp(logit_chosen), time
        )[0]
        - config["BETA"]
        * mirror_map.apply_fn(
            {"params": mirror_map.params},
            mirror_map.apply_fn(
                {"params": mirror_map.params}, jnp.exp(logit_chosen), time
            )[1]
            - mirror_map.apply_fn(
                {"params": mirror_map.params}, jnp.exp(logit_rejected), time
            )[1],
            time,
        )[2]
    )


loss_dict = {
    "rrhf_loss": rrhf_loss,
    "slic_hf_loss": slic_hf_loss,
    "kto_loss": kto_loss,
    "fdpo_loss": fdpo_loss,
    "rdpo_loss": rdpo_loss,
    "cpo_loss": cpo_loss,
    "d_basic_loss": d_basic_loss,
    "d_shuffled_loss": d_shuffled_loss,
    "d_noisy_loss": d_noisy_loss,
    "simpo_loss": simpo_loss,
    "ipo_loss": ipo_loss,
    "dpo_loss": dpo_loss,
    "cdpo_loss": cdpo_loss,
    "sft_loss": sft_loss,
    "orpo_loss": orpo_loss,
    "mpo_loss": mpo_loss,
}
