import argparse
from typing import Optional

import torch
import torch.nn as nn

try:
    import torchvision.models as tv_models
except Exception:
    tv_models = None


def create_resnet50(num_classes: int, pretrained: bool = True, freeze_backbone: bool = False) -> nn.Module:
    if tv_models is None:
        raise RuntimeError("torchvision is required to create resnet50.")
    weights = None
    if pretrained and hasattr(tv_models, "ResNet50_Weights"):
        weights = tv_models.ResNet50_Weights.IMAGENET1K_V2
    model = tv_models.resnet50(weights=weights)
    if freeze_backbone:
        for name, param in model.named_parameters():
            if not name.startswith("fc."):
                param.requires_grad = False
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model


def get_resnet50(pretrained: bool = True, num_classes: int = 65, freeze_backbone: bool = False) -> nn.Module:
    """Convenience wrapper for a 65-class ResNet-50 used in Office-Home."""
    return create_resnet50(num_classes=num_classes, pretrained=pretrained, freeze_backbone=freeze_backbone)


def load_checkpoint(model: nn.Module, checkpoint_path: str, map_location: Optional[str] = None) -> nn.Module:
    state = torch.load(checkpoint_path, map_location=map_location or "cpu")
    model.load_state_dict(state)
    return model


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Model factory sanity check")
    parser.add_argument("--num_classes", type=int, default=65)
    parser.add_argument("--pretrained", action="store_true")
    args = parser.parse_args()
    model = get_resnet50(pretrained=args.pretrained, num_classes=args.num_classes)
    print(model.__class__.__name__, "created with", args.num_classes, "classes")
