from __future__ import annotations

import math
from typing import List, Optional, Sequence, Tuple, Union, Iterable  # <-- added Iterable

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Iterable, Optional, List

def _resolve_exit_layers(
    num_layers: int,
    ee_layer_locations: Optional[Iterable[int]] = None,
    blks_to_exit: Optional[Iterable[int]] = None,  # alias
    no_of_exits: Optional[int] = None,             # optional truncate (legacy)
    exits_every_k: Optional[int] = None,
    exits_count: Optional[int] = None,
) -> List[int]:
    """
    Returns a deduped, ordered list of 0-based layer indices where exits are placed.
    Rules:
      - If explicit list is provided -> use it (accepts 1-based and -1).
      - Else if exits_every_k is given -> place exit every k layers (first at k-1).
      - Else if exits_count is given -> place ~uniformly spaced exits across [0..L-1].
      - Else default: every layer (0..L-1).
    Always forces the final exit to sit on the last layer (L-1).
    """
    L = int(num_layers)
    assert L >= 1, "num_layers must be >= 1"

    # 1) explicit list wins
    src = ee_layer_locations if ee_layer_locations is not None else blks_to_exit
    exits: List[int] = []

    def _push(v: int):
        if 0 <= v < L and (len(exits) == 0 or exits[-1] != v):
            exits.append(v)

    if src is not None:
        seen = set()
        for v in src:
            v = int(v)
            if v == -1:
                v = L - 1
            elif 1 <= v <= L:
                v = v - 1   # accept 1-based
            # else assume already 0-based
            if 0 <= v < L and v not in seen:
                exits.append(v)
                seen.add(v)

    # 2) every-k if no explicit
    if not exits and exits_every_k is not None and exits_every_k > 0:
        k = int(exits_every_k)
        start = min(L - 1, max(0, k - 1))
        for idx in range(start, L, k):
            _push(idx)

    # 3) uniform count if still empty
    if not exits and exits_count is not None and exits_count > 0:
        N = int(exits_count)
        if N == 1:
            _push(L - 1)
        else:
            for j in range(N):
                pos = round(j * (L - 1) / (N - 1))
                _push(int(pos))

    # 4) default: every layer
    if not exits:
        for i in range(L):
            _push(i)

    # 5) force last exit to be the last layer
    if exits[-1] != (L - 1):
        exits[-1] = L - 1

    # 6) optional truncate (legacy)
    if no_of_exits is not None:
        exits = exits[: int(no_of_exits)]

    # dedupe while keeping order
    out, seen = [], set()
    for v in exits:
        if v not in seen:
            out.append(v); seen.add(v)

    return out if out else [L - 1]

