from typing import Dict, Tuple

import torch


class ReinforceLoss:
    def __init__(
        self,
        temperature: float = 1.0,
        num_samples: int = 8,
        reward_scale: float = 1.0,
        baseline_type: str = "mean",
        kernel_sigma: float = 1.0,
        kernel_reduction: str = "mean",
    ):
        self.temperature = temperature
        self.num_samples = num_samples
        self.reward_scale = reward_scale
        self.baseline_type = baseline_type
        self.kernel_sigma = kernel_sigma
        self.kernel_reduction = kernel_reduction

    def compute_reward(
        self, predictions: torch.Tensor, targets: torch.Tensor
    ) -> torch.Tensor:
        sigma2 = float(self.kernel_sigma) ** 2 + 1e-12
        if predictions.dim() == 3:
            targets_expanded = (
                targets.unsqueeze(1)
                .unsqueeze(2)
                .expand(-1, predictions.size(1), predictions.size(2))
            )
            diff2 = (predictions - targets_expanded) ** 2
            if self.kernel_reduction == "sum":
                dist2 = diff2.sum(dim=2)
            else:
                dist2 = diff2.mean(dim=2)
        else:
            targets_expanded = targets.unsqueeze(1).expand(-1, predictions.size(1))
            dist2 = (predictions - targets_expanded) ** 2
        rewards = torch.exp(-0.5 * dist2 / sigma2) * self.reward_scale
        return rewards

    def compute_baseline(self, rewards: torch.Tensor) -> torch.Tensor:
        if self.baseline_type == "mean":
            baseline = rewards.mean(dim=1)
        elif self.baseline_type == "min":
            baseline = rewards.min(dim=1)[0]
        else:
            raise ValueError(f"Unknown baseline type: {self.baseline_type}")

        return baseline

    def compute_reinforce_loss(
        self, log_probs: torch.Tensor, rewards: torch.Tensor
    ) -> torch.Tensor:
        num_samples = log_probs.shape[1]

        sequence_log_probs = log_probs.sum(dim=2)

        baseline = self.compute_baseline(rewards)
        baseline_expanded = baseline.unsqueeze(1).expand(
            -1, num_samples
        )

        advantage = rewards - baseline_expanded
        reinforce_loss = -(sequence_log_probs * advantage).mean()

        return reinforce_loss

    def __call__(
        self, model, batch: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        device = batch["encoder_input"].device

        _, output_floats, step_log_probs = model.sample_with_logprobs(
            batch, self.num_samples, self.temperature
        )

        predictions = torch.tensor(
            output_floats, device=device, dtype=torch.float32
        )
        targets = batch["y"]
        y_medians = batch["y_median"]
        q1s = batch["q1"]
        q3s = batch["q3"]

        predictions = (predictions - y_medians) / (q3s - q1s)
        targets = (targets - y_medians) / (q3s - q1s)
        rewards = self.compute_reward(predictions, targets)

        log_probs = step_log_probs

        reinforce_loss = self.compute_reinforce_loss(log_probs, rewards)

        metrics = {
            "reinforce_loss": reinforce_loss.detach(),
            "mean_reward": rewards.mean().detach(),
            "max_reward": rewards.max().detach(),
            "min_reward": rewards.min().detach(),
            "reward_std": rewards.std().detach(),
            "mean_prediction": predictions.mean().detach(),
            "prediction_std": predictions.std().detach(),
        }

        return reinforce_loss, metrics
