import torch
from ofa.utils import download_url
from ofa.model_zoo import ofa_specialized


test_models = [
"flops@595M_top1@80.0_finetune@75",
"flops@482M_top1@79.6_finetune@75",
"flops@389M_top1@79.1_finetune@75",

"LG-G8_lat@24ms_top1@76.4_finetune@25",
"LG-G8_lat@16ms_top1@74.7_finetune@25",
"LG-G8_lat@11ms_top1@73.0_finetune@25",
"LG-G8_lat@8ms_top1@71.1_finetune@25",

"s7edge_lat@88ms_top1@76.3_finetune@25",
"s7edge_lat@58ms_top1@74.7_finetune@25",
"s7edge_lat@41ms_top1@73.1_finetune@25",
"s7edge_lat@29ms_top1@70.5_finetune@25",

"note8_lat@65ms_top1@76.1_finetune@25",
"note8_lat@49ms_top1@74.9_finetune@25",
"note8_lat@31ms_top1@72.8_finetune@25",
"note8_lat@22ms_top1@70.4_finetune@25",

"note10_lat@64ms_top1@80.2_finetune@75",
"note10_lat@50ms_top1@79.7_finetune@75",
"note10_lat@41ms_top1@79.3_finetune@75",
"note10_lat@30ms_top1@78.4_finetune@75",
"note10_lat@22ms_top1@76.6_finetune@25",
"note10_lat@16ms_top1@75.5_finetune@25",
"note10_lat@11ms_top1@73.6_finetune@25",
"note10_lat@8ms_top1@71.4_finetune@25",

"pixel1_lat@143ms_top1@80.1_finetune@75",
"pixel1_lat@132ms_top1@79.8_finetune@75",
"pixel1_lat@79ms_top1@78.7_finetune@75",
"pixel1_lat@58ms_top1@76.9_finetune@75",
"pixel1_lat@40ms_top1@74.9_finetune@25",
"pixel1_lat@28ms_top1@73.3_finetune@25",
"pixel1_lat@20ms_top1@71.4_finetune@25",

"pixel2_lat@62ms_top1@75.8_finetune@25",
"pixel2_lat@50ms_top1@74.7_finetune@25",
"pixel2_lat@35ms_top1@73.4_finetune@25",
"pixel2_lat@25ms_top1@71.5_finetune@25",

"1080ti_gpu64@27ms_top1@76.4_finetune@25",
"1080ti_gpu64@22ms_top1@75.3_finetune@25",
"1080ti_gpu64@15ms_top1@73.8_finetune@25",
"1080ti_gpu64@12ms_top1@72.6_finetune@25",

"v100_gpu64@11ms_top1@76.1_finetune@25",
"v100_gpu64@9ms_top1@75.3_finetune@25",
"v100_gpu64@6ms_top1@73.0_finetune@25",
"v100_gpu64@5ms_top1@71.6_finetune@25",

"tx2_gpu16@96ms_top1@75.8_finetune@25",
"tx2_gpu16@80ms_top1@75.4_finetune@25",
"tx2_gpu16@47ms_top1@72.9_finetune@25",
"tx2_gpu16@35ms_top1@70.3_finetune@25",

"cpu_lat@17ms_top1@75.7_finetune@25",
"cpu_lat@15ms_top1@74.6_finetune@25",
"cpu_lat@11ms_top1@72.0_finetune@25",
"cpu_lat@10ms_top1@71.1_finetune@25",
]   # 50 models in total


def load_test_models(net_id: int = 0, n_classes=10, trained_weights=None):
    net, image_size = ofa_specialized(net_id=test_models[net_id], n_classes=n_classes, pretrained=False)
    if trained_weights:
        init = torch.load(
            trained_weights,
            map_location="cpu",
        )["state_dict"]
        net.load_state_dict(init)
    else:
        url_base = "https://hanlab.mit.edu/files/OnceForAll/ofa_specialized/"
        init = torch.load(
            download_url(
                url_base + test_models[net_id] + "/init",
                model_dir=".torch/ofa_specialized/%s/" % test_models[net_id],
            ),
            map_location="cpu",
        )["state_dict"]
        net.load_state_dict(init)
    return net, image_size

