import torch
import torch.nn as nn
from torchvision.models import resnet18, vgg16
import timm

class CustomModel(nn.Module):
    def __init__(self, model_name='resnet18', num_classes=10, pretrained=False, model_path=None):
        super(CustomModel, self).__init__()
        self.model_name = model_name.lower()
        self.num_classes = num_classes
        self.pretrained = pretrained

        self.initialize_model()
        
        if model_path:
            self.load_custom_weights(model_path)

    def initialize_model(self):
        if self.model_name == 'resnet18':
            self.model = resnet18(pretrained=self.pretrained)
            if self.model.fc.out_features != self.num_classes:
                self.model.fc = nn.Linear(self.model.fc.in_features, self.num_classes)
        elif self.model_name == 'vgg':
            self.model = vgg16(pretrained=self.pretrained)
            if self.model.classifier[6].out_features != self.num_classes:
                self.model.classifier[6] = nn.Linear(self.model.classifier[6].in_features, self.num_classes)
        elif self.model_name == 'vit':
            self.model = timm.create_model('vit_base_patch16_224', pretrained=self.pretrained)
            if self.model.head.out_features != self.num_classes:
                self.model.head = nn.Linear(self.model.head.in_features, self.num_classes)
        else:
            raise ValueError(f"Unsupported model type: {self.model_name}")

    def load_custom_weights(self, model_path):
        state_dict = torch.load(model_path)
        self.model.load_state_dict(state_dict)

    def forward(self, x):
        return self.model(x)