import time
from typing import List, Optional, Tuple

import torch
from torch import nn
from torch.utils.data import DataLoader
from rich.progress import (
    Progress, SpinnerColumn, BarColumn, TaskProgressColumn,
    MofNCompleteColumn, TimeElapsedColumn, TimeRemainingColumn
)

# one copy stream for D2H flushes
_COPY_STREAM = None
def _get_copy_stream(device="cuda"):
    global _COPY_STREAM
    if _COPY_STREAM is None:
        _COPY_STREAM = torch.cuda.Stream(device=device)
    return _COPY_STREAM


@torch.inference_mode()
def stage_forward(
    model: nn.Module,
    dataloader: DataLoader,                 # yields (input_ids) or (input_ids, attention_mask)
    pruner,
    stage_size: int = 4,
    start_layer: int = 0,
    end_layer: Optional[int] = None,
    device: str = "cuda",
) -> None:
    """
    Stage-wise forward with **overlapped D2H** (copy stream + pinned host buffers + non_blocking copies).
    - Uses model.model._update_causal_mask to mirror HF masking logic.
    - Builds shared (cos, sin) RoPE per stage via model.model.rotary_emb.
    - Correct stream ordering (wait_stream) and lifetime (record_stream).
    """

    layers = model.model.layers
    n_layers = len(layers)
    L0 = start_layer
    L1 = n_layers - 1 if end_layer is None else end_layer
    assert 0 <= L0 <= L1 < n_layers
    assert stage_size >= 1

    model.eval()
    if hasattr(model, "config"):
        model.config.use_cache = False

    # ---- 1) CPU pinned inputs; keep optional masks (pin also for cheaper H2D)
    ids_cpu:  List[torch.Tensor] = []
    mask_cpu: List[Optional[torch.Tensor]] = []
    for batch in dataloader:
        if isinstance(batch, (list, tuple)):
            input_ids = batch[0]
            attention_mask = batch[1] if len(batch) > 1 else None
        else:
            input_ids, attention_mask = batch, None
        ids_cpu.append(input_ids.contiguous().pin_memory())
        mask_cpu.append(None if attention_mask is None else attention_mask.contiguous().pin_memory())

    num_batches = len(ids_cpu)
    assert num_batches > 0, "Empty dataloader."

    # ---- Warmup for shapes/dtype
    # Ensure embedding and rotary modules are on the correct device before warmup
    model.model.embed_tokens.to(device)
    model.model.rotary_emb.to(device)
    _ids0 = ids_cpu[0].to(device, non_blocking=True)
    hidden0 = model.model.embed_tokens(_ids0)
    dtype = hidden0.dtype
    B, S, D = hidden0.shape
    del hidden0, _ids0

    # ---- Pinned CPU caches for stage outputs
    cached_hidden_cpu = [
        torch.empty((B, S, D), dtype=dtype, pin_memory=True)
        for _ in range(num_batches)
    ]

    # ---- Shared positions; RoPE per stage
    position_ids   = torch.arange(S, device=device, dtype=torch.long).unsqueeze(0).expand(B, -1)
    cache_position = torch.arange(S, device=device, dtype=torch.long)

    copy_stream = _get_copy_stream(device=device)

    def _stage_span(k0: int) -> Tuple[int, int]:
        k1 = min(k0 + stage_size - 1, L1)
        return k0, k1

    n_stages_total = ((L1 - L0 + 1) + stage_size - 1) // stage_size

    @torch.inference_mode()
    def run_stage(a: int, b: int, is_stage0: bool, progress: Progress) -> float:
        # Fence only at stage boundary
        torch.cuda.synchronize()
        t0 = time.perf_counter()

        # Build RoPE once per stage
        # position_embeddings = model.model.rotary_emb(
        #     torch.empty(B, S, 1, device=device, dtype=dtype), position_ids
        # )
        
        # Use the *real* hidden to drive rotary_emb (HF does this)
        # For stage 0 we have hidden from embed_tokens; for later stages it's loaded from cache.
        # So: grab a sample ‘hidden_for_rope’ once per stage from batch 0.
        if is_stage0:
            hidden_for_rope = model.model.embed_tokens(ids_cpu[0].to(device, non_blocking=True))
        else:
            hidden_for_rope = cached_hidden_cpu[0].to(device, non_blocking=True)

        position_embeddings = model.model.rotary_emb(hidden_for_rope, position_ids)
        del hidden_for_rope

        t_batch = progress.add_task(
            description=f"[blue]Layers {a}-{b}: batches",
            total=num_batches, visible=True
        )
        progress.start_task(t_batch)

        for bidx in range(num_batches):
            # H2D inputs (async if stage0; else from pinned cache)
            if is_stage0:
                input_ids = ids_cpu[bidx].to(device, non_blocking=True)
                hidden = model.model.embed_tokens(input_ids)
            else:
                hidden = cached_hidden_cpu[bidx].to(device, non_blocking=True)

            # Optional mask H2D (cheap; small)
            mask2d = None if mask_cpu[bidx] is None else mask_cpu[bidx].to(device, non_blocking=True)

            # Ask model to build its own causal mask for this hidden/dtype
            causal_mask = model.model._update_causal_mask(
                attention_mask=mask2d,
                input_tensor=hidden,
                cache_position=cache_position,
                past_key_values=None,
                output_attentions=False,
            )
            if causal_mask is not None:
                row_all_neg_inf = torch.isneginf(causal_mask).all(dim=-1).any().item()
                assert not row_all_neg_inf, "Found fully masked attention row"
            
            # Layers a..b
            for li in range(a, b + 1):
                hidden = layers[li](
                    hidden,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=None,
                    output_attentions=False,
                    use_cache=False,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )[0]

            # check the actual output of this stage
            if not torch.isfinite(hidden).all():
                raise RuntimeError(f"Non-finite detected inside stage [{a}-{b}] on batch {bidx}")
            # D2H async copy (order + lifetime safety)
            hidden = hidden.detach()
            cur = torch.cuda.current_stream()
            copy_stream.wait_stream(cur)
            with torch.cuda.stream(copy_stream):
                cached_hidden_cpu[bidx].copy_(hidden, non_blocking=True)
            hidden.record_stream(copy_stream)

            progress.update(t_batch, advance=1)

        # Stage fence
        copy_stream.synchronize()
        torch.cuda.synchronize()
        t1 = time.perf_counter()
        progress.stop_task(t_batch)
        progress.update(t_batch, completed=num_batches, visible=False)
        progress.remove_task(t_batch)
        return t1 - t0

    # ---- UI
    columns = (
        SpinnerColumn(),
        "[progress.description]{task.description}",
        BarColumn(),
        TaskProgressColumn(),
        "Progress:", MofNCompleteColumn(),
        "Elapsed:", TimeElapsedColumn(),
        "Remaining:", TimeRemainingColumn(),
    )

    # ---- Drive stages
    with Progress(*columns) as progress:
        t_stage = progress.add_task("[magenta]Stages", total=n_stages_total)
        progress.start_task(t_stage)

        s0, e0 = _stage_span(L0)
        progress.update(t_stage, description=f"[magenta]Stage 1/{n_stages_total} (layers {s0}-{e0})")
        for li in range(s0, e0 + 1):
            layers[li].to(device)
        
        _ = run_stage(s0, e0, is_stage0=True, progress=progress)
        for li in range(s0, e0 + 1):
            layers[li].to('cpu')
            pruner[li].move_all_hessians_to_cpu()
        progress.advance(t_stage)

        stage_idx = 1
        k0 = e0 + 1
        while k0 <= L1:
            k1 = _stage_span(k0)[1]
            stage_idx += 1
            progress.update(t_stage, description=f"[magenta]Stage {stage_idx}/{n_stages_total} (layers {k0}-{k1})")
            for li in range(k0, k1 + 1):
                layers[li].to(device)
            _ = run_stage(k0, k1, is_stage0=False, progress=progress)
            for li in range(k0, k1 + 1):
                layers[li].to('cpu')
                pruner[li].move_all_hessians_to_cpu()
            progress.advance(t_stage)
            k0 = k1 + 1
        
        model.model.embed_tokens.to('cpu')
        model.model.rotary_emb.to('cpu')
