import torch
import torch.nn as nn

from torchvision.models.resnet import Bottleneck as TVBottleneck, BasicBlock as TVBasicBlock
from modules import BasicBlockCompat

def split_into_k_groups(n, k):
    assert k >= 1 and n >= 1
    base = n // k
    r = n - base * k
    sizes = [(base + 1 if i < r else base) for i in range(k)]
    return sizes

def bottleneck_stride(b) -> int:
    s = getattr(b.conv2, 'stride', (1, 1))
    return s[0] if isinstance(s, tuple) else int(s)

def plan_groups_for_stage(stage_seq, k_groups = 2):
    '''
    Given an RN50 layer (nn.Sequential of Bottlenecks), return a list of groups.
    Each group is a (start_idx, end_idx) over the stage's children (0-based),
    partitioned into k_groups contiguous chunks.
    '''
    blocks = [m for m in stage_seq.children() if isinstance(m, TVBottleneck)]
    n = len(blocks)
    assert n > 0, 'Stage must contain Bottleneck blocks.'
    sizes = split_into_k_groups(n, k_groups)
    groups = []
    p = 0
    for gsz in sizes:
        groups.append((p, p + gsz - 1))
        p += gsz
    return groups

def stage_names():
    return ['layer1', 'layer2', 'layer3', 'layer4']

def make_basic_from_bottleneck_group(first_b, last_b):
    inplanes = first_b.conv1.in_channels
    stride = bottleneck_stride(first_b)
    outplanes = last_b.bn3.num_features
    norm_layer = type(first_b.bn1)
    return BasicBlockCompat(inplanes, outplanes, stride=stride, norm_layer=norm_layer)

def replace_leftmost_k_bottlenecks_with_basic(model, layer_name, k):
    '''
    In `model.<layer_name>` (nn.Sequential), find the first contiguous run of `k`
    Bottleneck blocks and replace them with a single BasicBlockCompat.
    Returns the newly inserted BasicBlockCompat (useful for hooking/optim).
    '''
    seq: nn.Sequential = getattr(model, layer_name)
    mods = list(seq.children())
    # find start of the leftmost Bottleneck run
    start = None
    for i, m in enumerate(mods):
        if isinstance(m, TVBottleneck):
            start = i
            break
    assert start is not None, f'No Bottleneck blocks left to replace in {layer_name}.'

    end = start + k - 1
    assert end < len(mods), f'Requested k={k} exceeds remaining modules in {layer_name}.'
    for j in range(start, end + 1):
        assert isinstance(mods[j], TVBottleneck), f'Non-Bottleneck encountered within replacement span in {layer_name}.'

    first_b, last_b = mods[start], mods[end]
    new_block = make_basic_from_bottleneck_group(first_b, last_b)

    mods[start:end + 1] = [new_block]
    setattr(model, layer_name, nn.Sequential(*mods))
    return new_block

def build_progressive_schedule_and_orig_hooks(orig):
    '''
    Build a global progressive schedule across layer1..layer4.
    For RN18, we target k=2 groups per stage.
    Returns:
      - schedule: list of (layer_name, group_size)
      - orig_hook_modules: list of original *end-of-group* Bottleneck modules (for hooking)
    '''
    schedule = []
    orig_hook_modules = []
    for lname in stage_names():
        stage_seq: nn.Sequential = getattr(orig, lname)
        # sanity: RN50 stage should be Bottlenecks
        assert all(isinstance(m, TVBottleneck) for m in stage_seq.children()), \
            f'{lname} is not a pure Bottleneck sequence.'

        groups = plan_groups_for_stage(stage_seq, k_groups=2)  # RN18 has 2 per stage
        for (s, e) in groups:
            bn_list = list(stage_seq.children())
            schedule.append((lname, e - s + 1))
            orig_hook_modules.append(bn_list[e])   # hook at the group *output*
    return schedule, orig_hook_modules

def split_into_k_groups(n, k):
    base = n // k
    r = n - base * k
    return [(base + 1 if i < r else base) for i in range(k)]

def plan_groups_for_stage(stage_seq, k_groups):
    '''
    For an RN50 stage (nn.Sequential of Bottlenecks), return k contiguous groups
    as (start_idx, end_idx). Used to pair with the 2 RN18 blocks in the stage.
    '''
    blocks = [m for m in stage_seq.children() if isinstance(m, TVBottleneck)]
    n = len(blocks)
    assert n > 0, 'Stage must contain Bottleneck blocks.'
    sizes = split_into_k_groups(n, k_groups)
    out = []
    p = 0
    for gsz in sizes:
        out.append((p, p + gsz - 1))
        p += gsz
    return out

def enumerate_rn18_blocks(student):
    blocks = []
    for lname in stage_names():
        seq = getattr(student, lname)
        for m in seq.children():
            assert isinstance(m, TVBasicBlock), f'{lname} contains non-BasicBlock in RN18'
            blocks.append(m)
    return blocks

def build_alignment_rn50_rn18(teacher, student):
    '''
    Build a 1:1 alignment between:
      - teacher groups (end-of-group Bottlenecks) and
      - student blocks (TVBasicBlock) in stage order.
    Returns: schedule (list of group ids 0..7), teacher_hook_modules, student_hook_modules
    '''
    teacher_hook_modules = []
    student_hook_modules = []
    gid = 0
    for lname in stage_names():
        t_seq = getattr(teacher, lname)
        s_seq = getattr(student, lname)
        groups = plan_groups_for_stage(t_seq, k_groups=2)
        s_blocks = [m for m in s_seq.children()]
        assert len(s_blocks) == 2, f'{lname} in RN18 must have 2 blocks'
        bn_list = list(t_seq.children())
        for g_idx, (start, end) in enumerate(groups):
            teacher_hook_modules.append(bn_list[end])
            student_hook_modules.append(s_blocks[g_idx])
            gid += 1
    schedule = list(range(len(teacher_hook_modules)))
    return schedule, teacher_hook_modules, student_hook_modules