import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import re

__all__ = ['get_raresnet_model']


class SEBlock(nn.Module):
    """Squeeze-and-Excitation Block"""
    def __init__(self, in_planes, planes, reduction=16):
        super(SEBlock, self).__init__()
        self.fc1 = nn.Conv2d(planes, planes // reduction, kernel_size=1)
        self.fc2 = nn.Conv2d(planes // reduction, planes, kernel_size=1)

    def forward(self, x):
        scale = F.adaptive_avg_pool2d(x, 1)
        scale = F.relu(self.fc1(scale))
        scale = torch.sigmoid(self.fc2(scale))
        return x * scale

class BottleneckUnit(nn.Module):

    def __init__(self, in_planes, planes, stride=1, expansion=4, use_bn_on_a=True):
        super(BottleneckUnit, self).__init__()
        out_planes = planes * expansion
        
        # --- F.a (1x1 Conv) ---
        layers_a = [nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)]
        if use_bn_on_a:
            layers_a.append(nn.BatchNorm2d(planes))
        layers_a.append(nn.ReLU(inplace=True))
        self.a = nn.Sequential(*layers_a)
        
        # --- F.b (3x3 Conv) ---
        self.b = nn.Sequential(
            nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace=True)
        )
        
        # --- F.c (1x1 Conv) ---
        self.c = nn.Sequential(
            nn.Conv2d(planes, out_planes, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_planes)
        )
        
        # --- F.se ---
        self.se = SEBlock(out_planes, out_planes)
        
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.a(x)
        out = self.b(out)
        out = self.c(out)
        out = self.se(out)
        return out

class CustomBlock(nn.Module):
    def __init__(self, in_planes, planes, stride=1, expansion=4, downsample=None, use_bn_on_a=True):
        super(CustomBlock, self).__init__()
        self.proj = downsample
        self.F = BottleneckUnit(in_planes, planes, stride, expansion, use_bn_on_a=use_bn_on_a)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        out = self.F(x)
        if self.proj is not None:
            identity = self.proj(x)
        out += identity
        out = self.relu(out)
        return out

class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        self.register_buffer('mean', torch.tensor(mean).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor(std).view(1, 3, 1, 1))

    def forward(self, x):
        return (x - self.mean) / self.std

class CustomRaResNet(nn.Module):
    def __init__(self, layers, use_bn_on_a=True, num_classes=1000):
        super(CustomRaResNet, self).__init__()
        
        self.normalization = Normalization([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

        self.stem = nn.Sequential(
            nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
            )
        )

        self.stages = nn.ModuleDict()
        self.in_planes = 64
        
        for i, num_blocks in enumerate(layers):
            stage_name = f'stage{i+1}'
            stride = 1 if i == 0 else 2
            planes = 64 * (2**i)
            self.stages[stage_name] = self._make_stage(stage_name, planes, num_blocks, stride, use_bn_on_a)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * 4, num_classes)

    def _make_stage(self, stage_name, planes, num_blocks, stride, use_bn_on_a):
        expansion = 4
        blocks = OrderedDict()
        
        downsample = None
        if stride != 1 or self.in_planes != planes * expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_planes, planes * expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * expansion),
            )

        blocks[f'{stage_name}-block0'] = CustomBlock(self.in_planes, planes, stride, expansion, downsample, use_bn_on_a)
        self.in_planes = planes * expansion
        
        for i in range(1, num_blocks):
            blocks[f'{stage_name}-block{i}'] = CustomBlock(self.in_planes, planes, stride=1, expansion=expansion, use_bn_on_a=use_bn_on_a)
            
        return nn.Sequential(blocks)

    def forward(self, x):
        x = self.normalization(x)
        x = self.stem(x)
        x = self.stages['stage1'](x)
        x = self.stages['stage2'](x)
        x = self.stages['stage3'](x)
        x = self.stages['stage4'](x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


def analyze_structure(state_dict):

    layers = [0, 0, 0, 0]
    use_bn_on_a = True
    
    for k in state_dict.keys():
        if 'stage' in k and 'block' in k:
            
            match = re.search(r'stage(\d+)\.stage\1-block(\d+)', k)
            if match:
                s_idx = int(match.group(1)) - 1 # 0-based index
                b_idx = int(match.group(2))
                if s_idx < 4:
                    if b_idx + 1 > layers[s_idx]:
                        layers[s_idx] = b_idx + 1
    
    bn_key = 'stages.stage1.stage1-block0.F.a.1.weight'
    has_bn = False
    for k in state_dict.keys():
        if 'block0.F.a.1.weight' in k:
            has_bn = True
            break
    
    if not has_bn:
        has_seq = False
        for k in state_dict.keys():
            if 'block0.F.a.0.weight' in k:
                has_seq = True
                break
        if has_seq and not has_bn:
            print("Detected structure: F.a has NO BatchNorm.")
            use_bn_on_a = False
        elif not has_seq:
            print("Detected structure: F.a is single Conv (No BN).")
            use_bn_on_a = False
            
    print(f"Detected Layers Config: {layers}")
    return layers, use_bn_on_a

def clean_state_dict(state_dict):

    new_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        
        if name.startswith('module.'):
            name = name[7:]
            
        if 'stem.stem.stem.' in name:
            name = name.replace('stem.stem.stem.', 'stem.0.')
        
        new_dict[name] = v
    return new_dict

def get_raresnet_model(checkpoint_path, num_classes=1000, device='cpu'):
    
    print(f"Loading checkpoint from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
    
    clean_dict = clean_state_dict(state_dict)
    
    layers, use_bn_on_a = analyze_structure(clean_dict)
    
    if sum(layers) == 0:
        print("Warning: Could not detect layers, defaulting to ResNet-50 [3, 4, 6, 3]")
        layers = [3, 4, 6, 3]
        
    model = CustomRaResNet(layers, use_bn_on_a=use_bn_on_a, num_classes=num_classes)
    
    msg = model.load_state_dict(clean_dict, strict=False)
    print("Load status:", msg)
    
    model.to(device)
    model.eval()
    return model
