import timm
import torch
import torch.nn as nn

def create_model(name):
    return timm.create_model(name, pretrained=False)

def create_model_EMNIST_Letters():
    model = timm.create_model("resnet18.a1_in1k", pretrained=False)
    model = Resnet_change_tail(model, 1)
    model = Resnet_change_head(model, 37)
    return model

def create_model_FMNIST():
    model = timm.create_model("resnet18.a1_in1k", pretrained=False)
    model = Resnet_change_tail(model, 1)
    model = Resnet_change_head(model, 10)
    return model

def create_model_sketch():
    model = timm.create_model("resnet18.a1_in1k", pretrained=False)
    model = Resnet_change_head(model, 100)
    return model

def get_model_name(path):
    # return path.rsplit(".",1)[0].rsplit("/",1)[1]
    model = path.split("/")[6]
    if ".pth" in model:
        model = model.rsplit(".",1)[0]
    return model

def Resnet_change_tail(model, num_channels):
    if num_channels == 1:
        old_weights = model.conv1.weight.data
        new_weights = old_weights.mean(dim=1, keepdim=True)
        model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        model.conv1.weight.data = new_weights
    elif num_channels == 3:
        pass
    else:
        raise NotImplementedError
    return model

def Resnet_change_head(model, num_cls):
    if num_cls > 1:
        in_features, out_features = model.fc.in_features, model.fc.out_features
        if out_features != num_cls:
            model.fc = nn.Linear(in_features, num_cls)
            print(f"fc changed {in_features}x{out_features} -> {in_features}x{num_cls}")
    return model