# src/models/resnet_exit_template.py
from __future__ import annotations

import copy
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.models.model_utils import prune, Classifier
from ptflops import get_model_complexity_info

# --------------------------------------------------------------------------- #
#  Helper blocks                                                              #
# --------------------------------------------------------------------------- #

class BasicBlock(nn.Module):
    """A standard basic block for ResNet."""
    expansion: int = 1

    def __init__(
        self,
        in_planes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
    ) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass for the basic block."""
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out


def _make_layer(
    block: type, in_planes: int, planes: int, blocks: int, stride: int = 1
) -> nn.Sequential:
    """Helper function to create a stage of ResNet layers."""
    downsample = None
    if stride != 1 or in_planes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(
                in_planes, planes * block.expansion, kernel_size=1, stride=stride, bias=False
            ),
            nn.BatchNorm2d(planes * block.expansion),
        )

    layers: List[nn.Module] = [
        block(in_planes, planes, stride, downsample)  # first block
    ]
    in_planes = planes * block.expansion
    for _ in range(1, blocks):
        layers.append(block(in_planes, planes))
    return nn.Sequential(*layers)


# --------------------------------------------------------------------------- #
#  Main network with early exits                                              #
# --------------------------------------------------------------------------- #
class ResNetEarlyExit(nn.Module):
    """
    A ResNet model that supports early exits, depth truncation, and width scaling.
    This version includes robust logic to handle conflicting configuration arguments.
    """
    def __init__(
        self,
        block: type,
        layers: Sequence[int],
        num_classes: int = 10,
        *,
        depth: Optional[int] = None,
        no_of_exits: Optional[int] = None,
        blks_to_exit: Optional[Sequence[int]] = None,
        width_scale: float = 1.0,
        last_exit_only: bool = False,
        freeze_base: bool = False,
        cifar: bool = True,
    ) -> None:
        super().__init__()

        # ---- stem --------------------------------------------------------------
        self.cifar = cifar
        self.width_scale = float(width_scale)
        self._return_last_only = bool(last_exit_only)
        self.last_exit_only = self._return_last_only

        base_planes_init = 16 if cifar else 64
        self.in_planes = int(base_planes_init * self.width_scale)

        if cifar:
            self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        else:
            self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=7, stride=2, padding=3, bias=False)
            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.bn1 = nn.BatchNorm2d(self.in_planes)
        self.relu = nn.ReLU(inplace=True)

        # ---- build trunk (respect depth) --------------------------------------
        stage_planes = [16, 32, 64] if cifar else [64, 128, 256, 512]
        self.layers: nn.ModuleList = nn.ModuleList()

        max_blocks = int(depth) if depth is not None else int(sum(layers))
        blocks_built = 0
        for s, n_blocks in enumerate(layers):
            if blocks_built >= max_blocks:
                break
            take = min(n_blocks, max_blocks - blocks_built)
            planes = int(stage_planes[s] * self.width_scale)
            stride = 1 if s == 0 else 2
            self.layers.append(_make_layer(block, self.in_planes, planes, take, stride))
            self.in_planes = planes * block.expansion
            blocks_built += take

        # count how many blocks we actually built
        block_dims: List[int] = []
        for stage in self.layers:
            for blk in stage:
                # channel dim after the block is given by its second BN
                if hasattr(blk, "bn2"):
                    block_dims.append(int(blk.bn2.num_features))
                else:
                    block_dims.append(int(self.in_planes))

        self.total_blocks = len(block_dims)

        # ---- exit placement (normalized to *actual* built blocks) --------------
        def _norm_exits(cands: Sequence[int], total: int) -> List[int]:
            out = []
            for b in cands:
                b = int(b)
                if b < 0:
                    b = total + b  # -1 -> last built block
                b = max(0, min(b, total - 1))
                out.append(b)
            return out

        if blks_to_exit is not None:
            self.blks_to_exit = _norm_exits(list(blks_to_exit), self.total_blocks)
        else:
            if no_of_exits and self.total_blocks > 0:
                step = max(1, self.total_blocks // int(no_of_exits))
                self.blks_to_exit = list(range(step - 1, self.total_blocks, step))[: int(no_of_exits)]
            else:
                self.blks_to_exit = [self.total_blocks - 1] if self.total_blocks > 0 else []

        # ensure the final built block has an exit
        if self.total_blocks > 0:
            if not self.blks_to_exit:
                self.blks_to_exit = [self.total_blocks - 1]
            else:
                self.blks_to_exit[-1] = self.total_blocks - 1

        # ---- exit heads --------------------------------------------------------
        self.exit_heads = nn.ModuleList()
        if self.total_blocks > 0 and self.blks_to_exit:
            self.exit_heads = nn.ModuleList([
                Classifier(
                    in_planes=block_dims[b],
                    num_classes=num_classes,
                    reduction=block.expansion,
                    scale=self.width_scale,
                )
                for b in self.blks_to_exit
            ])

        # ---- optional freeze base ---------------------------------------------
        if freeze_base:
            for name, p in self.named_parameters():
                if not name.startswith("exit_heads"):
                    p.requires_grad = False

        # ---- key lists for FL --------------------------------------------------
        # include buffers in "all" (BN running stats), only params in "trainable"
        self.all_state_dict_keys = list(self.state_dict().keys())
        self.trainable_state_dict_keys = [n for n, p in self.named_parameters() if p.requires_grad]
        self.active_exit: Optional[int] = None

    def forward(self, x: torch.Tensor):
        """
        - Computes all attached heads at their blocks.
        - If `self.active_exit` is not None, returns only that head's logits.
        - Else, returns a list of logits for all exits (in order).
        - `last_exit_only` remains a legacy fallback (only used if active_exit is None).
        """
        # Build/cache mapping: block_idx -> [head_indices] (supports duplicates)
        if (not hasattr(self, "_heads_at_block")) or (getattr(self, "_heads_sig", None) != tuple(self.blks_to_exit)):
            heads_at = {}
            for i, b in enumerate(self.blks_to_exit):
                heads_at.setdefault(int(b), []).append(i)
            self._heads_at_block = heads_at
            self._heads_sig = tuple(self.blks_to_exit)

        outs = [None] * len(self.blks_to_exit)

        x = self.relu(self.bn1(self.conv1(x)))
        if not getattr(self, "cifar", True):
            x = self.maxpool(x)

        blk_idx = 0
        for stage in self.layers:
            for blk in stage:
                x = blk(x)
                idxs = self._heads_at_block.get(blk_idx)
                if idxs:
                    for i in idxs:
                        outs[i] = self.exit_heads[i](x)
                blk_idx += 1

        valid = [o for o in outs if o is not None]

        # Prefer explicit active_exit if set
        if self.active_exit is not None:
            i = int(self.active_exit)
            return valid[i]  # assume heads built in order

        # Fallback legacy switch
        last_only = getattr(self, "last_exit_only", getattr(self, "_return_last_only", False))
        if last_only:
            return valid[-1] if valid else None
        return valid
    
def exit_resnet_template(
    *,
    model_name: str = "resnet110",
    depth: Optional[int] = None,
    num_classes: int = 10,
    no_of_exits: Optional[int] = None,
    blks_to_exit: Optional[Sequence[int]] = None,
    ee_layer_locations: Optional[Sequence[int]] = None,
    width_scale: float = 1.0,
    last_exit_only: bool = False,
    freeze_base: bool = False,
    cifar: Optional[bool] = None,
    device: Optional[Union[str, torch.device]] = None,
    **kwargs,
) -> ResNetEarlyExit:
    """
    Builds a ResNetEarlyExit model based on common configurations.
    """
    if blks_to_exit is None and ee_layer_locations is not None:
        blks_to_exit = ee_layer_locations

    if cifar is None:
        cifar = "110" in model_name or "32" in model_name

    if model_name == "resnet18":
        layers = (2, 2, 2, 2)
    elif model_name == "resnet110":
        layers = (18, 18, 18)
    else:
        raise NotImplementedError(f"Model name '{model_name}' is not supported by this template.")

    net = ResNetEarlyExit(
        block=BasicBlock,
        layers=layers,
        num_classes=num_classes,
        depth=depth,
        no_of_exits=no_of_exits,
        blks_to_exit=blks_to_exit,
        width_scale=width_scale,
        last_exit_only=last_exit_only,
        freeze_base=freeze_base,
        cifar=cifar,
    )
    if device is not None:
        net = net.to(device)
    return net
