import abc

import torch


class AdvantageBase(abc.ABC):
    @torch.inference_mode()
    def __call__(self, *args, **kwargs):
        """
        This is the public "call" guaranteed to run under torch.inference_mode
        """
        return self._call_impl(*args, **kwargs)

    @abc.abstractmethod
    def _call_impl(self, rewards: torch.Tensor, mean_grouped_rewards: torch.Tensor,
                   std_grouped_rewards: torch.Tensor) -> torch.Tensor:
        """
        Subclasses implement the core logic here.
        The torch.inference_mode context is handled by the base class.
        """
        pass


def get_advantage_function(name: str | None = None, kwargs: dict[str, float] = {}) -> AdvantageBase:
    if name == 'default' or name is None:
        return DefaultAdvantage(**kwargs)
    elif name == "no_std":
        return NoStdAdvantage(**kwargs)
    elif name == "std_scaled":
        return StdScaledAdvantage(**kwargs)
    elif name == "trimmed":
        return TrimmedAdvantage(**kwargs)
    elif name == "replace_zero":
        return ReplaceZeroAdvantage(**kwargs)
    elif name == "trimmed_abs":
        return TrimmedAbsAdvantage(**kwargs)
    elif name == "noisy_default":
        return NoisyDefaultAdvantage(**kwargs)
    elif name == "bon_max":
        return BonMaxAdvantage(**kwargs)
    elif name == "bon_max_binary_positive":
        return BinaryBonMaxAdvantage(**kwargs)
    elif name == "identity":
        return IdentityAdvantage(**kwargs)
    else:
        raise ValueError(f"Invalid advantage function name: {name}")


class AdvantageBase(abc.ABC):
    @torch.no_grad()
    def __call__(self, *args, **kwargs):
        """
        This is the public "call" guaranteed to run under torch.inference_mode
        """
        return self._call_impl(*args, **kwargs)

    @abc.abstractmethod
    def _call_impl(self, rewards: torch.Tensor, mean_grouped_rewards: torch.Tensor, std_grouped_rewards: torch.Tensor,
                   num_generations: int) -> torch.Tensor:
        """
        Subclasses implement the core logic here.
        The torch.inference_mode context is handled by the base class.
        """
        pass


class DefaultAdvantage(AdvantageBase):
    def __init__(self, reg_coef: float = 1e-4, **kwargs):
        self.reg_coef = reg_coef

    def _call_impl(self, rewards, mean_grouped_rewards, std_grouped_rewards, num_generations):
        return (rewards - mean_grouped_rewards) / (std_grouped_rewards + self.reg_coef)


class NoStdAdvantage(AdvantageBase):
    def __init__(self, **kwargs):
        pass

    def _call_impl(self, rewards, mean_grouped_rewards, std_grouped_rewards, num_generations):
        return (rewards - mean_grouped_rewards)


class StdScaledAdvantage(AdvantageBase):
    def __init__(self, **kwargs):
        pass

    def _call_impl(self, rewards, mean_grouped_rewards, std_grouped_rewards, num_generations):
        return (rewards - mean_grouped_rewards) * std_grouped_rewards


class TrimmedAdvantage(AdvantageBase):
    def __init__(self, trim_coef: float = 0.0, reg_coef: float = 1e-4, **kwargs):
        self.trim_coef = trim_coef
        self.reg_coef = reg_coef

    def _call_impl(self, rewards, mean_grouped_rewards, std_grouped_rewards, num_generations):
        mean_grouped_rewards = mean_grouped_rewards.clone()
        mean_grouped_rewards[mean_grouped_rewards < self.trim_coef] = 0
        return (rewards - mean_grouped_rewards) / (std_grouped_rewards + self.reg_coef)


class TrimmedAbsAdvantage(TrimmedAdvantage):
    def __init__(self, trim_coef: float = 0.1, reg_coef: float = 1e-4, **kwargs):
        super().__init__(trim_coef, reg_coef)

    def _call_impl(self, rewards, mean_grouped_rewards, std_grouped_rewards, num_generations):
        _mean_grouped_rewards = mean_grouped_rewards.clone()
        _mean_grouped_rewards[_mean_grouped_rewards.abs() < self.trim_coef] = 0
        return (rewards - _mean_grouped_rewards) / (std_grouped_rewards + self.reg_coef)


