import os
import torch
import torch.nn.functional as F
from torch import nn

class DiscreteFM:
    def __init__(self, cat_conf):
        print('Using Masked Discrete Flow Model')
        self._conf = cat_conf
        self.vocab_size = int(cat_conf.vocab_size)
        # We assume the mask index is the last index.
        self.mask_token_idx = self.vocab_size 
        self.sampling_method = cat_conf['sampling_method']
        self.stochasticity = cat_conf['stochasticity']
        self.min_temp = cat_conf['min_temperature']
        self.max_temp = cat_conf['max_temperature']

    def sample_ref(self, n_samples: int, n_fragments: int, device: torch.device, frag_mask: torch.Tensor = None) -> torch.Tensor:
        """
        Samples reference (t=1), which is pure Mask state.
        """
        # Create a tensor filled with the mask token index
        ref = torch.full((n_samples, n_fragments), self.mask_token_idx, device=device, dtype=torch.long)
        return ref

    def forward_marginal(self, x_0: torch.Tensor, t: torch.Tensor, frag_mask: torch.Tensor):
        """
        Samples xt | x0.
        P(xt = mask) = t
        P(xt = x0) = 1 - t
        
        Args:
            x_0: [B, N] LongTensor of true fragment IDs.
            t: [B] FloatTensor time in [0, 1].
        """
        B, N = x_0.shape
        device = x_0.device
        
        # Expand time for broadcasting: [B, N]
        t_exp = t.view(B, 1).expand(B, N)
        
        # Sample mask decision: 1 if we should mask, 0 if we keep data
        # Prob(mask) = t
        mask_decision = torch.bernoulli(t_exp).bool()
        
        # Where frag_mask is 0 (padding), we don't care
        if frag_mask is not None:
            # Only apply masking logic where frag_mask is True
            mask_decision = mask_decision & frag_mask.bool()

        # Construct xt
        x_t = x_0.clone()
        x_t[mask_decision] = self.mask_token_idx
        
        return x_t

    def compute_loss(self, model_out: dict, batch: dict) -> torch.Tensor:
        """
        Computes Weighted Cross Entropy Loss.
        Returns: [B] tensor.
        """
        x_t = batch["cat_t"]             # [B, N]
        x_0 = batch["frag_ids"].long()   # [B, N]
        logits = model_out["cat_logits"] # [B, N, V]
        frag_mask = batch["frag_mask"].bool() 
        
        t = batch["t"]

        B, N, V = logits.shape

        # 1. Identify valid positions: Valid Fragment AND Currently Masked
        is_masked = (x_t == self.mask_token_idx)
        loss_mask = frag_mask & is_masked # [B, N]

        # 2. Compute per-token CE loss (unreduced)
        logits_flat = logits.reshape(-1, V)
        target_flat = x_0.reshape(-1)
        ce_loss_flat = F.cross_entropy(logits_flat, target_flat, reduction='none') 
        ce_loss = ce_loss_flat.view(B, N)

        # 3. Weighting: 1 / t (Inverse probability of being masked)
        # Clamp t for stability. Shape [B, 1]
        weights = 1.0 / torch.clamp(t, min=1e-4).view(B, 1)
        
        # 4. Apply mask and weights
        # We only learn on tokens that are currently masked
        weighted_loss = ce_loss * weights * loss_mask.float()

        # 5. Sum over atoms -> [B]
        loss_per_sample = weighted_loss.sum(dim=1)

        # 6. Normalize per sample 
        # Normalize by number of valid atoms (masked or not) to keep scale consistent
        # Dividing only by masked atoms causes high variance when t -> 0
        num_atoms = frag_mask.sum(dim=1).float() + 1e-9
        
        return loss_per_sample / num_atoms

    def reverse(
        self,
        s_t: torch.Tensor,
        logits: torch.Tensor,
        t: float,
        dt: float,
        flow_mask: torch.Tensor
    ):
        """
        Reverse step with stochastic error correction and temperature annealing.
        """
        device = s_t.device
        B, N = s_t.shape
        
        # 1. Calculate Current Temperature based on linear schedule
        # t goes from 1.0 -> 0.0
        # at t=1.0: temp = min + (max-min)*1 = max
        # at t=0.0: temp = min + (max-min)*0 = min
        curr_temp = self.min_temp + (self.max_temp - self.min_temp) * t
        
        # Safety clamp to prevent division by zero
        curr_temp = max(curr_temp, 1e-4)

        # 2. Compute Probabilities over x0
        scaled_logits = logits / curr_temp
        probs_x0 = F.softmax(scaled_logits, dim=-1) # [B, N, V]

        # 3. Compute Transition Rates
        curr_t = max(t, 1e-5)      
        comp_t = max(1.0 - t, 1e-5) 

        # Rate of unmasking boosted by stochasticity
        p_unmask = (dt / curr_t) * (1.0 + self.stochasticity)
        # Rate of re-masking (error correction)
        p_remask = (dt / comp_t) * self.stochasticity

        p_unmask = min(p_unmask, 1.0)
        p_remask = min(p_remask, 1.0)

        # 4. Identify Masks
        is_masked = (s_t == self.mask_token_idx)
        is_data = ~is_masked

        # 5. Sample Transitions
        # A. Mask -> Data
        should_unmask = torch.rand((B, N), device=device) < p_unmask
        update_unmask = is_masked & should_unmask
        
        # B. Data -> Mask (Error Correction)
        should_remask = torch.rand((B, N), device=device) < p_remask
        update_remask = is_data & should_remask
        
        # Apply flow mask constraints
        if flow_mask is not None:
            mask_bool = flow_mask.bool()
            update_unmask = update_unmask & mask_bool
            update_remask = update_remask & mask_bool

        # 6. Apply Updates
        s_next = s_t.clone()
        
        # Unmask: Sample new values
        if self.sampling_method == "argmax":
            # Note: Temperature has no effect on argmax, so this disables annealing
            new_x0_samples = torch.argmax(probs_x0, dim=-1)
        elif self.sampling_method == "multinomial":
            # Annealing happens here: high temp flattens probs, low temp sharpens them
            new_x0_samples = torch.multinomial(probs_x0.view(-1, self.vocab_size), 1).view(B, N)
            
        s_next[update_unmask] = new_x0_samples[update_unmask]
        
        # Remask: Revert to mask token
        s_next[update_remask] = self.mask_token_idx
        
        return s_next
