import torch.nn as nn
from torchvision import models
from torchvision.models import (
    VGG11_Weights,
    VGG13_Weights,
    VGG16_Weights,
    VGG19_Weights,
    VGG11_BN_Weights,
    VGG13_BN_Weights,
    VGG16_BN_Weights,
    VGG19_BN_Weights,
)
from .base import BaseModel, MLP_Classifier, init_weights

vgg_dict = {
    "vgg11": models.vgg11,
    "vgg13": models.vgg13,
    "vgg16": models.vgg16,
    "vgg19": models.vgg19,
    "vgg11bn": models.vgg11_bn,
    "vgg13bn": models.vgg13_bn,
    "vgg16bn": models.vgg16_bn,
    "vgg19bn": models.vgg19_bn,
}


class VGGBase(nn.Module):
    def __init__(self, vgg_name, input_shape):
        super().__init__()
        self.name = vgg_name
        if vgg_name == "vgg11":
            model_vgg = vgg_dict[vgg_name](weights=VGG11_Weights.DEFAULT)
        elif vgg_name == "vgg13":
            model_vgg = vgg_dict[vgg_name](weights=VGG13_Weights.DEFAULT)
        elif vgg_name == "vgg16":
            model_vgg = vgg_dict[vgg_name](weights=VGG16_Weights.DEFAULT)
        elif vgg_name == "vgg19":
            model_vgg = vgg_dict[vgg_name](weights=VGG19_Weights.DEFAULT)
        elif vgg_name == "vgg11bn":
            model_vgg = vgg_dict[vgg_name](weights=VGG11_BN_Weights.DEFAULT)
        elif vgg_name == "vgg13bn":
            model_vgg = vgg_dict[vgg_name](weights=VGG13_BN_Weights.DEFAULT)
        elif vgg_name == "vgg16bn":
            model_vgg = vgg_dict[vgg_name](weights=VGG16_BN_Weights.DEFAULT)
        elif vgg_name == "vgg19bn":
            model_vgg = vgg_dict[vgg_name](weights=VGG19_BN_Weights.DEFAULT)
        else:
            raise ValueError(f"Unsupported VGG model: {vgg_name}")

        original_conv1 = model_vgg.features[0]
        new_conv1 = nn.Conv2d(
            in_channels=input_shape[0],
            out_channels=original_conv1.out_channels,
            kernel_size=original_conv1.kernel_size,
            stride=original_conv1.stride,
            padding=original_conv1.padding,
            bias=original_conv1.bias is not None,
        )
        model_vgg.features[0] = new_conv1

        features = list(model_vgg.features.children())
        if input_shape[1] < 32:
            pooling_layers_index = []
            for idx, layer in enumerate(features):
                if isinstance(layer, nn.MaxPool2d):
                    pooling_layers_index.append(idx)
            second_last_pooling_index = pooling_layers_index[-2]
            features = features[: second_last_pooling_index + 1]
        features = nn.Sequential(*features)
        features.add_module("Flatten", nn.Flatten())
                
        self.features = features
        self.output_dim = 512
        
    def forward(self, x):
        return self.features(x)

def get_vgg_model(vgg_name, input_shape, num_classes, hidden_dim=1024):
    vgg_name = "vgg11" if vgg_name == "vgg" else vgg_name
    E = VGGBase(vgg_name, input_shape)
    C = MLP_Classifier(E.output_dim, num_classes, hidden_dim)
    name = f"{vgg_name}_mlp{hidden_dim}"
    return BaseModel(E, C.apply(init_weights), name)
