import torch
import torchvision
import torchvision.models as models
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
import timm

class TimmWrapper(torch.nn.Module):
    def __init__(self, model):
        super(TimmWrapper, self).__init__()
        self.model = model
    
    def forward(self, x: torch.Tensor):
        return {'feature': self.model(x)}

def get_feature_extractor(name):
    if name[:4] != 'timm':
        name = name.split('.')
    print (name)
    if name[:4] == 'timm':
        model = TimmWrapper(timm.create_model(
            name[5:],
            pretrained=True,
            num_classes=0,  # remove classifier nn.Linear
        ))
        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)


    elif name[0] == 'ResNet18_Weights':
        model = models.resnet18(weights = name[1])

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'EfficientNet_B0_Weights':
        model = models.efficientnet_b0(weights = name[1])

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'GoogLeNet_Weights':
        model = models.googlenet(weights = name[1])

        model = create_feature_extractor(model, {'dropout':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'Swin_T_Weights':
        model = models.swin_t(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'MobileNet_V3_Large_Weights':
        model = models.mobilenet_v3_large(weights = name[1])

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'MobileNet_V3_Small_Weights':
        model = models.mobilenet_v3_small(weights = name[1])

        model = create_feature_extractor(model, {'flatten':'feature'})


        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'MNASNet0_5_Weights':
        model = models.mnasnet0_5(weights = name[1])

        model = create_feature_extractor(model, {'mean':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'ShuffleNet_V2_X0_5_Weights':
        model = models.shufflenet_v2_x0_5(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'mean':'feature'})


        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'AlexNet_Weights':
        model = models.alexnet(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    elif name[0] == 'ConvNeXt_Tiny_Weights':
        model = models.convnext_tiny(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'classifier.1':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    elif name[0] == 'ConvNeXt_Small_Weights':
        model = models.convnext_small(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'classifier.1':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    elif name[0] == 'DenseNet121_Weights':
        model = models.densenet121(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'DenseNet161_Weights':
        model = models.densenet161(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'DenseNet169_Weights':
        model = models.densenet169(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'DenseNet201_Weights':
        model = models.densenet201(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'EfficientNet_B1_Weights':
        model = models.efficientnet_b1(weights = name[1])

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'EfficientNet_B2_Weights':
        model = models.efficientnet_b2(weights = name[1])

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'EfficientNet_B3_Weights':
        model = models.efficientnet_b3(weights = name[1])

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'EfficientNet_B4_Weights':
        model = models.efficientnet_b4(weights = name[1])

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'EfficientNet_V2_S_Weights':
        model = models.efficientnet_v2_s(weights = name[1])

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'Inception_V3_Weights':
        model = models.inception_v3(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'MNASNet0_75_Weights':
        model = models.mnasnet0_75(weights = name[1])

        model = create_feature_extractor(model, {'mean':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'MNASNet1_0_Weights':
        model = models.mnasnet1_0(weights = name[1])

        model = create_feature_extractor(model, {'mean':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'MNASNet1_3_Weights':
        model = models.mnasnet1_3(weights = name[1])

        model = create_feature_extractor(model, {'mean':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'MaxVit_T_Weights':
        model = models.maxvit_t(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'classifier.1':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))

        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'MobileNet_V2_Weights':
        model = models.mobilenet_v2(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'RegNet_X_1_6GF_Weights':
        model = models.regnet_x_1_6gf(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'RegNet_X_3_2GF_Weights':
        model = models.regnet_x_3_2gf(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'RegNet_X_400MF_Weights':
        model = models.regnet_x_400mf(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'RegNet_X_800MF_Weights':
        model = models.regnet_x_800mf(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1) 

    elif name[0] == 'RegNet_Y_1_6GF_Weights':
        model = models.regnet_y_1_6gf(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'RegNet_Y_3_2GF_Weights':
        model = models.regnet_y_3_2gf(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'RegNet_Y_400MF_Weights':
        model = models.regnet_y_400mf(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'RegNet_Y_800MF_Weights':
        model = models.regnet_y_800mf(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'ResNeXt50_32X4D_Weights':
        model = models.resnext50_32x4d(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'ResNet34_Weights':
        model = models.resnet34(weights = name[1])

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'ResNet50_Weights':
        model = models.resnet50(weights = name[1])

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'ResNet101_Weights':
        model = models.resnet101(weights = name[1])

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'ShuffleNet_V2_X1_0_Weights':
        model = models.shufflenet_v2_x1_0(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'mean':'feature'})


        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'ShuffleNet_V2_X1_5_Weights':
        model = models.shufflenet_v2_x1_5(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'mean':'feature'})


        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'ShuffleNet_V2_X2_0_Weights':
        model = models.shufflenet_v2_x2_0(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'mean':'feature'})


        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'SqueezeNet1_0_Weights':
        model = models.squeezenet1_0(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'features':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        print (out['feature'].size())
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'SqueezeNet1_1_Weights':
        model = models.squeezenet1_0(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'features':'feature'})


        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'Swin_V2_T_Weights':
        model = models.swin_v2_t(weights = name[1])

        #print (get_graph_node_names(model))

        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'ViT_B_32_Weights':
        model = models.vit_b_32(weights = name[1])

        #print (get_graph_node_names(model))
        model = create_feature_extractor(model, {'getitem_5':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)
    
    elif name[0] == 'ViT_B_16_Weights':
        model = models.vit_b_16(weights = name[1])

        #print (get_graph_node_names(model))
        model = create_feature_extractor(model, {'getitem_5':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    elif name[0] == 'Wide_ResNet50_2_Weights':
        model = models.wide_resnet50_2(weights = name[1])

        #print (get_graph_node_names(model))
        
        model = create_feature_extractor(model, {'flatten':'feature'})

        model.eval()
        out = model(torch.rand(1, 3, 224, 224))
        model.output_dim = out['feature'].size(1)

    else:
        raise Exception(name[0] + ' is not supported')
    
    return model


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--pretrained_model', type=str, default = 'ResNet18_Weights.IMAGENET1K_V1')

    args = parser.parse_args()

    model = get_feature_extractor(name=args.pretrained_model)
    print(model.output_dim)