class SequenceExitHead(nn.Module):
    def __init__(self, hidden_dim: int, vocab_size: int, dropout: float = 0.0,
                 use_ln: bool = False, last_token_only: bool = True):
        super().__init__()
        self.norm = nn.LayerNorm(hidden_dim) if use_ln else nn.Identity()
        self.drop = nn.Dropout(dropout) if dropout and dropout > 0 else nn.Identity()
        self.fc   = nn.Linear(hidden_dim, vocab_size, bias=True)
        self.last_token_only = bool(last_token_only)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Accepts:
          - (B, T, H): standard time-distributed features
          - (B, H):    already last-token features (prev-compatible)
        Returns:
          - (B, V)     if last_token_only=True or input was (B, H)
          - (B, T, V)  otherwise
        """
        x = self.norm(x)
        x = self.drop(x)

        if x.dim() == 2:  # (B, H) → (B, V)
            return self.fc(x)

        # (B, T, H)
        if self.last_token_only:
            return self.fc(x[:, -1, :])  # (B, V)

        B, T, H = x.shape
        return self.fc(x.reshape(B * T, H)).reshape(B, T, -1)  # (B, T, V)


class GRUEarlyExitLM(nn.Module):
    def __init__(
        self,
        *,
        vocab_size: int,
        embed_dim: int = 128,
        hidden_dim: int = 256,
        num_layers: int = 4,
        dropout: float = 0.1,
        head_dropout: float = 0.0,
        tie_final_weights: bool = False,
        last_exit_only: bool = False,
        pad_idx: int = 0,
        ee_layer_locations: Optional[Iterable[int]] = None,
        no_of_exits: Optional[int] = None,
        blks_to_exit: Optional[Iterable[int]] = None,
        # NEW (prev-compatible; defaults keep old behavior)
        exits_every_k: Optional[int] = None,
        exits_count: Optional[int] = None,
        **kwargs,
    ) -> None:
        super().__init__()
        assert num_layers >= 1, "num_layers must be >= 1"

        self.vocab_size = int(vocab_size)
        self.embed_dim  = int(embed_dim)
        self.hidden_dim = int(hidden_dim)
        self.num_layers = int(num_layers)
        self.pad_idx    = int(pad_idx)
        
        # ------ resolve early-exit locations ------
        raw_exits = ee_layer_locations if ee_layer_locations is not None else blks_to_exit

        if raw_exits is None:
            exits = list(range(self.num_layers))  # exit after every layer
        else:
            exits = []
            seen = set()
            for v in raw_exits:
                v = int(v)
                if v == -1:
                    v = self.num_layers - 1                  # -1 -> last layer
                elif 1 <= v <= self.num_layers:
                    v = v - 1                                # accept 1-based
                # keep only valid indices
                if 0 <= v < self.num_layers and v not in seen:
                    exits.append(v)
                    seen.add(v)

            if not exits:
                exits = [self.num_layers - 1]

        # keep only the first N exits if requested
        if no_of_exits is not None:
            exits = exits[: int(no_of_exits)]

        # match ResNet behavior: ensure the **final** exit sits on the last built layer
        if self.num_layers > 0:
            if not exits:
                exits = [self.num_layers - 1]
            else:
                exits[-1] = self.num_layers - 1

        self.blks_to_exit: List[int] = _resolve_exit_layers(
            self.num_layers,
            ee_layer_locations=ee_layer_locations,
            blks_to_exit=blks_to_exit,
            no_of_exits=no_of_exits,
            exits_every_k=exits_every_k,
            exits_count=exits_count,
        )

        self.no_of_exits: int = len(self.blks_to_exit)

        # keep a "layers" attribute so strategies can probe depth (sum(len(s)) == num_layers)
        self.layers = [list(range(self.num_layers))]

        # ------ embedding / GRU stack ------
        self.embedding = nn.Embedding(self.vocab_size, self.embed_dim, padding_idx=self.pad_idx)

        self.gru_layers = nn.ModuleList()
        in_dim = self.embed_dim
        for _ in range(self.num_layers):
            self.gru_layers.append(nn.GRU(input_size=in_dim, hidden_size=self.hidden_dim, num_layers=1, batch_first=True))
            in_dim = self.hidden_dim

        # exit heads (only where we placed exits)
        self._exit_layers: List[int] = list(self.blks_to_exit)
        self.exit_heads = nn.ModuleList([
            SequenceExitHead(self.hidden_dim, self.vocab_size, dropout=head_dropout, use_ln=False)
            for _ in self._exit_layers
        ])

        # optional weight tying (only if shapes align)
        self.tie_final_weights = bool(tie_final_weights)
        if self.tie_final_weights and len(self.exit_heads) > 0:
            if self.exit_heads[-1].fc.weight.shape == self.embedding.weight.shape:
                self.exit_heads[-1].fc.weight = self.embedding.weight
            else:
                self.tie_final_weights = False  # skip if dims mismatch

        self.last_exit_only: bool = bool(last_exit_only)
        self.active_exit: Optional[int] = None

        # key lists for FL strategies
        self.all_state_dict_keys = list(self.state_dict().keys())
        self.trainable_state_dict_keys = [n for n, p in self.named_parameters() if p.requires_grad]

        # --- let the client/runtime know this model expects token ids ---
        self.expects_token_ids = True
        self.input_dtype = torch.long
        self.input_kind = "tokens"

    def forward(
        self,
        x: torch.Tensor,
        h0: Optional[torch.Tensor] = None,
        return_h: bool = False,
    ):
        # ---- unwrap ptflops / tuple-style inputs ----
        if isinstance(x, (tuple, list)):
            # could be (x,), (x, h0), or nested
            if len(x) >= 1 and torch.is_tensor(x[0]):
                if len(x) >= 2 and h0 is None and torch.is_tensor(x[1]):
                    h0 = x[1]
                x = x[0]
            elif len(x) == 1 and isinstance(x[0], (tuple, list)) and len(x[0]) >= 1 and torch.is_tensor(x[0][0]):
                inner = x[0]
                if len(inner) >= 2 and h0 is None and torch.is_tensor(inner[1]):
                    h0 = inner[1]
                x = inner[0]

        if x.dim() != 2:
            raise ValueError(f"expected token ids of shape (B, T); got {tuple(x.shape)}")

        B, T = x.shape
        emb = self.embedding(x)

        outs: List[torch.Tensor] = []
        h_all = []

        h_prev = emb
        head_ptr = 0
        for li, gru in enumerate(self.gru_layers):
            h0i = None
            if h0 is not None and h0.dim() == 3 and h0.size(0) >= (li + 1):
                h0i = h0[li:li+1]

            y, h_last = gru(h_prev, h0i)  # y: (B,T,H), h_last: (1,B,H)
            h_all.append(h_last)

            if li in self._exit_layers:
                logits_t = self.exit_heads[head_ptr](y)  # may be (B,T,V) or (B,V)
                if isinstance(logits_t, torch.Tensor):
                    if logits_t.dim() == 3:
                        logits = logits_t[:, -1, :]      # take last token → (B, V)
                    elif logits_t.dim() == 2:
                        logits = logits_t                 # already (B, V)
                    else:
                        raise RuntimeError(f"unexpected head output dim={logits_t.dim()}")
                else:
                    raise RuntimeError("exit head did not return a Tensor")
                outs.append(logits)
                head_ptr += 1

            if li < self.num_layers - 1:
                p = float(getattr(self, "_inter_drop_p", 0.0))
                h_prev = F.dropout(y, p=p, training=self.training) if p > 0 else y
            else:
                h_prev = y

        if return_h:
            h_stack = torch.cat(h_all, dim=0) if len(h_all) > 0 else None

        if self.active_exit is not None:
            out = outs[int(self.active_exit)]
            return (out, h_stack) if return_h else out

        if self.last_exit_only:
            out = outs[-1]
            return (out, h_stack) if return_h else out

        return (outs, h_stack) if return_h else outs

def exit_shakespeare_gru(
    *,
    vocab_size: int,
    embed_dim: Optional[int] = None,
    embedding_dim: Optional[int] = None,   # alias
    hidden_dim: Optional[int] = None,
    hidden_size: Optional[int] = None,     # alias
    num_layers: int = 4,
    dropout: float = 0.1,
    head_dropout: float = 0.0,
    tie_final_weights: Optional[bool] = None,
    tie_weights: Optional[bool] = None,    # alias
    last_exit_only: bool = False,
    pad_idx: int = 0,
    ee_layer_locations: Optional[Iterable[int]] = None,
    blks_to_exit: Optional[Iterable[int]] = None,             # accept strategy arg
    no_of_exits: Optional[int] = None,
    exits_every_k: Optional[int] = None,
    exits_count: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    **kwargs,  # tolerate extra args like width_scale/depth, etc.
) -> GRUEarlyExitLM:
    # normalize synonyms
    E = embed_dim if embed_dim is not None else embedding_dim
    H = hidden_dim if hidden_dim is not None else hidden_size
    if E is None: E = 32
    if H is None: H = 64
    tie = tie_final_weights if tie_final_weights is not None else bool(tie_weights) if tie_weights is not None else False

    net = GRUEarlyExitLM(
        vocab_size=int(vocab_size),
        embed_dim=int(E),
        hidden_dim=int(H),
        num_layers=int(num_layers),
        dropout=float(dropout),
        head_dropout=float(head_dropout),
        tie_final_weights=bool(tie),
        last_exit_only=bool(last_exit_only),
        pad_idx=int(pad_idx),
        ee_layer_locations=ee_layer_locations,
        blks_to_exit=blks_to_exit,               # pass through
        no_of_exits=no_of_exits,
        exits_every_k=exits_every_k,
        exits_count=exits_count, 
        **kwargs,
    )
    if device is not None:
        net = net.to(device)
    return net