#! /usr/bin/env python3

import torch

import os
import pdb
from util.args_loader import get_args
from util.model_loader import get_model
from util.data_loader import get_loader_in, get_loader_out
import numpy as np
import torch.nn.functional as F
from transformers import MobileNetV2ForImageClassification
import timm

torch.manual_seed(1)
torch.cuda.manual_seed(1)
np.random.seed(1)
device = "cuda" if torch.cuda.is_available() else "cpu"


args = get_args()

# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

loader_in_dict = get_loader_in(args, config_type="eval", split=("train", "val"))
trainloaderIn, testloaderIn, num_classes = (
    loader_in_dict.train_loader,
    loader_in_dict.val_loader,
    loader_in_dict.num_classes,
)

# model = get_model(args, num_classes, load_ckpt=True) # set true to load from ash_ckpt in their repo; essentially the same as torch ckpt
# model = MobileNetV2ForImageClassification.from_pretrained("google/mobilenet_v2_1.0_224")

timm_models = {
    "BiT_m": {"config": {"model_name": "resnetv2_101x1_bitm", "pretrained": True}},
    "BiT_s": {
        "config": {
            "model_name": "resnetv2_101x1_bitm",
            "checkpoint_path": "./model_weights/checkpoints/BiT-S-R101x1.npz",
        }
    },
    "vit_base_patch16_224_21kpre": {
        "config": {"model_name": "vit_base_patch16_224", "pretrained": True}
    },
    "vit_base_patch16_384_21kpre": {
        "config": {"model_name": "vit_base_patch16_384", "pretrained": True}
    },
    "convnext_base_in22ft1k": {
        "config": {"model_name": "convnext_base_in22ft1k", "pretrained": True}
    },
    "convnext_base": {"config": {"model_name": "convnext_base", "pretrained": True}},
    "convnext_tiny-22k": {
        "config": {"model_name": "convnext_tiny_384_in22ft1k", "pretrained": True}
    },
    "deit3_base_patch16_224": {
        "config": {"model_name": "deit3_base_patch16_224", "pretrained": True}
    },
    "deit3_base_patch16_224_in21ft1k": {
        "config": {"model_name": "deit3_base_patch16_224_in21ft1k", "pretrained": True}
    },
    "tf_efficientnetv2_m": {
        "config": {"model_name": "tf_efficientnetv2_m", "pretrained": True}
    },
    "tf_efficientnetv2_m_in21ft1k": {
        "config": {"model_name": "tf_efficientnetv2_m_in21ft1k", "pretrained": True}
    },
    "swinv2-22k": {
        "config": {
            "model_name": "swinv2_base_window12to16_192to256_22kft1k",
            "pretrained": True,
        }
    },
    "swinv2-1k": {
        "config": {"model_name": "swinv2_base_window16_256", "pretrained": True}
    },
    "deit3-384-22k": {
        "config": {"model_name": "deit3_base_patch16_384_in21ft1k", "pretrained": True}
    },
    "deit3-384-1k": {
        "config": {"model_name": "deit3_base_patch16_384", "pretrained": True}
    },
    "tf_efficientnet_b7_ns": {
        "config": {"model_name": "tf_efficientnet_b7_ns", "pretrained": True}
    },
    "tf_efficientnet_b7": {
        "config": {"model_name": "tf_efficientnet_b7", "pretrained": True}
    },
    "resnet50": {"config": {"model_name": "resnet50", "pretrained": True}},
    "efficientnet_b0": {
        "config": {"model_name": "efficientnet_b0", "pretrained": True}
    },
    "vit_base_patch16_384_laion2b_in12k_in1k": {
        "config": {
            "model_name": "vit_base_patch16_clip_384.laion2b_ft_in12k_in1k",
            "pretrained": True,
        }
    },
    "vit_base_patch16_384_laion2b_in1k": {
        "config": {
            "model_name": "vit_base_patch16_clip_384.laion2b_ft_in1k",
            "pretrained": True,
        }
    },
    "vit_base_patch16_384_openai_in12k_in1k": {
        "config": {
            "model_name": "vit_base_patch16_clip_384.openai_ft_in12k_in1k",
            "pretrained": True,
        }
    },
    "vit_base_patch16_384_openai_in1k": {
        "config": {
            "model_name": "vit_base_patch16_clip_384.openai_ft_in1k",
            "pretrained": True,
        }
    },
    "vit_base_patch16_384": {
        "config": {
            "model_name": "vit_base_patch16_384.augreg_in1k",
            "pretrained": True,
        },
        "batch_size": 128,
        "server": "curie",
    },
    "xcit_medium_24_p16_224_dist": {
        "config": {"model_name": "xcit_medium_24_p16_224_dist", "pretrained": True}
    },
    "xcit_medium_24_p16_224": {
        "config": {"model_name": "xcit_medium_24_p16_224", "pretrained": True}
    },
}

