"""Masking strategies for JEPA training."""

import math
import random
from dataclasses import dataclass
from typing import List, Tuple

import torch


@dataclass
class MaskConfig:
    """Configuration for a mask type."""
    mask_type: str  # "time_slot", "full_day", "full_slot", "random"
    weight: float = 1.0
    # For time_slot
    time_ratio: Tuple[float, float] = (0.3, 0.7)
    slot_ratio: Tuple[float, float] = (0.2, 0.4)
    # For full_day
    day_ratio: Tuple[float, float] = (0.1, 0.3)
    # For full_slot (uses slot_ratio)
    # For random
    mask_ratio: Tuple[float, float] = (0.4, 0.6)  # iid mask ratio range


class MarketMaskCollator:
    """Generates masks for JEPA training. Returns [B, L, K] bool (True=visible)."""

    def __init__(
        self,
        clip_length: int = 21,
        num_tokens: int = 24,
        mask_configs: List[MaskConfig] = None,
        min_visible_ratio: float = 0.3,
        causal_ratio: float = 0.0,
    ):
        if clip_length <= 0:
            raise ValueError(f"clip_length must be > 0, got {clip_length}")
        if num_tokens <= 0:
            raise ValueError(f"num_tokens must be > 0, got {num_tokens}")
        if not (0.0 <= min_visible_ratio <= 1.0):
            raise ValueError(f"min_visible_ratio must be in [0, 1], got {min_visible_ratio}")
        if not (0.0 <= causal_ratio <= 1.0):
            raise ValueError(f"causal_ratio must be in [0, 1], got {causal_ratio}")

        self.L = clip_length
        self.K = num_tokens
        self.min_visible_ratio = min_visible_ratio
        self.causal_ratio = causal_ratio

        if mask_configs is None:
            mask_configs = [
                MaskConfig("time_slot", weight=0.5, time_ratio=(0.3, 0.7), slot_ratio=(0.2, 0.4)),
                MaskConfig("full_day", weight=0.25, day_ratio=(0.1, 0.3)),
                MaskConfig("full_slot", weight=0.25, slot_ratio=(0.1, 0.2)),
            ]
        if len(mask_configs) == 0:
            raise ValueError("mask_configs must be non-empty")
        self.configs = mask_configs

        # Normalize weights
        total_weight = sum(c.weight for c in self.configs)
        if total_weight <= 0:
            raise ValueError(f"Sum of mask config weights must be > 0, got {total_weight}")
        self.weights = [c.weight / total_weight for c in self.configs]
        
        # Check if all configs are random (enables fast path)
        self.all_random = all(c.mask_type == "random" for c in self.configs)
        if self.all_random and len(self.configs) == 1:
            self.random_ratio = self.configs[0].mask_ratio

    def __call__(self, batch_size: int) -> torch.Tensor:
        """
        Generate masks for a batch (VECTORIZED).

        Args:
            batch_size: number of masks to generate

        Returns:
            [B, L, K] bool tensor, True = visible
        """
        # Fast path for pure random masking
        if self.all_random and hasattr(self, 'random_ratio'):
            return self._batch_random_mask(batch_size, self.random_ratio)
        
        # Fast path for structured masking (still vectorized)
        return self._batch_structured_mask(batch_size)
    
    def _batch_random_mask(self, B: int, mask_ratio: Tuple[float, float]) -> torch.Tensor:
        # Sample different ratio for each sample
        ratios = torch.empty(B).uniform_(mask_ratio[0], mask_ratio[1])
        # Generate all masks at once: True = visible
        masks = torch.rand(B, self.L, self.K) > ratios[:, None, None]
        
        # Handle causal ratio
        if self.causal_ratio > 0:
            causal_samples = torch.rand(B) < self.causal_ratio
            if causal_samples.any():
                causal_masks = self._batch_causal_mask(causal_samples.sum().item())
                masks[causal_samples] = causal_masks
        
        # Enforce minimum visibility (vectorized)
        masks = self._batch_enforce_min_visible(masks)
        return masks
    
    def _batch_structured_mask(self, B: int) -> torch.Tensor:
        # Determine which mask type each sample uses
        config_indices = torch.multinomial(
            torch.tensor(self.weights), B, replacement=True
        )
        
        # Handle causal ratio first
        if self.causal_ratio > 0:
            causal_samples = torch.rand(B) < self.causal_ratio
        else:
            causal_samples = torch.zeros(B, dtype=torch.bool)
        
        # Start with all visible
        masks = torch.ones(B, self.L, self.K, dtype=torch.bool)
        
        # Generate masks per type (vectorized within each type)
        for cfg_idx, config in enumerate(self.configs):
            # Which samples use this config (and aren't causal)
            use_this = (config_indices == cfg_idx) & (~causal_samples)
            count = use_this.sum().item()
            
            if count == 0:
                continue
            
            if config.mask_type == "time_slot":
                type_masks = self._batch_time_slot_mask(count, config.time_ratio, config.slot_ratio)
            elif config.mask_type == "full_day":
                type_masks = self._batch_full_day_mask(count, config.day_ratio)
            elif config.mask_type == "full_slot":
                type_masks = self._batch_full_slot_mask(count, config.slot_ratio)
            elif config.mask_type == "random":
                type_masks = self._batch_random_mask_simple(count, config.mask_ratio)
            else:
                raise ValueError(f"Unknown mask type: {config.mask_type}")
            
            masks[use_this] = type_masks
        
        # Handle causal samples
        if causal_samples.any():
            causal_count = causal_samples.sum().item()
            masks[causal_samples] = self._batch_causal_mask(causal_count)
        
        # Enforce minimum visibility
        masks = self._batch_enforce_min_visible(masks)
        return masks
    
    def _batch_time_slot_mask(self, B: int, time_ratio: Tuple[float, float], 
                               slot_ratio: Tuple[float, float]) -> torch.Tensor:
        masks = torch.ones(B, self.L, self.K, dtype=torch.bool)
        
        # Sample time spans
        t_ratios = torch.empty(B).uniform_(time_ratio[0], time_ratio[1])
        t_lens = (self.L * t_ratios).int().clamp(min=1)
        t_starts = (torch.rand(B) * (self.L - t_lens.float())).int().clamp(min=0)
        
        # Sample slot counts
        k_ratios = torch.empty(B).uniform_(slot_ratio[0], slot_ratio[1])
        k_counts = (self.K * k_ratios).int().clamp(min=1)
        
        # Apply masks (vectorized per-sample via broadcasting where possible)
        for i in range(B):
            t_start, t_len = t_starts[i].item(), t_lens[i].item()
            k_count = k_counts[i].item()
            k_indices = torch.randperm(self.K)[:k_count]
            masks[i, t_start:t_start + t_len, k_indices] = False
        
        return masks
    
    def _batch_full_day_mask(self, B: int, day_ratio: Tuple[float, float]) -> torch.Tensor:
        masks = torch.ones(B, self.L, self.K, dtype=torch.bool)
        
        ratios = torch.empty(B).uniform_(day_ratio[0], day_ratio[1])
        num_days = (self.L * ratios).int().clamp(min=1)
        
        for i in range(B):
            n = num_days[i].item()
            days = torch.randperm(self.L)[:n]
            masks[i, days, :] = False
        
        return masks
    
    def _batch_full_slot_mask(self, B: int, slot_ratio: Tuple[float, float]) -> torch.Tensor:
        masks = torch.ones(B, self.L, self.K, dtype=torch.bool)
        
        ratios = torch.empty(B).uniform_(slot_ratio[0], slot_ratio[1])
        num_slots = (self.K * ratios).int().clamp(min=1)
        
        for i in range(B):
            n = num_slots[i].item()
            slots = torch.randperm(self.K)[:n]
            masks[i, :, slots] = False
        
        return masks
    
    def _batch_random_mask_simple(self, B: int, mask_ratio: Tuple[float, float]) -> torch.Tensor:
        ratios = torch.empty(B).uniform_(mask_ratio[0], mask_ratio[1])
        masks = torch.rand(B, self.L, self.K) > ratios[:, None, None]
        return masks
    
    def _batch_causal_mask(self, B: int) -> torch.Tensor:
        masks = torch.ones(B, self.L, self.K, dtype=torch.bool)
        
        future_ratios = torch.empty(B).uniform_(0.2, 0.4)
        num_futures = (self.L * future_ratios).int().clamp(min=1)
        
        for i in range(B):
            n = num_futures[i].item()
            masks[i, -n:, :] = False
        
        return masks
    
    def _batch_enforce_min_visible(self, masks: torch.Tensor) -> torch.Tensor:
        if self.min_visible_ratio <= 0:
            return masks

        B = masks.size(0)
        total = self.L * self.K
        min_visible = int(math.ceil(self.min_visible_ratio * total))
        if min_visible <= 0:
            return masks

        visible_counts = masks.sum(dim=(1, 2))  # bool -> int64
        need_fix = visible_counts < min_visible
        
        if not need_fix.any():
            return masks
        
        # Fix samples that need it
        for i in torch.where(need_fix)[0]:
            i = int(i.item())
            need = int((min_visible - visible_counts[i]).item())
            if need <= 0:
                continue

            masked_indices = (~masks[i]).nonzero(as_tuple=False)
            if masked_indices.numel() == 0:
                continue

            need = min(need, masked_indices.size(0))
            perm = torch.randperm(masked_indices.size(0))[:need]
            tks = masked_indices[perm]
            masks[i, tks[:, 0], tks[:, 1]] = True
        
        return masks
    
    def _generate_mask(self) -> torch.Tensor:
        config = random.choices(self.configs, weights=self.weights, k=1)[0]

        if config.mask_type == "time_slot":
            mask = self._time_slot_mask(config.time_ratio, config.slot_ratio)
        elif config.mask_type == "full_day":
            mask = self._full_day_mask(config.day_ratio)
        elif config.mask_type == "full_slot":
            mask = self._full_slot_mask(config.slot_ratio)
        elif config.mask_type == "random":
            mask = self._random_mask(config.mask_ratio)
        else:
            raise ValueError(f"Unknown mask type: {config.mask_type}")

        mask = self._enforce_min_visible(mask)
        return mask

    def _time_slot_mask(self, time_ratio, slot_ratio) -> torch.Tensor:
        mask = torch.ones(self.L, self.K, dtype=torch.bool)
        t_ratio = random.uniform(*time_ratio)
        t_len = max(1, int(self.L * t_ratio))
        t_start = random.randint(0, self.L - t_len)
        k_ratio = random.uniform(*slot_ratio)
        k_count = max(1, int(self.K * k_ratio))
        k_indices = random.sample(range(self.K), k_count)
        for k in k_indices:
            mask[t_start:t_start + t_len, k] = False
        return mask

    def _full_day_mask(self, day_ratio) -> torch.Tensor:
        mask = torch.ones(self.L, self.K, dtype=torch.bool)
        ratio = random.uniform(*day_ratio)
        num_days = max(1, int(self.L * ratio))
        days = random.sample(range(self.L), num_days)
        for t in days:
            mask[t, :] = False
        return mask

    def _full_slot_mask(self, slot_ratio) -> torch.Tensor:
        mask = torch.ones(self.L, self.K, dtype=torch.bool)
        ratio = random.uniform(*slot_ratio)
        num_slots = max(1, int(self.K * ratio))
        slots = random.sample(range(self.K), num_slots)
        for k in slots:
            mask[:, k] = False
        return mask

    def _random_mask(self, mask_ratio) -> torch.Tensor:
        ratio = random.uniform(*mask_ratio)
        mask = torch.rand(self.L, self.K) > ratio
        return mask

    def _generate_causal_mask(self) -> torch.Tensor:
        mask = torch.ones(self.L, self.K, dtype=torch.bool)
        future_ratio = random.uniform(0.2, 0.4)
        num_future = max(1, int(self.L * future_ratio))
        mask[-num_future:, :] = False
        return mask

    def _enforce_min_visible(self, mask: torch.Tensor) -> torch.Tensor:
        if self.min_visible_ratio <= 0:
            return mask
        total = self.L * self.K
        min_visible = int(math.ceil(self.min_visible_ratio * total))
        if min_visible <= 0:
            return mask

        visible = int(mask.sum().item())
        if visible >= min_visible:
            return mask

        need = min_visible - visible
        masked_indices = (~mask).nonzero(as_tuple=False)
        if masked_indices.numel() == 0:
            return mask

        need = min(need, masked_indices.size(0))
        perm = torch.randperm(masked_indices.size(0))[:need]
        tks = masked_indices[perm]
        mask[tks[:, 0], tks[:, 1]] = True
        return mask


