from typing import Dict
from argparse import ArgumentParser
import torch
from dataset import SubImageFolder
from utils.eval_utils import cmc_evaluate
from torchvision.models import resnet18, ResNet18_Weights, resnet50, ResNet50_Weights, alexnet, AlexNet_Weights
from torchvision.models import resnet152, ResNet152_Weights, RegNet_X_3_2GF_Weights, regnet_x_3_2gf
from torchvision.models import vgg16, VGG16_Weights, MaxVit_T_Weights, maxvit_t
from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

def main(args):
    device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

    data = SubImageFolder(name="imagenet",
                        data_root="/home/DATASETS/ImageNet", 
                        num_classes=1000,
                        num_workers= 8,
                        batch_size=64)

    #### GALLERY MODELS ####
    if args.gallery_model == 'alexnet':
        gallery_model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1).eval()
        gallery_model = create_feature_extractor(gallery_model, return_nodes={'flatten': 'flatten', 'classifier.6': 'fc'})

    if args.gallery_model == 'resnet50':
        gallery_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1).eval()
        gallery_model = create_feature_extractor(gallery_model, return_nodes={'flatten': 'flatten', 'fc': 'fc'})

    if args.gallery_model == 'regnet_x_3_2gf':
        gallery_model = regnet_x_3_2gf(weights=RegNet_X_3_2GF_Weights.IMAGENET1K_V2).eval()
        gallery_model = create_feature_extractor(gallery_model, return_nodes={'flatten': 'flatten', 'fc': 'fc'})

    if args.gallery_model == 'resnet152':
        gallery_model = resnet152(weights=ResNet152_Weights.IMAGENET1K_V2).eval()
        gallery_model = create_feature_extractor(gallery_model, return_nodes={'flatten': 'flatten', 'fc': 'fc'})

    if args.gallery_model == 'maxvit_t':
        gallery_model = maxvit_t(weights=MaxVit_T_Weights.IMAGENET1K_V1).eval()
        gallery_model = create_feature_extractor(gallery_model, return_nodes={'classifier.1': 'flatten', 'classifier.5': 'fc'})

    #### QUERY MODELS ####

    if args.query_model == 'alexnet':
        query_model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1).eval()
        query_model = create_feature_extractor(query_model, return_nodes={'flatten': 'flatten', 'classifier.6': 'fc'})
                                                           
    if args.query_model == 'resnet50':
        query_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1).eval()
        query_model = create_feature_extractor(query_model, return_nodes={'flatten': 'flatten', 'fc': 'fc'})

    if args.query_model == 'regnet_x_3_2gf':
        query_model = regnet_x_3_2gf(weights=RegNet_X_3_2GF_Weights.IMAGENET1K_V2).eval()
        query_model = create_feature_extractor(query_model, return_nodes={'flatten': 'flatten', 'fc': 'fc'})

    elif args.query_model == 'resnet152':
        query_model = resnet152(weights=ResNet152_Weights.IMAGENET1K_V2).eval()
        query_model = create_feature_extractor(query_model, return_nodes={'flatten': 'flatten', 'fc': 'fc'})

    if args.query_model == 'maxvit_t':
        query_model = maxvit_t(weights=MaxVit_T_Weights.IMAGENET1K_V1).eval()
        print(get_graph_node_names(query_model))
        query_model = create_feature_extractor(query_model, return_nodes={'classifier.1': 'flatten', 'classifier.5': 'fc'})

    cmc_out, mean_ap_out = cmc_evaluate(args,
        gallery_model,
        query_model,
        data.val_loader,
        device,
        distance_metric='cosine'
    )

    print('CMC Top-1 = {}, CMC Top-5 = {}'.format(*cmc_out))

    if args.softmax:
        name_txt = "./output_softmax"+"_G_" + args.gallery_model + "_Q_" + args.query_model + ".txt"
    if args.logits:
        name_txt = "./output_logit" + "_G_" + args.gallery_model+"_Q_" + args.query_model + ".txt"
    if args.penultimate_layer:
        name_txt = "./output_penultimate"+"_G_"+args.gallery_model+"_Q_"+args.query_model+".txt"

    with open(name_txt, 'w') as file:
        file.write('CMC Top-1 = {}, CMC Top-5 = {}\n'.format(*cmc_out))
        file.write('mAP = {}'.format(mean_ap_out))


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--query_model', type=str, required=True,
                        help='Query model.')
    parser.add_argument('--gallery_model', type=str, required=True,
                        help='Gallery model.')
    parser.add_argument('--logits', action='store_true', help='Use logits as features.')
    parser.add_argument('--softmax', action='store_true', help='Use softmax output as features.')
    parser.add_argument('--penultimate_layer', action='store_true', help='Use penultimate layer as features.')
    args = parser.parse_args()
    main(args)