model_name = "deit3_base_patch16_224"
model = timm.create_model(**timm_models[model_name]["config"])
model.eval()
model.to(device)


batch_size = args.batch_size
featdim = {
    "resnet50": 2048,
    "resnet50-supcon": 2048,
    "vit": 768,
    "mobile": 1280,
    "convnext": 1024,
    "swin": 1024,
    "deit": 768,
}[args.model_arch]

FORCE_RUN = True
ID_RUN = True
OOD_RUN = True

if ID_RUN:
    for split, in_loader in [("val", testloaderIn), ("train", trainloaderIn)]:
        cache_dir = f"cache/{args.in_dataset}_{split}_{args.name}_in"
        if FORCE_RUN or not os.path.exists(cache_dir):
            os.makedirs(cache_dir, exist_ok=True)
            feat_log = np.memmap(
                f"{cache_dir}/feat.mmap",
                dtype=float,
                mode="w+",
                shape=(len(in_loader.dataset), featdim),
            )
            score_log = np.memmap(
                f"{cache_dir}/score.mmap",
                dtype=float,
                mode="w+",
                shape=(len(in_loader.dataset), num_classes),
            )
            label_log = np.memmap(
                f"{cache_dir}/label.mmap",
                dtype=float,
                mode="w+",
                shape=(len(in_loader.dataset),),
            )

            model.to(device)
            model.eval()
            with torch.no_grad():
                for batch_idx, (inputs, targets) in enumerate(in_loader):
                    inputs, targets = inputs.to(device), targets.to(device)
                    # inputs.clamp_(0, 1)
                    # convert to PIL image from tensor
                    # trans = torchvision.transforms.ToPILImage()
                    # inputs = trans(inputs[0])
                    # inputs = image_processor(inputs, return_tensors="pt")

                    start_ind = batch_idx * batch_size
                    end_ind = min((batch_idx + 1) * batch_size, len(in_loader.dataset))

                    if args.model_arch == "resnet50-supcon":
                        out = model.encoder(inputs)
                        score = model.fc(out)
                    elif args.model_arch == "convnext":
                        out = model.forward_features(inputs)
                        out = F.adaptive_avg_pool2d(out, 1)
                        out = out.view(out.size(0), -1)
                        score = model(inputs)
                    elif args.model_arch == "swin":
                        out = model.forward_features(inputs).permute(0, 3, 1, 2)
                        out = F.adaptive_avg_pool2d(out, 1)
                        out = out.view(out.size(0), -1)
                        score = model(inputs)
                    elif args.model_arch == "deit":
                        out = model.forward_features(inputs)
                        out = model.pool(out)
                        out = model.fc_norm(out)
                        out = model.head_drop(out)
                        score = model.head(out)
                    elif args.model_arch == "mobile":
                        out = model.features(inputs.to(device))
                        out = F.adaptive_avg_pool2d(out, 1)
                        out = out.view(out.size(0), -1)
                        score = model.classifier(out)
                    else:
                        out = model.features(inputs)
                        score = model.fc(out)
                    if len(out.shape) > 2:
                        out = F.adaptive_avg_pool2d(out, 1)
                        out = out.view(out.size(0), -1)
                        score = model.fc(out)
                    # score = net(inputs)
                    feat_log[start_ind:end_ind, :] = out.data.cpu().numpy()
                    label_log[start_ind:end_ind] = targets.data.cpu().numpy()
                    score_log[start_ind:end_ind] = score.data.cpu().numpy()
                    if batch_idx % 100 == 0:
                        print(f"{batch_idx}/{len(in_loader)}")
        else:
            feat_log = np.memmap(
                f"{cache_dir}/feat.mmap",
                dtype=float,
                mode="r",
                shape=(len(in_loader.dataset), featdim),
            )
            score_log = np.memmap(
                f"{cache_dir}/score.mmap",
                dtype=float,
                mode="r",
                shape=(len(in_loader.dataset), num_classes),
            )
            label_log = np.memmap(
                f"{cache_dir}/label.mmap",
                dtype=float,
                mode="r",
                shape=(len(in_loader.dataset),),
            )

