import time
from typing import List, Optional, Tuple

import torch
from torch import nn
from torch.utils.data import DataLoader
from rich.progress import (
    Progress, SpinnerColumn, BarColumn, TaskProgressColumn,
    MofNCompleteColumn, TimeElapsedColumn, TimeRemainingColumn
)

@torch.inference_mode()
def stage_forward(
    model: nn.Module,
    dataloader: DataLoader,                 # yields (input_ids) or (input_ids, attention_mask)
    pruner,                                 # e.g., list-like; pruner[i].move_all_hessians_to_cpu()
    stage_size: int = 4,                    # 1 = layer-by-layer
    start_layer: int = 0,
    end_layer: Optional[int] = None,
    device: str = "cuda",
) -> None:
    """
    Stage-wise forward (synchronous copies; no CUDA streams).
    - Uses model.model._update_causal_mask to mirror HF attention-mask handling.
    - Builds (cos, sin) RoPE via model.model.rotary_emb and reuses it across layers in a stage.
    """

    layers = model.model.layers
    n_layers = len(layers)
    L0 = start_layer
    L1 = n_layers - 1 if end_layer is None else end_layer
    assert 0 <= L0 <= L1 < n_layers
    assert stage_size >= 1

    model.eval()
    if hasattr(model, "config"):
        model.config.use_cache = False

    # ---- 1) Slurp dataloader onto CPU (contiguous). Keep optional 2D masks if provided.
    ids_cpu:  List[torch.Tensor] = []
    mask_cpu: List[Optional[torch.Tensor]] = []
    for batch in dataloader:
        if isinstance(batch, (list, tuple)):
            input_ids = batch[0]
            attention_mask = batch[1] if len(batch) > 1 else None
        else:
            input_ids, attention_mask = batch, None
        ids_cpu.append(input_ids.cpu().contiguous())
        mask_cpu.append(None if attention_mask is None else attention_mask.cpu().contiguous())

    num_batches = len(ids_cpu)
    assert num_batches > 0, "Empty dataloader."

    # ---- Warmup for (B,S,D), dtype
    model.model.embed_tokens.to(device)
    model.model.rotary_emb.to(device)
    _ids0 = ids_cpu[0].to(device)
    hidden0 = model.model.embed_tokens(_ids0)        # [B,S,D] CUDA
    dtype = hidden0.dtype
    B, S, D = hidden0.shape
    del hidden0, _ids0

    # ---- CPU cache for stage outputs (sync copies)
    cached_hidden_cpu = [torch.empty((B, S, D), dtype=dtype) for _ in range(num_batches)]

    # ---- Shared positions (fixed here), RoPE computed per stage
    position_ids   = torch.arange(S, device=device, dtype=torch.long).unsqueeze(0).expand(B, -1)  # [B,S]
    cache_position = torch.arange(S, device=device, dtype=torch.long)                              # [S]

    def _stage_span(k0: int) -> Tuple[int, int]:
        k1 = min(k0 + stage_size - 1, L1)
        return k0, k1

    n_stages_total = ((L1 - L0 + 1) + stage_size - 1) // stage_size

    @torch.inference_mode()
    def run_stage(a: int, b: int, is_stage0: bool, progress: Progress) -> float:
        torch.cuda.synchronize()
        t0 = time.perf_counter()

        # # Build RoPE once per stage (use fp32 internally; model’s rotary_emb already upcasts)
        # # RoPE expects the last dim = head_dim implicitly; HF impl handles shapes internally.
        # position_embeddings = model.model.rotary_emb(
        #     torch.empty(B, S, 1, device=device, dtype=dtype), position_ids
        # )  # returns (cos, sin) cast back to dtype

        # Use the *real* hidden to drive rotary_emb (HF does this)
        # For stage 0 we have hidden from embed_tokens; for later stages it's loaded from cache.
        # So: grab a sample ‘hidden_for_rope’ once per stage from batch 0.
        if is_stage0:
            hidden_for_rope = model.model.embed_tokens(ids_cpu[0].to(device))
        else:
            hidden_for_rope = cached_hidden_cpu[0].to(device)

        position_embeddings = model.model.rotary_emb(hidden_for_rope, position_ids)
        del hidden_for_rope

        t_batch = progress.add_task(
            description=f"[blue]Layers {a}-{b}: batches",
            total=num_batches, visible=True
        )
        progress.start_task(t_batch)

        for bidx in range(num_batches):
            # Prepare inputs on device (sync copies)
            if is_stage0:
                input_ids = ids_cpu[bidx].to(device)
                hidden = model.model.embed_tokens(input_ids)         # [B,S,D]
            else:
                hidden = cached_hidden_cpu[bidx].to(device)

            # (Optional) 2D attention mask from dataloader
            mask2d = None if mask_cpu[bidx] is None else mask_cpu[bidx].to(device)

            # Ask the model to build the *exact* causal mask it would use
            causal_mask = model.model._update_causal_mask(
                attention_mask=mask2d,
                input_tensor=hidden,
                cache_position=cache_position,
                past_key_values=None,
                output_attentions=False,
            )

            # Layers a..b
            for li in range(a, b + 1):
                hidden = layers[li](
                    hidden,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=None,
                    output_attentions=False,
                    use_cache=False,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )[0]

            # D2H (sync)
            cached_hidden_cpu[bidx].copy_(hidden.detach().to("cpu"))

            progress.update(t_batch, advance=1)

        torch.cuda.synchronize()
        t1 = time.perf_counter()
        progress.stop_task(t_batch)
        progress.update(t_batch, completed=num_batches, visible=False)
        progress.remove_task(t_batch)
        return t1 - t0

    # ---- UI
    columns = (
        SpinnerColumn(),
        "[progress.description]{task.description}",
        BarColumn(),
        TaskProgressColumn(),
        "Progress:", MofNCompleteColumn(),
        "Elapsed:", TimeElapsedColumn(),
        "Remaining:", TimeRemainingColumn(),
    )

    # ---- Drive stages
    with Progress(*columns) as progress:
        t_stage = progress.add_task("[magenta]Stages", total=n_stages_total)
        progress.start_task(t_stage)

        s0, e0 = _stage_span(L0)
        progress.update(t_stage, description=f"[magenta]Stage 1/{n_stages_total} (layers {s0}-{e0})")
        for li in range(s0, e0 + 1):
            layers[li].to(device)
        _ = run_stage(s0, e0, is_stage0=True, progress=progress)
        for li in range(s0, e0 + 1):
            layers[li].to('cpu')
            pruner[li].move_all_hessians_to_cpu()
        progress.advance(t_stage)

        stage_idx = 1
        k0 = e0 + 1
        while k0 <= L1:
            k1 = _stage_span(k0)[1]
            stage_idx += 1
            progress.update(t_stage, description=f"[magenta]Stage {stage_idx}/{n_stages_total} (layers {k0}-{k1})")
            for li in range(k0, k1 + 1):
                layers[li].to(device)
            _ = run_stage(k0, k1, is_stage0=False, progress=progress)
            for li in range(k0, k1 + 1):
                layers[li].to('cpu')
                pruner[li].move_all_hessians_to_cpu()
            progress.advance(t_stage)
            k0 = k1 + 1
        # ---- Warmup for shapes/dtype
        model.model.embed_tokens.to('cpu')
        model.model.rotary_emb.to('cpu')

