from typing import Dict, Tuple

import torch
import torch.nn.functional as F


class Remax_mse:
    def __init__(
        self,
        temperature: float = 1.0,
        num_samples: int = 8,
        reward_scale: float = 1.0,
        kernel_sigma: float = 1.0,
        kernel_reduction: str = "mean",
        kl_weight: float = 0.1,
    ):
        self.temperature = temperature
        self.num_samples = num_samples
        self.reward_scale = reward_scale
        self.kernel_sigma = kernel_sigma
        self.kernel_reduction = kernel_reduction
        self.kl_weight = kl_weight

    def compute_reward(
        self, predictions: torch.Tensor, targets: torch.Tensor
    ) -> torch.Tensor:
        if predictions.dim() == 3:
            targets_expanded = (
                targets.unsqueeze(1)
                .unsqueeze(2)
                .expand(-1, predictions.size(1), predictions.size(2))
            )
            diff2 = (predictions - targets_expanded) ** 2
            if self.kernel_reduction == "sum":
                dist2 = diff2.sum(dim=2)
            else:
                dist2 = diff2.mean(dim=2)
        else:
            targets_expanded = targets.unsqueeze(1).expand(-1, predictions.size(1))
            dist2 = (predictions - targets_expanded) ** 2
        return -dist2

    def compute_kl_divergence(
        self, 
        current_logits: torch.Tensor, 
        ref_logits: torch.Tensor,
        mask: torch.Tensor = None
    ) -> torch.Tensor:
        current_log_probs = F.log_softmax(current_logits, dim=-1)
        ref_probs = F.softmax(ref_logits, dim=-1)
        
        kl_div = F.kl_div(
            current_log_probs, 
            ref_probs, 
            reduction='none'
        )
        
        kl_div_per_token = kl_div.sum(dim=-1)
        
        if mask is not None:
            kl_div_per_token = kl_div_per_token * mask.float()
            kl_loss = kl_div_per_token.sum() / mask.float().sum()
        else:
            kl_loss = kl_div_per_token.mean()
            
        return kl_loss

    def compute_reinforce_loss(
        self, log_probs: torch.Tensor, rewards: torch.Tensor, baseline: torch.Tensor
    ) -> torch.Tensor:
        sequence_log_probs = log_probs.sum(dim=2)

        advantage = rewards - baseline
        num_nonzero_advantage = (advantage != 0).sum().item()
        advantage = advantage.detach()
        advantages_std_per_batch = advantage.std(dim=1)
        mean_advantages_std = advantages_std_per_batch.mean()

        reinforce_loss = -(sequence_log_probs * advantage).mean()

        return reinforce_loss, mean_advantages_std, num_nonzero_advantage

    def compute_policy_entropy(
        self, model, batch: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        batch_size = batch["encoder_input"].shape[0]
        device = batch["encoder_input"].device
        
        with torch.no_grad():
            memory, memory_key_padding_mask = model.encoder_decoder.encode(batch["encoder_input"])
            
            current_tgt_ids = torch.full(
                (batch_size, 1),
                model.decoder_vocab.bos_pad_id,
                dtype=torch.long,
                device=device,
            )
            
            step_entropies = torch.zeros(
                (batch_size, model.decode_len),
                dtype=torch.float32,
                device=device,
            )
            
            for step_idx in range(model.decode_len):
                logits = model.encoder_decoder.next_token_logits(
                    current_tgt_ids, memory, memory_key_padding_mask
                )
                
                curr_mask = model.decoder_constraint_masks[step_idx, :].unsqueeze(0)
                MASK_VALUE = -1e7
                masked_logits = (1.0 - curr_mask) * MASK_VALUE + curr_mask * logits
                
                scaled_logits = masked_logits / self.temperature
                probs = F.softmax(scaled_logits, dim=-1)
                log_probs = F.log_softmax(scaled_logits, dim=-1)
                
                p_log_p = probs * log_probs
                p_log_p = torch.nan_to_num(p_log_p, nan=0.0, posinf=0.0, neginf=0.0)
                entropy = -torch.sum(p_log_p, dim=-1)
                step_entropies[:, step_idx] = entropy
                
                next_token_ids = torch.argmax(scaled_logits, dim=-1, keepdim=True)
                
                if step_idx < model.decode_len - 1:
                    current_tgt_ids = torch.cat([current_tgt_ids, next_token_ids], dim=1)
            
            policy_entropy = step_entropies.mean()
        
        return policy_entropy

    def compute_policy_entropy_sample(
        self, model, batch: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        batch_size = batch["encoder_input"].shape[0]
        device = batch["encoder_input"].device
        
        with torch.no_grad():
            memory, memory_key_padding_mask = model.encoder_decoder.encode(batch["encoder_input"])
            
            num_samples = self.num_samples
            expanded_memory = (
                memory.unsqueeze(1)
                .expand(-1, num_samples, -1, -1)
                .reshape(batch_size * num_samples, memory.size(1), memory.size(2))
            )
            
            if memory_key_padding_mask is not None:
                expanded_memory_key_padding_mask = (
                    memory_key_padding_mask.unsqueeze(1)
                    .expand(-1, num_samples, -1)
                    .reshape(batch_size * num_samples, memory_key_padding_mask.size(1))
                )
            else:
                expanded_memory_key_padding_mask = None
            
            current_tgt_ids = torch.full(
                (batch_size * num_samples, 1),
                model.decoder_vocab.bos_pad_id,
                dtype=torch.long,
                device=device,
            )
            
            step_entropies = torch.zeros(
                (batch_size * num_samples, model.decode_len),
                dtype=torch.float32,
                device=device,
            )
            
            for step_idx in range(model.decode_len):
                logits = model.encoder_decoder.next_token_logits(
                    current_tgt_ids,
                    expanded_memory,
                    expanded_memory_key_padding_mask
                )
                
                curr_mask = model.decoder_constraint_masks[step_idx, :].unsqueeze(0)
                MASK_VALUE = -1e7
                masked_logits = (1.0 - curr_mask) * MASK_VALUE + curr_mask * logits
                
                scaled_logits = masked_logits / self.temperature
                probs = F.softmax(scaled_logits, dim=-1)
                log_probs = F.log_softmax(scaled_logits, dim=-1)
                
                p_log_p = probs * log_probs
                p_log_p = torch.nan_to_num(p_log_p, nan=0.0, posinf=0.0, neginf=0.0)
                entropy = -torch.sum(p_log_p, dim=-1)
                step_entropies[:, step_idx] = entropy
                
                next_token_ids = torch.multinomial(probs, num_samples=1)
                
                if step_idx < model.decode_len - 1:
                    current_tgt_ids = torch.cat([current_tgt_ids, next_token_ids], dim=1)
            
            step_entropies = step_entropies.reshape(batch_size, num_samples, model.decode_len)
            sample_entropies = step_entropies.mean(dim=-1)
            batch_entropies = sample_entropies.mean(dim=-1)
            policy_entropy = batch_entropies.mean()
        
        return policy_entropy

    def __call__(
        self, model, batch: Dict[str, torch.Tensor], ref_model=None
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        model.eval()
        if ref_model is not None:
            ref_model.eval()
        batch_size = batch["encoder_input"].shape[0]
        device = batch["encoder_input"].device

        decoded_ids, output_floats, step_log_probs, step_logits = model.sample_with_logprobs(
            batch, self.num_samples, self.temperature, return_logits=True
        )

        predictions = torch.tensor(
            output_floats, device=device, dtype=torch.float32
        )
        targets = batch["y"]
        y_max = batch["y_max"]
        y_min = batch["y_min"]
        y_max_bs = y_max.unsqueeze(1)
        y_min_bs = y_min.unsqueeze(1)
        predictions_true = predictions.squeeze(-1)
        predictions_norm = predictions_true * (y_max_bs - y_min_bs + 1e-8) + y_min_bs
        targets_norm = targets * (y_max - y_min + 1e-8) + y_min
        rewards = self.compute_reward(predictions_norm, targets_norm)
        
        reward_std_per_batch = rewards.std(dim=1)
        mean_reward_std = reward_std_per_batch.mean()

        with torch.no_grad():
            _, greedy_output_floats = model.greedy_decode(
                batch, num_samples=1
            )
        greedy_predictions = torch.tensor(
            greedy_output_floats, device=device, dtype=torch.float32
        )
        greedy_predictions = greedy_predictions[:, :, 0]
        greedy_predictions = greedy_predictions.expand(-1, self.num_samples)
        greedy_predictions_norm = greedy_predictions * (y_max_bs - y_min_bs + 1e-8) + y_min_bs
        baseline = self.compute_reward(greedy_predictions_norm, targets_norm)
        mean_baseline = baseline.mean().detach()
        log_probs = step_log_probs
        num_rewards_ge_baseline = (rewards >= baseline).sum().item()
        num_rewards_lt_baseline = (rewards < baseline).sum().item()
        reinforce_loss, mean_advantages_std, num_nonzero_advantage = self.compute_reinforce_loss(log_probs, rewards, baseline)
        
        total_loss = reinforce_loss
        kl_loss = torch.tensor(0.0, device=device)
        
        if ref_model is not None and self.kl_weight > 0:
            _, _, seq_len, vocab_size = step_logits.shape
            current_logits = step_logits.view(batch_size * self.num_samples, seq_len, vocab_size)
            
            vectorized_decoded_ids = decoded_ids.view(batch_size * self.num_samples, seq_len)
            
            mask = (vectorized_decoded_ids != model.decoder_vocab.bos_pad_id).float()

            with torch.no_grad():
                ref_logits = self._get_logits_for_sequence(
                    ref_model, batch, vectorized_decoded_ids, self.num_samples
                )
            
            kl_loss = self.compute_kl_divergence(current_logits, ref_logits, mask)
            
            total_loss = total_loss + self.kl_weight * kl_loss
        policy_entropy = self.compute_policy_entropy_sample(model, batch)
        metrics = {
            "reinforce_loss": reinforce_loss.detach(),
            "kl_loss": kl_loss.detach(),
            "total_loss": total_loss.detach(),
            "mean_reward": rewards.mean().detach(),
            "max_reward": rewards.max().detach(),
            "min_reward": rewards.min().detach(),
            "reward_std": rewards.std().detach(),
            "mean_reward_std": mean_reward_std.detach(),
            "mean_advantages_std": mean_advantages_std.detach(),
            "mean_baseline": mean_baseline,
            "num_nonzero_advantage": num_nonzero_advantage,
            "num_rewards_ge_baseline": num_rewards_ge_baseline,
            "num_rewards_lt_baseline": num_rewards_lt_baseline,
            "mean_prediction": predictions.mean().detach(),
            "prediction_std": predictions.std().detach(),
            "policy_entropy": policy_entropy.detach(),
        }

        model.train()

        return total_loss, metrics

    def _get_logits_for_sequence(
        self, model, batch: Dict[str, torch.Tensor], decoded_ids: torch.Tensor, num_samples: int
    ) -> torch.Tensor:
        encoder_input = batch["encoder_input"]
        batch_size = encoder_input.shape[0]
        device = encoder_input.device
        
        expanded_encoder_input = encoder_input.repeat_interleave(num_samples, dim=0)
        
        memory, memory_key_padding_mask = model.encoder_decoder.encode(expanded_encoder_input)
        
        current_tgt_ids = torch.full(
            (batch_size * num_samples, 1),
            model.decoder_vocab.bos_pad_id,
            dtype=torch.long,
            device=device,
        )
        
        step_logits = torch.zeros(
            (batch_size * num_samples, decoded_ids.shape[1], len(model.decoder_vocab)),
            dtype=torch.float32,
            device=device,
        )
        
        for step_idx in range(decoded_ids.shape[1]):
            logits = model.encoder_decoder.next_token_logits(
                current_tgt_ids, memory, memory_key_padding_mask
            )
            step_logits[:, step_idx, :] = logits
            
            if step_idx < decoded_ids.shape[1] - 1:
                current_tgt_ids = torch.cat([current_tgt_ids, decoded_ids[:, step_idx:step_idx+1]], dim=1)
        
        return step_logits