if OOD_RUN:

    for ood_dataset in args.out_datasets:
        loader_test_dict = get_loader_out(
            args, dataset=(None, ood_dataset), split=("val")
        )
        out_loader = loader_test_dict.val_ood_loader

        cache_dir = f"cache/{ood_dataset}vs{args.in_dataset}_{args.name}_out"
        if FORCE_RUN or not os.path.exists(cache_dir):
            os.makedirs(cache_dir, exist_ok=True)
            ood_feat_log = np.memmap(
                f"{cache_dir}/feat.mmap",
                dtype=float,
                mode="w+",
                shape=(len(out_loader.dataset), featdim),
            )
            ood_score_log = np.memmap(
                f"{cache_dir}/score.mmap",
                dtype=float,
                mode="w+",
                shape=(len(out_loader.dataset), num_classes),
            )
            model.eval()
            with torch.no_grad():
                for batch_idx, (inputs, _) in enumerate(out_loader):
                    inputs = inputs.to(device).float()
                    start_ind = batch_idx * batch_size
                    end_ind = min((batch_idx + 1) * batch_size, len(out_loader.dataset))

                    if args.model_arch == "resnet50-supcon":
                        out = model.encoder(inputs)
                        score = model.fc(out)
                    elif args.model_arch == "mobile":
                        out = model.features(inputs.to(device))
                        out = F.adaptive_avg_pool2d(out, 1)
                        out = out.view(out.size(0), -1)
                        score = model.classifier(out)
                    elif args.model_arch == "convnext":
                        out = model.forward_features(inputs)
                        out = F.adaptive_avg_pool2d(out, 1)
                        out = out.view(out.size(0), -1)
                        score = model(inputs)
                    elif args.model_arch == "swin":
                        out = model.forward_features(inputs).permute(0, 3, 1, 2)
                        out = F.adaptive_avg_pool2d(out, 1)
                        out = out.view(out.size(0), -1)
                        score = model(inputs)
                    elif args.model_arch == "deit":
                        out = model.forward_features(inputs)
                        out = model.pool(out)
                        out = model.fc_norm(out)
                        out = model.head_drop(out)
                        score = model.head(out)
                    else:
                        out = model.features(inputs)
                        score = model.fc(out)

                    if len(out.shape) > 2:
                        out = F.adaptive_avg_pool2d(out, 1)
                        out = out.view(out.size(0), -1)
                        score = model.fc(out)

                    ood_feat_log[start_ind:end_ind, :] = out.data.cpu().numpy()
                    ood_score_log[start_ind:end_ind] = score.data.cpu().numpy()
                    if batch_idx % 100 == 0:
                        print(f"{batch_idx}/{len(out_loader)}")

        else:
            ood_feat_log = np.memmap(
                f"{cache_dir}/feat.mmap",
                dtype=float,
                mode="r",
                shape=(len(out_loader.dataset), featdim),
            )
            ood_score_log = np.memmap(
                f"{cache_dir}/score.mmap",
                dtype=float,
                mode="r",
                shape=(len(out_loader.dataset), num_classes),
            )

# loader_test_dict = get_loader_out(args, dataset=(None, 'noise'), split=('val'))
# out_loader = loader_test_dict.val_ood_loader
#
# cache_dir = f"cache/{'noise'}vs{args.in_dataset}_{args.name}_out"
# if FORCE_RUN or not os.path.exists(cache_dir):
#     os.makedirs(cache_dir, exist_ok=True)
#     ood_feat_log = np.memmap(f"{cache_dir}/feat.mmap", dtype=float, mode='w+', shape=(len(out_loader.dataset), featdim))
#     ood_score_log = np.memmap(f"{cache_dir}/score.mmap", dtype=float, mode='w+', shape=(len(out_loader.dataset), num_classes))
#     model.eval()
#     for batch_idx, (inputs, _) in enumerate(out_loader):
#         inputs = inputs.to(device).float()
#         start_ind = batch_idx * batch_size
#         end_ind = min((batch_idx + 1) * batch_size, len(out_loader.dataset))
#
#         if args.model_arch == 'resnet50-supcon':
#             out = model.encoder(inputs)
#         else:
#             out = model.features(inputs)
#         if len(out.shape) > 2:
#             out = F.adaptive_avg_pool2d(out, 1)
#             out = out.view(out.size(0), -1)
#         score = model.fc(out)
#         # score = net(inputs)
#         ood_feat_log[start_ind:end_ind, :] = out.data.cpu().numpy()
#         ood_score_log[start_ind:end_ind] = score.data.cpu().numpy()
#         if batch_idx % 100 == 0:
#             print(f"{batch_idx}/{len(out_loader)}")
#
# else:
#     ood_feat_log = np.memmap(f"{cache_dir}/feat.mmap", dtype=float, mode='r', shape=(len(out_loader.dataset), featdim))
#     ood_score_log = np.memmap(f"{cache_dir}/score.mmap", dtype=float, mode='r', shape=(len(out_loader.dataset), num_classes))