class ReplaceZeroAdvantage(AdvantageBase):
    def __init__(self, replace_coef: float = 0.1,
                 reg_coef: float = 1e-4,
                 max_reward: float = 1.0,
                 min_reward: float = 0.0,
                 **kwargs):
        self.replace_coef = replace_coef
        self.reg_coef = reg_coef
        self.max_reward = max_reward
        self.min_reward = min_reward

    def _call_impl(self, rewards, mean_grouped_rewards, std_grouped_rewards, num_generations):
        advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + self.reg_coef)
        if advantages.abs().sum() < 1e-6:
            updated_rewards = 2 * (rewards - self.min_reward) / (self.max_reward - self.min_reward) - 1
            return torch.ones_like(advantages) * self.replace_coef * updated_rewards
        return advantages


class NoisyDefaultAdvantage(DefaultAdvantage):
    def __init__(self, reg_coef: float = 1e-4, noise_mean: float = 0., noise_std: float = 0.1, seed: int = 1234,
                 **kwargs):
        super().__init__(reg_coef)
        self.noise_mean = noise_mean
        self.noise_std = noise_std
        self.seed = seed

    def _call_impl(self, rewards, mean_grouped_rewards, std_grouped_rewards, num_generations):
        advantages = super().__call__(rewards, mean_grouped_rewards, std_grouped_rewards)
        _gen = torch.Generator()
        _gen.manual_seed(self.seed)
        noise = torch.randn_like(advantages, generator=_gen) * self.noise_std + self.noise_mean
        noised_advantages = advantages + noise
        return noised_advantages


class BonMaxAdvantage(AdvantageBase):
    def __init__(self, baseline: str = None, **kwargs):
        self.baseline = baseline if baseline is not None else 'none'

    def _call_impl(self, rewards, mean_grouped_rewards, std_grouped_rewards, num_generations):
        rewards = rewards.view(-1, num_generations)
        max_rewards, max_indices = rewards.max(dim=1, keepdim=True)
        mask = (rewards == max_rewards)

        if self.baseline == 'with_repetition':
            # Remove second max for each prompt with repetitions
            second_max_rewards = rewards.clone()
            second_max_rewards[torch.arange(rewards.shape[0]).unsqueeze(1), max_indices] = float('-inf')
            second_max = second_max_rewards.max(dim=1, keepdim=True).values
            # If all values are the same, use max_rewards
            second_max[second_max == float('-inf')] = max_rewards[second_max == float('-inf')]
            new_rewards = rewards - second_max
            advantages = new_rewards * mask
        elif self.baseline == 'without_repetition':
            # Remove second max for each prompt without repetitions
            second_max_rewards = rewards.clone()
            second_max_rewards[mask] = float('-inf')
            second_max = second_max_rewards.max(dim=1, keepdim=True).values
            # If all values are the same, use max_rewards
            second_max[second_max == float('-inf')] = max_rewards[second_max == float('-inf')]
            new_rewards = rewards - second_max
            advantages = new_rewards * mask
        elif self.baseline == 'mean':
            advantages = rewards - mean_grouped_rewards.view(-1, num_generations)
            advantages = advantages * mask
        elif self.baseline == 'none':
            advantages = rewards * mask
        else:
            raise ValueError(f"Invalid baseline: {self.baseline}")
        return advantages.view(-1)


class BinaryBonMaxAdvantage(BonMaxAdvantage):
    def __init__(self, baseline: bool = True, k: int = 1, **kwargs):
        super().__init__()
        self.baseline = baseline
        self.k = k

    def _call_impl(self, rewards, mean_grouped_rewards, std_grouped_rewards, num_generations):
        rewards = rewards.view(-1, num_generations)
        p_fail = 1 - rewards.mean(dim=1)

        # Compute value for positions with reward = 1
        term_one = 1 * self.k * p_fail.pow(self.k - 1) * (1 - p_fail) / (1 - p_fail.pow(self.k))  # shape (n, 1)

        # Create output tensor initialized with zeros
        advantages = torch.zeros_like(rewards, dtype=rewards.dtype)

        # Only replace positions where rewards == 1
        ones_mask = (rewards == 1)
        advantages[ones_mask] = term_one.view(-1, 1).expand_as(rewards)[ones_mask]

        return advantages.view(-1)


class IdentityAdvantage(AdvantageBase):
    def __init__(self, **kwargs):
        super().__init__()

    def _call_impl(self, rewards, mean_grouped_rewards, std_grouped_rewards, num_generations):
        return rewards
