from typing import Callable

import numpy as np
from loguru import logger


def calc_normalized_reward(rewards: np.ndarray) -> np.ndarray:
    sorted_idx = np.argsort(rewards).tolist()
    ranks = np.array([sorted_idx.index(i) for i in range(len(rewards))])
    ranks_shifted = ranks + 0.5
    N = len(rewards)
    return ranks_shifted / N


def calc_BoN_reward(rewards: np.ndarray, n: int) -> np.ndarray:
    q = calc_normalized_reward(rewards)
    return n * rewards * q ** (n - 1)


def get_BoN_reward_func(reward_func: Callable, n: int = 4) -> Callable:
    def bon_reward_func(**kwargs):
        rewards = reward_func(**kwargs)
        bon_reward = calc_BoN_reward(np.array(rewards), n)
        return bon_reward.tolist()

    bon_reward_func.__name__ = f"bon_{reward_func.__name__}"
    return bon_reward_func


def calc_BoN_linearized_reward(rewards: np.ndarray, n: int) -> np.ndarray:
    M = len(rewards)
    sorted_idx = np.argsort(rewards).tolist()
    linearized_rewards = np.zeros_like(rewards)

    cumsum = 0.0
    next_reward = rewards[sorted_idx[-1]]
    for i in reversed(range(len(rewards))):
        idx = sorted_idx[i]
        dr = next_reward - rewards[idx]
        q = (i + 1) / M
        cumsum += n * q ** (n - 1) * dr
        linearized_rewards[idx] = -cumsum
        next_reward = rewards[idx]
    return linearized_rewards


def get_BoN_linearized_reward_func(reward_func: Callable, n: int = 4) -> Callable:
    def bon_linearized_reward_func(**kwargs):
        rewards = reward_func(**kwargs)
        bon_linearized_reward = calc_BoN_linearized_reward(np.array(rewards), n)
        return bon_linearized_reward.tolist()

    bon_linearized_reward_func.__name__ = f"bon_linearized_{reward_func.__name__}"
    return bon_linearized_reward_func


def calc_soft_BoN_reward(rewards: np.ndarray, tau: float = 0.1) -> np.ndarray:
    normalizing_constant = np.mean(np.exp(rewards / tau))
    soft_bon_reward = rewards * np.exp(rewards / tau) / normalizing_constant
    return soft_bon_reward


def get_soft_BoN_reward_func(reward_func: Callable, tau: float = 0.1) -> Callable:
    def soft_bon_reward_func(**kwargs):
        rewards = reward_func(**kwargs)
        soft_bon_reward = calc_soft_BoN_reward(np.array(rewards), tau)
        return soft_bon_reward.tolist()

    soft_bon_reward_func.__name__ = f"bon_{reward_func.__name__}"
    return soft_bon_reward_func


def calc_soft_BoN_linearized_reward(
    rewards: np.ndarray, tau: float = 0.1
) -> np.ndarray:
    normalizing_constant = np.mean(np.exp(rewards / tau))
    expected_reward = np.mean(rewards * np.exp(rewards / tau) / normalizing_constant)
    linearized_rewards = (
        rewards * np.exp(rewards / tau) / normalizing_constant
        - np.exp(rewards / tau) * expected_reward / normalizing_constant
    )
    logger.info(f"rewards: {rewards}")
    logger.info(f"Soft-BoN linearized rewards: {linearized_rewards}")
    return linearized_rewards


def get_soft_BoN_linearized_reward_func(
    reward_func: Callable, tau: float = 0.1
) -> Callable:
    def soft_bon_linearized_reward_func(**kwargs):
        rewards = reward_func(**kwargs)
        soft_bon_linearized_reward = calc_soft_BoN_linearized_reward(
            np.array(rewards), tau
        )
        return soft_bon_linearized_reward.tolist()

    soft_bon_linearized_reward_func.__name__ = f"bon_linearized_{reward_func.__name__}"
    return soft_bon_linearized_reward_func


def get_soft_min_linearized_reward_func(
    transformed_reward_funcs, linearized_reward_funcs, gamma: float
) -> Callable:
    def soft_min_linearized_reward_func(**kwargs):
        n_reward = len(transformed_reward_funcs)
        transformed_rewards = [
            transformed_reward_func(**kwargs)
            for transformed_reward_func in transformed_reward_funcs
        ]
        linearized_rewards = [
            linearized_reward_func(**kwargs)
            for linearized_reward_func in linearized_reward_funcs
        ]
        soft_min_linearized_rewards = np.zeros_like(transformed_rewards[0])

        sum_weights = 0.0
        for i in range(n_reward):
            expected_reward = np.mean(transformed_rewards[i])
            weight = np.exp(-gamma * expected_reward)
            logger.info(
                f"Reward function {i}: expected_reward={expected_reward}, weight={weight}"
            )
            logger.info(f"Linearized rewards: {linearized_rewards[i]}")
            soft_min_linearized_rewards += weight * np.array(linearized_rewards[i])
            sum_weights += weight
        soft_min_linearized_rewards /= sum_weights

        return soft_min_linearized_rewards.tolist()

    soft_min_linearized_reward_func.__name__ = "soft_min_linearized_reward"
    return soft_min_linearized_reward_func
