import torch.nn as nn
from torchvision import models
from torchvision.models import (
    ResNet18_Weights,
    ResNet34_Weights,
    ResNet50_Weights,
    ResNet101_Weights,
    ResNet152_Weights,
    ResNeXt50_32X4D_Weights,
    ResNeXt101_32X8D_Weights,
)
from .base import BaseModel, MLP_Classifier, init_weights

res_dict = {
    "resnet18": models.resnet18,
    "resnet34": models.resnet34,
    "resnet50": models.resnet50,
    "resnet101": models.resnet101,
    "resnet152": models.resnet152,
    "resnext50": models.resnext50_32x4d,
    "resnext101": models.resnext101_32x8d,
}


class ResBase(nn.Module):
    def __init__(self, res_name, input_channel):
        super().__init__()
        if res_name == "resnet18":
            model_resnet = res_dict[res_name](weights=ResNet18_Weights.DEFAULT)
        elif res_name == "resnet34":
            model_resnet = res_dict[res_name](weights=ResNet34_Weights.DEFAULT)
        elif res_name == "resnet50":
            model_resnet = res_dict[res_name](weights=ResNet50_Weights.DEFAULT)
        elif res_name == "resnet101":
            model_resnet = res_dict[res_name](weights=ResNet101_Weights.DEFAULT)
        elif res_name == "resnet152":
            model_resnet = res_dict[res_name](weights=ResNet152_Weights.DEFAULT)
        elif res_name == "resnext50":
            model_resnet = res_dict[res_name](weights=ResNeXt50_32X4D_Weights.DEFAULT)
        elif res_name == "resnext101":
            model_resnet = res_dict[res_name](weights=ResNeXt101_32X8D_Weights.DEFAULT)
        else:
            raise ValueError(f"Unsupported ResNet model: {res_name}")

        conv1 = nn.Conv2d(
            input_channel,
            model_resnet.conv1.out_channels,
            kernel_size=model_resnet.conv1.kernel_size,
            stride=model_resnet.conv1.stride,
            padding=model_resnet.conv1.padding,
            bias=False,
        )
        conv1.weight.data[:, : min(input_channel, 3), :, :] = (
            model_resnet.conv1.weight.data[:, :input_channel, :, :]
        )
        bn1 = model_resnet.bn1
        relu = model_resnet.relu
        maxpool = model_resnet.maxpool
        layer1 = model_resnet.layer1
        layer2 = model_resnet.layer2
        layer3 = model_resnet.layer3
        layer4 = model_resnet.layer4
        avgpool = model_resnet.avgpool

        self.output_dim = model_resnet.fc.in_features

        self.backbone = nn.Sequential(
            conv1,
            bn1,
            relu,
            maxpool,
            layer1,
            layer2,
            layer3,
            layer4,
            avgpool,
            nn.Flatten(),
        )
    
    def forward(self, x):
        return self.backbone(x)

def get_res_model(res_name, input_channel, num_classes, hidden_dim=1024):
    res_name = "resnet18" if res_name == "resnet" else res_name
    E = ResBase(res_name, input_channel=input_channel)
    C = MLP_Classifier(E.output_dim, num_classes, hidden_dim)
    name = f"{res_name}_mlp{hidden_dim}"
    return BaseModel(E, C.apply(init_weights), name)
