# Copyright (c) 2021 Copyright holder of the paper "Test-Time Adaptation to Distribution Shifts by Confidence Maximization and Input Transformation" submitted to NeurIPS 2021 for review
# All rights reserved.

"""Save the deepaugment_and_augmix model to mlflow"""
import os

import mlflow
import torch
from torchvision import models

robust_models = {
    "deepaugment_and_augmix": "resnet50_deepaugment_and_augmix.pth",
}

# download the ResNet50-DeepAugment+Augmix model weights from publicly accessible git repo https://github.com/hendrycks/imagenet-r
# save it under the folder "home/torchvision_imagenet_models"
# name the model  as "resnet50_deepaugment_and_augmix.pth",
repo = "home/torchvision_imagenet_models"


def load_state_dict(model, weight_path) -> None:

    checkpoint = torch.load(weight_path)
    state_dict = checkpoint["state_dict"]
    prefix = "module."
    suffix = "num_batches_tracked"
    for key in list(state_dict.keys()):
        if suffix in key:
            del state_dict[key]
        elif prefix in key:
            new_key = key[len(prefix) :]
            state_dict[new_key] = state_dict[key]
            del state_dict[key]

    model.load_state_dict(state_dict)


def main():
    mlflow.set_tracking_uri(
        "file:///home/test-time-robustness/mlruns"
    )
    mlflow.set_experiment("robust_pretrained_models")

    for model_name, model_url in robust_models.items():
        mlflow.start_run(run_name=model_name)
        weight_path = os.path.join(repo, os.path.basename(model_url))
        model_factory = getattr(models, "resnet50")
        model = model_factory()

        load_state_dict(model, weight_path)

        mlflow.pytorch.log_model(model, "model")
        uri = mlflow.get_artifact_uri()
        print(f"{model_name}: {os.path.join(uri[7:], 'model')}")
        mlflow.end_run()


if __name__ == "__main__":
    main()
