import os
import torch
import torch.utils.data
import torchvision
from tqdm import tqdm
import transformers
import timm
import timm.data

CHECKPOINTS_DIR = os.environ.get("CHECKPOINTS_DIR", "checkpoints/")
CHECKPOINTS_DIR = os.path.join(CHECKPOINTS_DIR, "ce/")


def main(model_name="resnet50"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = timm.create_model(model_name, pretrained=True)
    model = model.to(device)
    model.eval()

    # transform
    data_config = timm.data.resolve_data_config(model.default_cfg)
    data_config["is_training"] = False
    transform = timm.data.create_transform(**data_config)

    root = os.environ["IMAGENET_ROOT"]
    dataset = torchvision.datasets.ImageNet(root=os.path.join(root), split="val", transform=transform)
    assert len(dataset) == 50000
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False, pin_memory=True, num_workers=6)

    logits = torch.empty((len(dataset), 1000), dtype=torch.float32)
    targets = torch.empty((len(dataset),), dtype=torch.int64)
    idx = 0
    for batch in tqdm(dataloader):
        x, y = batch
        x = x.to(device)
        with torch.no_grad():
            out = model(x).cpu()
        logits[idx : idx + x.shape[0]] = out
        targets[idx : idx + x.shape[0]] = y
        idx += x.shape[0]

    root = os.path.join(CHECKPOINTS_DIR, model_name)
    os.makedirs(root, exist_ok=True)
    torch.save(logits, os.path.join(root, "test_logits.pt"))
    torch.save(targets, os.path.join(root, "test_targets.pt"))
    acc = torch.mean((logits.argmax(dim=1) == targets).float()).item()
    print("Accuracy:", acc)


if __name__ == "__main__":
    model_names = [
        "resnet18",
        "resnet34",
        "resnet50",
        "resnet101",
        "densenet121",
        "densenet161",
        "densenet201",
        "timm/vit_base_patch16_224.augreg_in1k",
        "timm/vit_base_patch16_384.augreg_in1k",
        "timm/vit_base_patch16_224.orig_in21k_ft_in1k",
        "timm/vit_base_patch16_384.orig_in21k_ft_in1k",
    ]
    for model_name in model_names:
        print(model_name)
        root = os.path.join(CHECKPOINTS_DIR, model_name)
        if not os.path.isfile(os.path.join(root, "test_logits.pt")):
            try:
                main(model_name)
            except Exception as e:
                print(e)
