import timm
import torch
import torch.nn as nn

VIT_models = {
    'vit_base_patch16_224.augreg_in21k_ft_in1k': './Models/VIT/p16_augreg_in21k_ft_in1k/vit_base_patch16_224.augreg_in21k_ft_in1k.pt',
    'vit_base_patch16_siglip_224':    './Models/VIT/siglip_224/vit_base_patch16_siglip_224.pt',
    'vit_base_patch32_224.augreg_in21k_ft_in1k':    './Models/VIT/p32_augreg_in21k_ft_in1k/vit_base_patch32_224.augreg_in21k_ft_in1k.pt',
    'vit_base_patch32_224.sam_in1k':    './Models/VIT/sam_in1k/vit_base_patch32_224.sam_in1k.pt',
    'vit_base_r50_s16_224.orig_in21k':    './Models/VIT/orig_in21k/vit_base_r50_s16_224.orig_in21k.pt',
    'vit_base_patch8_224.augreg2_in21k_ft_in1k':    './Models/VIT/augreg2_in21k_ft_in1k/vit_base_patch8_224.augreg2_in21k_ft_in1k.pt',
    'vit_base_patch16_224.mae':    './Models/VIT/mae/vit_base_patch16_224.mae.pt',
    'vit_base_patch32_clip_224.openai_ft_in1k':    './Models/VIT/openai_ft_in1k/vit_base_patch32_clip_224.openai_ft_in1k.pt',
    'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k':    './Models/VIT/laion2b_ft_in12k_in1k/vit_base_patch32_clip_224.laion2b_ft_in12k_in1k.pt',
    "vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k":'./Models/VIT/r_in21k_ft_in1k/vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k.pt',
    'vit_tiny_r_s16_p8_224.augreg_in21k':'./Models/VIT/r_in21k/vit_tiny_r_s16_p8_224.augreg_in21k.pt',
    'vit_tiny_patch16_224.augreg_in21k':'./Models/VIT/in21k/vit_tiny_patch16_224.augreg_in21k.pt',
    'vit_tiny_patch16_224.augreg_in21k_ft_in1k':'./Models/VIT/in21k_ft_in1k/vit_tiny_patch16_224.augreg_in21k_ft_in1k.pt',
}

def create_vit_model(name):
    return torch.load(VIT_models[name])

def get_vit_name(path):
    model = path.split("/")[6]
    if ".pt" in model:
        model = model.rsplit(".",1)[0]
    return model

def freeze_vit(model):
    for param in model.parameters():
        param.requires_grad = False
    for param in model.head.parameters():
        param.requires_grad = True
    print("Vit model freezed")
    return model

def vit_change_head(model, num_cls):
    if num_cls > 1:
        in_features, out_features = model.head.in_features, model.head.out_features
        if out_features != num_cls:
            model.head = nn.Linear(in_features, num_cls)
            print(f"head changed {in_features}x{out_features} -> {in_features}x{num_cls}")
    return model