# Copyright (c) 2021 Copyright holder of the paper "Test-Time Adaptation to Distribution Shifts by Confidence Maximization and Input Transformation" submitted to NeurIPS 2021 for review
# All rights reserved.

"""Save the torchvision ImageNet models to mlflow"""
import os
import re

import mlflow
import torch
from torchvision import models

model_urls = {
    "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
    "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
    "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
    "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
}

# download all the above torchvision model weights and save it under the folder "home/torchvision_imagenet_models"
# name the models with the basename (suffix e.g., mobilenet_v2-b0353104.pth) provided in model_urls
repo = "home/torchvision_imagenet_models"


def load_state_dict(model, model_url) -> None:
    # '.'s are no longer allowed in module names, but previous _DenseLayer
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
    )

    state_dict = torch.load(model_url)
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)


def main():
    mlflow.set_tracking_uri(
        "file:///home/test-time-robustness/mlruns"
    )
    mlflow.set_experiment("torchvision_models")

    for model_name, model_url in model_urls.items():
        mlflow.start_run(run_name=model_name)
        weight_path = os.path.join(repo, os.path.basename(model_url))
        model_factory = getattr(models, model_name)
        model = model_factory()
        if "densenet" in model_name:
            load_state_dict(model, weight_path)
        else:
            model.load_state_dict(torch.load(weight_path))
        mlflow.pytorch.log_model(model, "model")
        uri = mlflow.get_artifact_uri()
        print(f"{model_name}: {os.path.join(uri[7:], 'model')}")
        mlflow.end_run()


if __name__ == "__main__":
    main()