class MaskCollator:
    """
    Collator that wraps MarketMaskCollator for use with DataLoader.

    Usage:
        collator = MaskCollator(clip_length=21, num_tokens=24)
        loader = DataLoader(dataset, collate_fn=collator)
    """

    def __init__(
        self,
        clip_length: int = 21,
        num_tokens: int = 24,
        mask_configs: List[MaskConfig] = None,
        min_visible_ratio: float = 0.3,
        causal_ratio: float = 0.0,
    ):
        self.mask_gen = MarketMaskCollator(
            clip_length=clip_length,
            num_tokens=num_tokens,
            mask_configs=mask_configs,
            min_visible_ratio=min_visible_ratio,
            causal_ratio=causal_ratio,
        )

    def __call__(self, batch):
        """
        Collate batch and generate masks.

        Args:
            batch: list of dicts with "tokens", "stride", etc.

        Returns:
            dict with collated tensors and masks
        """
        tokens = torch.stack([b["tokens"] for b in batch])
        strides = [b["stride"] for b in batch]

        stride = strides[0]
        assert all(s == stride for s in strides), "Mixed strides in batch not supported"

        B = tokens.size(0)
        masks = self.mask_gen(B)

        return {
            "tokens": tokens,
            "mask": masks,
            "stride": stride,
        }
