"""
Differentiable Chunking Module for DNAChunker.

This module provides a differentiable alternative to hard boundary selection,
enabling gradient flow from the MLM loss to the routing module.

Key insight: Standard chunking uses hard selection which breaks gradients.
We replace this with soft segment assignment that:
1. Uses Straight-Through Estimator (STE) for boundary decisions
2. Uses differentiable soft pooling based on cumulative segment IDs
3. Allows MLM loss to influence boundary placement decisions
"""



import torch

import torch.nn as nn

import torch.nn.functional as F

from typing import Optional, Tuple





class StraightThroughBoundary(torch.autograd.Function):

    """
    Straight-Through Estimator for boundary decisions.

    Forward: Hard boundary (p > threshold)
    Backward: Gradient passes through as if we used the soft probability
    """

    @staticmethod

    def forward(ctx, p: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:

        return (p > threshold).float()



    @staticmethod

    def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:



        return grad_output, None





def ste_boundary(p: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:

    """
    Apply Straight-Through Estimator to boundary probability.

    Forward: Returns hard boundary (0 or 1)
    Backward: Gradient flows through p directly
    """

    return StraightThroughBoundary.apply(p, threshold)



class Downsampler(nn.Module):

    """
    Differentiable downsampler using soft segment assignment.

    Now additionally returns:
      token_to_chunk: (B, L_in, L_out) soft assignment for attention restriction
    """



    def __init__(

        self,

    ):

        super().__init__()



    def forward(

        self,

        x: torch.Tensor,

        p_original: torch.Tensor,

        boundaries: torch.Tensor,

        mask_locations: torch.Tensor,

        pad_locations: torch.Tensor,

        input_original_indices_l0: Optional[torch.Tensor] = None,

    ):

        """
        Hard-STE downsampling with vectorized segment pooling.

        Returns:
            pooled_output: (B, L_out, D) - pooled representations
            chunk_lengths: (B,)
            comp_mask_loc: (B, L_out)
            comp_pad_loc:  (B, L_out)
            output_original_indices_l0: (B, L_out)
            token_to_chunk: (B, L_in, L_out)
            segment_confidence: (B, L_out) - per-segment confidence scores
        """

        B, L, D = x.shape

        device = x.device

        dtype = x.dtype

        valid_mask = ~pad_locations





        boundaries_hard = boundaries.to(dtype).clamp(0, 1) * valid_mask.to(dtype)





        chunk_lengths = boundaries_hard.sum(dim=1).int()

        max_chunks = max(int(chunk_lengths.max().item()) if chunk_lengths.numel() > 0 else 1, 1)





        comp_pad_loc = torch.arange(max_chunks, device=device).unsqueeze(0) >= chunk_lengths.unsqueeze(1)





        segment_ids = torch.cumsum(boundaries_hard.int(), dim=1)

        chunk_idx = (segment_ids - 1).clamp(min=0, max=max_chunks - 1)



        token_to_chunk_raw = F.one_hot(chunk_idx, num_classes=max_chunks).float()

        token_to_chunk_raw = token_to_chunk_raw * valid_mask.unsqueeze(-1).float()

        token_to_chunk_raw = token_to_chunk_raw * (~comp_pad_loc).unsqueeze(1).float()





        x_masked = x * valid_mask.unsqueeze(-1).to(dtype)

        pooled_sum = torch.bmm(token_to_chunk_raw.transpose(1, 2), x_masked)

        counts = token_to_chunk_raw.sum(dim=1).unsqueeze(-1)

        pooled_output = pooled_sum / counts.clamp(min=1.0)





        mask_float = mask_locations.float().unsqueeze(-1)

        mask_sum = torch.bmm(token_to_chunk_raw.transpose(1, 2), mask_float).squeeze(-1)

        comp_mask_loc = mask_sum > 0.0





        if input_original_indices_l0 is None:

            positions = torch.arange(L, device=device, dtype=torch.float32).unsqueeze(0).expand(B, -1)

        else:

            positions = input_original_indices_l0.to(torch.float32)



        pos_sum = torch.bmm(token_to_chunk_raw.transpose(1, 2), positions.unsqueeze(-1)).squeeze(-1)

        output_original_indices_l0 = pos_sum / counts.squeeze(-1).clamp(min=1.0)

        output_original_indices_l0 = output_original_indices_l0 * (~comp_pad_loc).float()





        confidence = (2 * p_original - 1).pow(2) * valid_mask.float()

        conf_sum = torch.bmm(token_to_chunk_raw.transpose(1, 2), confidence.unsqueeze(-1)).squeeze(-1)

        segment_confidence = conf_sum / counts.squeeze(-1).clamp(min=1.0)

        segment_confidence = segment_confidence * (~comp_pad_loc).float()





        pooled_output = pooled_output * (~comp_pad_loc).unsqueeze(-1).to(dtype)

        comp_mask_loc = comp_mask_loc & (~comp_pad_loc)





        token_to_chunk = token_to_chunk_raw / token_to_chunk_raw.sum(dim=-1, keepdim=True).clamp(min=1e-8)



        return (

            pooled_output.to(dtype),

            chunk_lengths,

            comp_mask_loc,

            comp_pad_loc,

            output_original_indices_l0.to(torch.float32),

            token_to_chunk.to(torch.float32),

            segment_confidence.to(torch.float32),

        )





class RoutingModule(nn.Module):

    """
    Routing module with improved gradient flow.

    Changes from original:
    1. Uses STE for boundary decision
    2. Adds boundary entropy regularization for exploration
    3. Provides both hard and soft boundaries
    """



    def __init__(self, hid_size: int):

        super().__init__()

        self.w_q = nn.Linear(hid_size, hid_size, bias=False)

        self.w_k = nn.Linear(hid_size, hid_size, bias=False)



    def _get_protection_boundaries(self, mask_locations: torch.Tensor) -> torch.Tensor:

        """Ensure boundaries around special mask tokens."""

        shifted = torch.cat([

            torch.zeros_like(mask_locations[:, :1]),

            mask_locations[:, :-1]

        ], dim=1)

        return mask_locations | shifted



    def forward(

        self,

        x: torch.Tensor,

        mask_locations: torch.Tensor,

        pad_locations: torch.Tensor,

        temperature: float = 1.0,

        return_entropy: bool = False,

        enforce_mask_boundaries: bool = True,

    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:

        """
        Compute boundary probabilities and decisions.

        Args:
            x: Input tensor (B, L, D)
            mask_locations: Special mask token locations (B, L)
            pad_locations: Padding mask (B, L)
            temperature: Temperature for boundary sharpness
            return_entropy: If True, return boundary entropy for regularization
            enforce_mask_boundaries: If True, force boundaries around mask tokens.
                                    Set True for stage 1, False for stage 2.

        Returns:
            p: Boundary probabilities (B, L) - differentiable
            b: Hard boundary decisions (B, L) - via STE, differentiable backward
            entropy: Optional boundary entropy for regularization
        """

        B = x.shape[0]

        device = x.device





        q_aligned = self.w_q(x[:, 1:, :])

        k_aligned = self.w_k(x[:, :-1, :])



        dot = torch.sum(q_aligned * k_aligned, dim=-1)

        q_norm = torch.linalg.vector_norm(q_aligned, dim=-1)

        k_norm = torch.linalg.vector_norm(k_aligned, dim=-1)



        eps = 1e-6

        norm_prod = (q_norm * k_norm).clamp(min=eps)

        sim = dot / norm_prod







        p_vals = 0.5 * (1 - sim)





        first_p = torch.ones(B, 1, device=device, dtype=p_vals.dtype)

        p = torch.cat([first_p, p_vals], dim=1)





        p = p * (~pad_locations).float()









        b = (p >= 0.5).float()









        if enforce_mask_boundaries:

            special_b = self._get_protection_boundaries(mask_locations).float()

            b_final = torch.max(b, special_b)

            p_final = torch.max(p, special_b)

        else:

            b_final = b

            p_final = p





        b_final = b_final * (~pad_locations).float()

        p_final = p_final * (~pad_locations).float()





        entropy = None

        if return_entropy:





            p_clamped = p.clamp(min=eps, max=1-eps)

            entropy_per_pos = -(p_clamped * torch.log(p_clamped) +

                                (1 - p_clamped) * torch.log(1 - p_clamped))



            valid_mask = (~pad_locations).float()

            entropy = (entropy_per_pos * valid_mask).sum(dim=1) / valid_mask.sum(dim=1).clamp(min=1)

            entropy = entropy.mean()



        return p_final, b_final, entropy





def test_differentiable_chunking():

    """Test that gradients flow properly through the differentiable chunking."""

    print("Testing Downsampler gradient flow...\n")



    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



    B, L, D = 2, 16, 32





    x = torch.randn(B, L, D, device=device, requires_grad=True)





    router = RoutingModule(D).to(device)





    downsampler = Downsampler().to(device)

    downsampler.train()





    mask_locations = torch.zeros(B, L, dtype=torch.bool, device=device)

    mask_locations[:, 5] = True

    pad_locations = torch.zeros(B, L, dtype=torch.bool, device=device)

    pad_locations[:, -3:] = True





    p, b, entropy = router(x, mask_locations, pad_locations, return_entropy=True)



    print(f"p range: [{p.min().item():.3f}, {p.max().item():.3f}]")

    print(f"Boundaries per item: {b.sum(dim=1).tolist()}")

    print(f"Entropy: {entropy.item():.4f}")





    pooled, chunk_lengths, comp_mask, comp_pad, positions, token_to_chunk, segment_confidence = downsampler(

        x, p, b, mask_locations, pad_locations

    )



    print(f"\nPooled shape: {pooled.shape}")

    print(f"Chunk lengths: {chunk_lengths.tolist()}")





    decoder = nn.Linear(D, 10, device=device)

    logits = decoder(pooled)





    target = torch.randint(0, 10, (B, pooled.shape[1]), device=device)





    loss_mask = ~comp_pad

    if loss_mask.any():

        mlm_loss = F.cross_entropy(

            logits[loss_mask],

            target[loss_mask]

        )

    else:

        mlm_loss = torch.tensor(0.0, device=device)





    total_loss = mlm_loss - 0.01 * entropy



    print(f"\nMLM Loss: {mlm_loss.item():.4f}")

    print(f"Total Loss: {total_loss.item():.4f}")





    total_loss.backward()





    print("\n--- Gradient Analysis ---")

    print(f"x.grad is not None: {x.grad is not None}")

    if x.grad is not None:

        print(f"x.grad norm: {x.grad.norm().item():.6e}")



    print(f"router.w_q.weight.grad is not None: {router.w_q.weight.grad is not None}")

    if router.w_q.weight.grad is not None:

        print(f"router.w_q.weight.grad norm: {router.w_q.weight.grad.norm().item():.6e}")



    print(f"router.w_k.weight.grad is not None: {router.w_k.weight.grad is not None}")

    if router.w_k.weight.grad is not None:

        print(f"router.w_k.weight.grad norm: {router.w_k.weight.grad.norm().item():.6e}")





    success = (

        router.w_q.weight.grad is not None and

        router.w_q.weight.grad.norm() > 1e-8 and

        router.w_k.weight.grad is not None and

        router.w_k.weight.grad.norm() > 1e-8

    )



    print(f"\n{'='*50}")

    print(f"GRADIENT FLOW TEST: {'PASSED' if success else 'FAILED'}")

    print(f"{'='*50}")



    return success





if __name__ == "__main__":

    test_differentiable_chunking()

