import torch
import torch.nn as nn
import torchvision.models as models
import types


class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-4):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = x.norm(2, dim=-1, keepdim=True)
        rms = norm / (x.shape[-1] ** 0.5)
        return self.scale * (x / (rms + self.eps))


class VGG16_MNIST(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG16_MNIST, self).__init__()

        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(1, 64, kernel_size=3, padding=1),  # Changed in_channels to 1
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 28 → 14

            # Block 2
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 14 → 7

            # Block 3 (trimmed for MNIST)
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=1),  # 7 → ~4
        )

        self.classifier = nn.Sequential(
            nn.Linear(256 * 4 * 4, 1000),
            nn.ReLU(inplace=True),
            nn.Linear(1000, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


# Function to create ResNet18 model
def create_resnet18(input_channels, num_classes, layer_norm="rms"):
    """Create and customize ResNet18 model with normalization layer."""
    model = models.resnet18(pretrained=False)
    model.conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)

    # Store the original fc for feature extraction
    in_features = model.fc.in_features

    # Add normalization layer based on argument
    if layer_norm == "none":
        model.norm_layer = nn.Identity()
    elif layer_norm == "standard":
        model.norm_layer = nn.LayerNorm(in_features)
    elif layer_norm == "rms":
        model.norm_layer = RMSNorm(in_features)
    else:
        raise ValueError(f"Unknown layer_norm option: {layer_norm}")

    model.fc1 = nn.Linear(in_features, 1000, bias=False)
    model.relu2 = nn.ReLU(True)
    model.fc2 = nn.Linear(1000, num_classes, bias=False)

    # Define proper forward method for feature extraction
    def new_forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.norm_layer(x)  # Apply normalization
        features = self.relu2(self.fc1(x))
        outputs = self.fc2(features)
        return outputs, features

    # Add method to get fc2 weights
    def get_classifier_weights(self):
        """Returns the current weights of the fc2 layer."""
        return self.fc2.weight.data

    # Replace the forward method
    model.forward = types.MethodType(new_forward, model)

    # Add the new method for getting fc2 weights
    model.get_classifier_weights = types.MethodType(get_classifier_weights, model)

    return model


# Function to create ResNet50 model
def create_resnet50(input_channels, num_classes, layer_norm="rms"):
    """Create and customize ResNet50 model with normalization layer."""
    model = models.resnet50(pretrained=False)
    model.conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)

    # Store the original fc for feature extraction
    in_features = model.fc.in_features

    # Add normalization layer based on argument
    if layer_norm == "none":
        model.norm_layer = nn.Identity()
    elif layer_norm == "standard":
        model.norm_layer = nn.LayerNorm(in_features)
    elif layer_norm == "rms":
        model.norm_layer = RMSNorm(in_features)
    else:
        raise ValueError(f"Unknown layer_norm option: {layer_norm}")

    model.fc1 = nn.Linear(in_features, 1000, bias=False)
    model.relu2 = nn.ReLU(True)
    model.fc2 = nn.Linear(1000, num_classes, bias=False)

    # Define proper forward method for feature extraction
    def new_forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.norm_layer(x)  # Apply normalization
        features = self.relu2(self.fc1(x))
        outputs = self.fc2(features)
        return outputs, features

    # Add method to get fc2 weights
    def get_classifier_weights(self):
        """Returns the current weights of the fc2 layer."""
        return self.fc2.weight.data

    # Replace the forward method
    model.forward = types.MethodType(new_forward, model)
    # Add the new method for getting fc2 weights
    model.get_classifier_weights = types.MethodType(get_classifier_weights, model)

    return model


# Function to create VGG16 model
def create_vgg16(input_channels, num_classes, layer_norm="rms"):
    """Create and customize VGG16 model with normalization layer."""
    model = VGG16_MNIST()

    with torch.no_grad():
        dummy_input = torch.zeros(1, 1, 28, 28)
        dummy_out = model.features(dummy_input)
        flattened_size = dummy_out.view(1, -1).shape[1]

    # Replace classifier to get features
    in_features = flattened_size

    # Add normalization layer based on argument
    if layer_norm == "none":
        model.norm_layer = nn.Identity()
    elif layer_norm == "standard":
        model.norm_layer = nn.LayerNorm(in_features)
    elif layer_norm == "rms":
        model.norm_layer = RMSNorm(in_features)
    else:
        raise ValueError(f"Unknown layer_norm option: {layer_norm}")

    model.classifier = nn.Sequential(
        model.norm_layer,  # Apply normalization
        nn.Linear(in_features, 1000, bias=False),  # Feature dimension
        nn.ReLU(True),
        nn.Linear(1000, num_classes, bias=False)  # Output layer
    )

    # Define proper forward method for feature extraction
    def new_forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        # Get features from the second-to-last layer
        x = self.classifier[0](x)  # Apply normalization
        x = self.classifier[1](x)
        features = self.classifier[2](x)  # Apply ReLU
        outputs = self.classifier[3](x)
        return outputs, features

    def get_classifier_weights(self):
        """Returns the current weights of the fc2 layer."""
        return self.classifier[3].weight.data

    # Replace the forward method
    model.forward = types.MethodType(new_forward, model)
    model.get_classifier_weights = types.MethodType(get_classifier_weights, model)

    return model

# Function to select model
def get_model(name, input_channels, num_classes, device, weight_init_variance=None, layer_norm="rms"):
    """Get model based on name."""
    if name == 'resnet18':
        model = create_resnet18(input_channels, num_classes, layer_norm)
        # custom_init(model, std=0.1)
    elif name == 'resnet50':
        model = create_resnet50(input_channels, num_classes, layer_norm)
    elif name == 'vgg16':
        model = create_vgg16(input_channels, num_classes, layer_norm)
    else:
        raise ValueError(f"Unknown model name: {name}")

    return model.to(device)
