from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn.functional as F


class GRPO_loss:
    def __init__(
        self,
        temperature: float = 1.0,
        num_samples: int = 8,
        reward_scale: float = 1.0,
        kernel_sigma: float = 1.0,
        kernel_reduction: str = "mean",
        kl_weight: float = 0.1,
        ratio_clip: float = 0.2,
    ):
        self.temperature = temperature
        self.num_samples = num_samples
        self.reward_scale = reward_scale
        self.kernel_sigma = kernel_sigma
        self.kernel_reduction = kernel_reduction
        self.kl_weight = kl_weight
        self.ratio_clip = ratio_clip

    def compute_reward(
        self, predictions: torch.Tensor, targets: torch.Tensor
    ) -> torch.Tensor:
        if predictions.dim() == 3:
            targets_expanded = (
                targets.unsqueeze(1)
                .unsqueeze(2)
                .expand(-1, predictions.size(1), predictions.size(2))
            )
            diff2 = (predictions - targets_expanded) ** 2
            if self.kernel_reduction == "sum":
                dist2 = diff2.sum(dim=2)
            else:
                dist2 = diff2.mean(dim=2)
        else:
            targets_expanded = targets.unsqueeze(1).expand(-1, predictions.size(1))
            dist2 = (predictions - targets_expanded) ** 2
        rewards = dist2
        return -rewards

    def compute_kl_divergence(
        self, 
        current_logits: torch.Tensor, 
        ref_logits: torch.Tensor,
        mask: torch.Tensor = None
    ) -> torch.Tensor:
        current_log_probs = F.log_softmax(current_logits, dim=-1)
        ref_probs = F.softmax(ref_logits, dim=-1)
        
        kl_div = F.kl_div(
            current_log_probs, 
            ref_probs, 
            reduction='none'
        )
        
        kl_div_per_token = kl_div.sum(dim=-1)
        
        if mask is not None:
            kl_div_per_token = kl_div_per_token * mask.float()
            kl_loss = kl_div_per_token.sum() / mask.float().sum()
        else:
            kl_loss = kl_div_per_token.mean()
            
        return kl_loss

    def compute_reinforce_loss(
        self, log_probs: torch.Tensor, rewards: torch.Tensor, baseline: torch.Tensor
    ) -> torch.Tensor:
        sequence_log_probs = log_probs.sum(dim=2)
        advantage = rewards - baseline
        advantage = advantage.detach()
        reinforce_loss = -(sequence_log_probs * advantage).mean()
        return reinforce_loss

    def _sequence_log_prob_from_logits(
        self,
        logits: torch.Tensor,
        token_ids: torch.Tensor,
        mask: torch.Tensor,
    ) -> torch.Tensor:
        log_probs = F.log_softmax(logits, dim=-1)
        gathered = torch.gather(log_probs, dim=-1, index=token_ids.unsqueeze(-1)).squeeze(-1)
        gathered = gathered * mask
        seq_log_prob = gathered.sum(dim=1)
        return seq_log_prob

    def compute_policy_entropy(
        self, model, batch: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        batch_size = batch["encoder_input"].shape[0]
        device = batch["encoder_input"].device
        
        with torch.no_grad():
            memory, memory_key_padding_mask = model.encoder_decoder.encode(batch["encoder_input"])
            
            current_tgt_ids = torch.full(
                (batch_size, 1),
                model.decoder_vocab.bos_pad_id,
                dtype=torch.long,
                device=device,
            )
            
            step_entropies = torch.zeros(
                (batch_size, model.decode_len),
                dtype=torch.float32,
                device=device,
            )
            
            for step_idx in range(model.decode_len):
                logits = model.encoder_decoder.next_token_logits(
                    current_tgt_ids, memory, memory_key_padding_mask
                )
                
                curr_mask = model.decoder_constraint_masks[step_idx, :].unsqueeze(0)
                MASK_VALUE = -1e7
                masked_logits = (1.0 - curr_mask) * MASK_VALUE + curr_mask * logits
                
                scaled_logits = masked_logits / self.temperature
                probs = F.softmax(scaled_logits, dim=-1)
                log_probs = F.log_softmax(scaled_logits, dim=-1)
                
                p_log_p = probs * log_probs
                p_log_p = torch.nan_to_num(p_log_p, nan=0.0, posinf=0.0, neginf=0.0)
                entropy = -torch.sum(p_log_p, dim=-1)
                step_entropies[:, step_idx] = entropy
                
                next_token_ids = torch.argmax(scaled_logits, dim=-1, keepdim=True)
                
                if step_idx < model.decode_len - 1:
                    current_tgt_ids = torch.cat([current_tgt_ids, next_token_ids], dim=1)
            
            policy_entropy = step_entropies.mean()
        
        return policy_entropy

    def compute_policy_entropy_sample(
        self, model, batch: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        batch_size = batch["encoder_input"].shape[0]
        device = batch["encoder_input"].device
        
        with torch.no_grad():
            memory, memory_key_padding_mask = model.encoder_decoder.encode(batch["encoder_input"])
            
            num_samples = self.num_samples
            expanded_memory = (
                memory.unsqueeze(1)
                .expand(-1, num_samples, -1, -1)
                .reshape(batch_size * num_samples, memory.size(1), memory.size(2))
            )
            
            if memory_key_padding_mask is not None:
                expanded_memory_key_padding_mask = (
                    memory_key_padding_mask.unsqueeze(1)
                    .expand(-1, num_samples, -1)
                    .reshape(batch_size * num_samples, memory_key_padding_mask.size(1))
                )
            else:
                expanded_memory_key_padding_mask = None
            
            current_tgt_ids = torch.full(
                (batch_size * num_samples, 1),
                model.decoder_vocab.bos_pad_id,
                dtype=torch.long,
                device=device,
            )
            
            step_entropies = torch.zeros(
                (batch_size * num_samples, model.decode_len),
                dtype=torch.float32,
                device=device,
            )
            
            for step_idx in range(model.decode_len):
                logits = model.encoder_decoder.next_token_logits(
                    current_tgt_ids,
                    expanded_memory,
                    expanded_memory_key_padding_mask
                )
                
                curr_mask = model.decoder_constraint_masks[step_idx, :].unsqueeze(0)
                MASK_VALUE = -1e7
                masked_logits = (1.0 - curr_mask) * MASK_VALUE + curr_mask * logits
                
                scaled_logits = masked_logits / self.temperature
                probs = F.softmax(scaled_logits, dim=-1)
                log_probs = F.log_softmax(scaled_logits, dim=-1)
                
                p_log_p = probs * log_probs
                p_log_p = torch.nan_to_num(p_log_p, nan=0.0, posinf=0.0, neginf=0.0)
                entropy = -torch.sum(p_log_p, dim=-1)
                step_entropies[:, step_idx] = entropy
                
                next_token_ids = torch.multinomial(probs, num_samples=1)
                
                if step_idx < model.decode_len - 1:
                    current_tgt_ids = torch.cat([current_tgt_ids, next_token_ids], dim=1)
            
            step_entropies = step_entropies.reshape(batch_size, num_samples, model.decode_len)
            sample_entropies = step_entropies.mean(dim=-1)
            batch_entropies = sample_entropies.mean(dim=-1)
            policy_entropy = batch_entropies.mean()
        
        return policy_entropy

    def __call__(
        self, model, batch: Dict[str, torch.Tensor], ref_model=None
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        model.eval()
        if ref_model is not None:
            ref_model.eval()
        batch_size = batch["encoder_input"].shape[0]
        device = batch["encoder_input"].device

        with torch.no_grad():
            ref_decoded_ids, ref_output_floats, ref_step_log_probs, ref_step_logits = model.sample_with_logprobs(
                batch, self.num_samples, self.temperature, return_logits=True
            )
        predictions = torch.tensor(ref_output_floats, device=device, dtype=torch.float32)
        targets = batch["y"]
        y_max = batch["y_max"]
        y_min = batch["y_min"]
        y_max_bs = y_max.unsqueeze(1)
        y_min_bs = y_min.unsqueeze(1)
        predictions_true = predictions.squeeze(-1)
        predictions_norm = predictions_true * (y_max_bs - y_min_bs + 1e-8) + y_min_bs
        targets_norm = targets * (y_max - y_min + 1e-8) + y_min
        rewards = self.compute_reward(predictions_norm, targets_norm)
        reward_std_per_batch = rewards.std(dim=1)
        mean_reward_std = reward_std_per_batch.mean()

        baseline = rewards.mean(dim=1, keepdim=True)
        mean_baseline = baseline.mean().detach()
        std_baseline = rewards.std(dim=1, keepdim=True) + 1e-6
        std_baseline = std_baseline.expand_as(rewards)
        advantages = (rewards - baseline) / std_baseline
        advantages_detached = advantages.detach()
        num_nonzero_advantage = (advantages_detached != 0).sum().item()
        baseline_expanded = baseline.expand_as(rewards)
        num_rewards_ge_baseline = (rewards >= baseline_expanded).sum().item()
        num_rewards_lt_baseline = (rewards < baseline_expanded).sum().item()
        advantages_std_per_batch = advantages.std(dim=1)
        mean_advantages_std = advantages_std_per_batch.mean()

        B, S, L = ref_decoded_ids.shape
        vectorized_ref_ids = ref_decoded_ids.view(B * S, L)
        mask = (vectorized_ref_ids != model.decoder_vocab.bos_pad_id).float()
        current_logits_on_ref = self._get_logits_for_sequence(
            model, batch, vectorized_ref_ids, self.num_samples
        )
        current_log_probs_full = F.log_softmax(current_logits_on_ref, dim=-1)
        current_token_log_probs = torch.gather(
            current_log_probs_full, dim=-1, index=vectorized_ref_ids.unsqueeze(-1)
        ).squeeze(-1)
        ref_token_log_probs = ref_step_log_probs.reshape(B * S, L)

        token_log_ratio = current_token_log_probs - ref_token_log_probs
        token_is_ratios = torch.exp(token_log_ratio)

        ratios_reshaped = token_is_ratios.view(B, S, L)
        mask_reshaped = mask.view(B, S, L)
        lengths = mask_reshaped.sum(dim=2).clamp_min(1.0)
        advantages_per_token = advantages_detached.unsqueeze(-1).expand_as(mask_reshaped)
        ratios_clipped = torch.clamp(ratios_reshaped, 1.0 - self.ratio_clip, 1.0 + self.ratio_clip)
        surrogate1 = ratios_reshaped * advantages_per_token
        surrogate2 = ratios_clipped * advantages_per_token
        token_objective = torch.minimum(surrogate1, surrogate2)
        seq_objective = (token_objective * mask_reshaped).sum(dim=2) / lengths
        grpo_loss = - seq_objective.mean()

        total_loss = grpo_loss
        kl_loss = torch.tensor(0.0, device=device)

        clipped_mask = (ratios_reshaped < (1.0 - self.ratio_clip)) | (ratios_reshaped > (1.0 + self.ratio_clip))
        clipped_mask = clipped_mask & (mask_reshaped > 0)
        num_ratios_clipped = clipped_mask.sum()
        total_tokens = mask_reshaped.sum().clamp_min(1.0)
        frac_ratios_clipped = num_ratios_clipped.float() / total_tokens.float()

        if self.kl_weight > 0:
            with torch.no_grad():
                ref_logits_on_ref = self._get_logits_for_sequence(
                    ref_model, batch, vectorized_ref_ids, self.num_samples
                )
            kl_loss = self.compute_kl_divergence(current_logits_on_ref, ref_logits_on_ref, mask)
            total_loss = total_loss + self.kl_weight * kl_loss

        policy_entropy = self.compute_policy_entropy_sample(model, batch)

        median_rewards = torch.zeros(batch_size, device=device)
        for i in range(batch_size):
            sample_predictions = predictions[i]
            sample_rewards = rewards[i]
            
            median_pred = torch.median(sample_predictions)
            
            diff = torch.abs(sample_predictions - median_pred)
            median_idx = torch.argmin(diff)
            
            median_rewards[i] = sample_rewards[median_idx]
        
        median_reward_mean = median_rewards.mean()
        median_reward_min = median_rewards.min()

        metrics = {
            "grpo_loss": grpo_loss.detach(),
            "kl_loss": kl_loss.detach(),
            "total_loss": total_loss.detach(),
            "mean_reward": rewards.mean(),
            "max_reward": rewards.max(),
            "min_reward": rewards.min(),
            "reward_std": rewards.std(),
            "mean_reward_std": mean_reward_std.detach(),
            "mean_advantages_std": mean_advantages_std.detach(),
            "mean_baseline": mean_baseline,
            "num_nonzero_advantage": num_nonzero_advantage,
            "num_rewards_ge_baseline": num_rewards_ge_baseline,
            "num_rewards_lt_baseline": num_rewards_lt_baseline,
            "policy_entropy": policy_entropy.detach(),
            "median_reward_mean": median_reward_mean.detach(),
            "median_reward_min": median_reward_min.detach(),
            "frac_ratios_clipped": frac_ratios_clipped.detach(),
            "reward_std_per_sample": reward_std_per_batch.detach().cpu().numpy(),
            "advantages_per_rollout": advantages_detached.cpu().numpy(),
        }
        model.train()

        return total_loss, metrics

    def _get_logits_for_sequence(
        self, model, batch: Dict[str, torch.Tensor], decoded_ids: torch.Tensor, num_samples: int
    ) -> torch.Tensor:
        encoder_input = batch["encoder_input"]
        batch_size = encoder_input.shape[0]
        device = encoder_input.device
        
        expanded_encoder_input = encoder_input.repeat_interleave(num_samples, dim=0)
        
        memory, memory_key_padding_mask = model.encoder_decoder.encode(expanded_encoder_input)
        
        current_tgt_ids = torch.full(
            (batch_size * num_samples, 1),
            model.decoder_vocab.bos_pad_id,
            dtype=torch.long,
            device=device,
        )
        
        step_logits = torch.zeros(
            (batch_size * num_samples, decoded_ids.shape[1], len(model.decoder_vocab)),
            dtype=torch.float32,
            device=device,
        )
        
        for step_idx in range(decoded_ids.shape[1]):
            logits = model.encoder_decoder.next_token_logits(
                current_tgt_ids, memory, memory_key_padding_mask
            )
            step_logits[:, step_idx, :] = logits
            
            if step_idx < decoded_ids.shape[1] - 1:
                current_tgt_ids = torch.cat([current_tgt_ids, decoded_ids[:, step_idx:step_idx+1]], dim=1)
        
        return step_logits