import torch
from torch import Tensor
import torch.nn as nn
from typing import Optional, Callable


class ConvNet(nn.Module):
    """Simple 3-layer ConvNet with configurable output classes"""
    
    def __init__(
        self,
        num_classes: int = 10,
        input_channels: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        **kwargs
    ) -> None:
        super(ConvNet, self).__init__()
        
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
            
        # Layer 1: Conv + BN + ReLU + MaxPool
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = norm_layer(64)
        self.relu1 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Layer 2: Conv + BN + ReLU + MaxPool  
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = norm_layer(128)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Layer 3: Conv + BN + ReLU + MaxPool
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn3 = norm_layer(256)
        self.relu3 = nn.ReLU(inplace=True)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Global Average Pooling + FC (num_classes is configurable)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.num_features = 256
        
        # num_classifier 설정: kwargs에서 가져오거나 num_classes 사용
        num_classifier = kwargs.get('num_classifier', num_classes)
        self.fc = nn.Linear(256, num_classifier)  # num_classifier 사용 (수정)
        
        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x, only_fc=False, only_feat=False, **kwargs):
        """
        Args:
            x: input tensor, depends on only_fc and only_feat flag
            only_fc: only use classifier, input should be features before classifier
            only_feat: only return pooled features
        """
        
        if only_fc:
            return self.fc(x)
        
        x = self.extract(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        
        if only_feat:
            return x
            
        out = self.fc(x)
        # print(out.shape)
        result_dict = {'logits': out, 'feat': x}
        return result_dict
    
    def extract(self, x):
        # Layer 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        
        # Layer 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
        
        # Layer 3
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.maxpool3(x)
        
        return x
    
    def group_matcher(self, coarse=False, prefix=''):
        """For compatibility with existing training code"""
        matcher = dict(
            stem=r'^{}conv1|^{}bn1'.format(prefix, prefix),
            blocks=r'^{}conv[23]|^{}bn[23]'.format(prefix, prefix)
        )
        return matcher
    
    def no_weight_decay(self):
        """Return parameters that should not have weight decay"""
        nwd = []
        for n, _ in self.named_parameters():
            if 'bn' in n or 'bias' in n:
                nwd.append(n)
        return nwd


def conv3net(pretrained=False, pretrained_path=None, **kwargs):
    """Create a 3-layer ConvNet with num_classes outputs"""
    kwargs['num_classifier'] = kwargs.get('num_classes', 10)  # 오타 수정
    model = ConvNet(**kwargs)
    return model

def conv3net_sco(pretrained=False, pretrained_path=None, **kwargs):
    """Create a 3-layer ConvNet with num_classes+1 outputs for SCOMatch OOD detection"""
    # num_classes를 받아서 +1 해서 전달
    if 'num_classes' in kwargs:
        kwargs['num_classifier'] = kwargs['num_classes'] + 1
    else:
        kwargs['num_classifier'] = 11  # 기본값 10 + 1
    
    model = ConvNet(**kwargs)
    return model
