import os
from typing import Any, Optional

import torch

from verl.trainer.ppo.core_algos import register_policy_loss, agg_loss
import verl.utils.torch_functional as verl_F
from verl.workers.config import ActorConfig


@register_policy_loss("external_adjusted")
def compute_policy_loss_external_adjusted(
    old_log_prob: torch.Tensor,
    log_prob: torch.Tensor,
    advantages: torch.Tensor,
    response_mask: torch.Tensor,
    loss_agg_mode: str = "token-mean",
    config: Optional[ActorConfig] = None,
    rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:


    clip_ratio = config.clip_ratio
    clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
    clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
    clip_ratio_c = config.get("clip_ratio_c", 3.0)

    assert clip_ratio_c > 1.0, f"clip_ratio_c should be > 1.0, got {clip_ratio_c}"

    policy_loss_config = config.get("policy_loss", {})
    external_min_log_prob = policy_loss_config.get("external_min_log_prob", None) if policy_loss_config else None

    negative_approx_kl = log_prob - old_log_prob
    negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)

    ratio = torch.exp(negative_approx_kl)
    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

    pg_losses1 = -advantages * ratio
    pg_losses2 = -advantages * torch.clamp(ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
    clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)

    pg_losses3 = -advantages * clip_ratio_c
    clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
    pg_clipfrac_lower = verl_F.masked_mean(
        torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask
    )

    pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)

    if external_min_log_prob is not None and external_min_log_prob > 0:
        actual_threshold = -external_min_log_prob
        tolerance = 1e-5
        floor_mask = (torch.abs(old_log_prob - actual_threshold) < tolerance) & (response_mask > 0)
        floor_mask = floor_mask & (advantages > 0) & (log_prob < actual_threshold)

        sft_loss_term = -log_prob

        adjusted_pg_losses = torch.where(floor_mask, sft_loss_term, pg_losses)

        floored_count = floor_mask.sum().item()
        total_tokens = response_mask.sum().item()

        pg_losses = adjusted_pg_losses

    elif external_min_log_prob is not None and external_min_log_prob <= 0:

        tolerance = 1e-5
        floor_mask = (torch.abs(old_log_prob - external_min_log_prob) < tolerance) & (response_mask > 0)
        pi_theta = torch.exp(torch.clamp(log_prob.detach(), min=-20.0, max=20.0))
        weight_factor = 1.0 / (0.01 + pi_theta)
        adjusted_pg_losses = torch.where(floor_mask, pg_losses * weight_factor, pg_losses)

        floored_count = floor_mask.sum().item()
        total_tokens = response_mask.sum().item()

        pg_losses = adjusted_pg_losses

    if rollout_is_weights is not None:
        pg_losses = pg_losses * rollout_is_weights

    pg_loss = agg_loss(
        loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode,
        **config.global_batch_info
    )

    pg_metrics = {
        "actor/pg_clipfrac": pg_clipfrac.detach().item(),
        "actor/ppo_kl": ppo_kl.detach().item(),
        "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
    }
    return pg_loss, pg_metrics
