import gc
import json
import argparse
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

from torchvision.models import (
    alexnet, vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn,
    vgg19,vgg19_bn, densenet121, densenet161, densenet169, densenet201,
    resnet18, resnet34, resnet50, resnet101, resnet152,
    convnext_base, convnext_large,
    swin_t, swin_b, swin_v2_t, swin_v2_b,
    regnet_y_128gf, regnet_y_32gf,
    vit_b_16, vit_h_14, vit_l_16
)
from torchvision.models import (
    AlexNet_Weights, VGG11_Weights, VGG11_BN_Weights,
    VGG13_Weights, VGG13_BN_Weights, VGG16_Weights, VGG16_BN_Weights,VGG19_BN_Weights,
    VGG19_Weights, DenseNet121_Weights, DenseNet161_Weights,
    DenseNet169_Weights, DenseNet201_Weights,
    ResNet18_Weights, ResNet34_Weights, ResNet50_Weights,
    ResNet101_Weights, ResNet152_Weights,
    ConvNeXt_Base_Weights, ConvNeXt_Large_Weights,
    Swin_T_Weights, Swin_B_Weights, Swin_V2_T_Weights, Swin_V2_B_Weights,
    RegNet_Y_128GF_Weights, RegNet_Y_32GF_Weights,          
    ViT_B_16_Weights, ViT_H_14_Weights, ViT_L_16_Weights    
)


class ImagenetDataset(ImageFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        class_id = path.split('/')[-2]
        filename = path.split('/')[-1][15:-5]
        filename = class_id + "_" + filename
        return filename, sample, class_id


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model", type=str, required=True,
        choices=[
            "alexnet", "vgg11", "vgg11_bn", "vgg13", "vgg13_bn",
            "vgg16", "vgg16_bn","vgg19_bn", "vgg19", "densenet121",
            "densenet161", "densenet169", "densenet201",
            "resnet18", "resnet34", "resnet50", "resnet101", "resnet152",
            "resnet50v2", "resnet101v2", "resnet152v2",
            "convnext_base", "convnext_large",
            "swin_t", "swin_t_v2", "swin_b", "swin_b_v2",
            "regnet_y_128gf_e2e", "regnet_y_128gf_linear",
            "regnet_y_32gf_e2e",  "regnet_y_32gf_linear", "regnet_y_32gf_v2",
            "vit_b_16_linear", "vit_b_16_v1",
            "vit_h_14_linear",
            "vit_l_16_linear", "vit_l_16_v1"
        ]
    )
    parser.add_argument("--data-folder", type=str, required=True)
    parser.add_argument("--split", type=str, default="val", choices=["train", "val"])
    args = parser.parse_args()

    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # Mapping of class_id to idx
    with open("class_list.json", "r") as f:
        class_list = json.load(f)
    CLASS_MAP = {c: i for i, c in enumerate(class_list)}

    # Mapping of model names to constructors and pretrained weights
    model_dict = {
        "alexnet": (alexnet, AlexNet_Weights.IMAGENET1K_V1),
        "vgg11": (vgg11, VGG11_Weights.IMAGENET1K_V1),
        "vgg11_bn": (vgg11_bn, VGG11_BN_Weights.IMAGENET1K_V1),
        "vgg13": (vgg13, VGG13_Weights.IMAGENET1K_V1),
        "vgg13_bn": (vgg13_bn, VGG13_BN_Weights.IMAGENET1K_V1),
        "vgg16": (vgg16, VGG16_Weights.IMAGENET1K_V1),
        "vgg16_bn": (vgg16_bn, VGG16_BN_Weights.IMAGENET1K_V1),
        "vgg19": (vgg19, VGG19_Weights.IMAGENET1K_V1),
        "vgg19_bn": (vgg19_bn, VGG19_BN_Weights.IMAGENET1K_V1),
        "densenet121": (densenet121, DenseNet121_Weights.IMAGENET1K_V1),
        "densenet161": (densenet161, DenseNet161_Weights.IMAGENET1K_V1),
        "densenet169": (densenet169, DenseNet169_Weights.IMAGENET1K_V1),
        "densenet201": (densenet201, DenseNet201_Weights.IMAGENET1K_V1),
        "resnet18": (resnet18, ResNet18_Weights.IMAGENET1K_V1),
        "resnet34": (resnet34, ResNet34_Weights.IMAGENET1K_V1),
        "resnet50": (resnet50, ResNet50_Weights.IMAGENET1K_V1),
        "resnet101": (resnet101, ResNet101_Weights.IMAGENET1K_V1),
        "resnet152": (resnet152, ResNet152_Weights.IMAGENET1K_V1),
        # ResNet v2
        "resnet50v2": (resnet50, ResNet50_Weights.IMAGENET1K_V2),
        "resnet101v2": (resnet101, ResNet101_Weights.IMAGENET1K_V2),
        "resnet152v2": (resnet152, ResNet152_Weights.IMAGENET1K_V2),
        # ConvNeXt
        "convnext_base": (convnext_base, ConvNeXt_Base_Weights.IMAGENET1K_V1),
        "convnext_large": (convnext_large, ConvNeXt_Large_Weights.IMAGENET1K_V1),
        # Swin Transformer
        "swin_t": (swin_t, Swin_T_Weights.IMAGENET1K_V1),
        "swin_b": (swin_b, Swin_B_Weights.IMAGENET1K_V1),
        "swin_t_v2": (swin_v2_t, Swin_V2_T_Weights.IMAGENET1K_V1),
        "swin_b_v2": (swin_v2_b, Swin_V2_B_Weights.IMAGENET1K_V1),
        # RegNet
        "regnet_y_128gf_e2e":   (regnet_y_128gf, RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_E2E_V1),
        "regnet_y_128gf_linear":(regnet_y_128gf, RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_LINEAR_V1),
        "regnet_y_32gf_e2e":    (regnet_y_32gf,  RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1),
        "regnet_y_32gf_linear": (regnet_y_32gf,  RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_LINEAR_V1),
        "regnet_y_32gf_v2":     (regnet_y_32gf,  RegNet_Y_32GF_Weights.IMAGENET1K_V2),
        # ViT
        "vit_b_16_linear": (vit_b_16, ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1),
        "vit_b_16_v1":     (vit_b_16, ViT_B_16_Weights.IMAGENET1K_V1),
        "vit_h_14_linear": (vit_h_14, ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1),
        "vit_l_16_linear": (vit_l_16, ViT_L_16_Weights.IMAGENET1K_SWAG_LINEAR_V1),
        "vit_l_16_v1":     (vit_l_16, ViT_L_16_Weights.IMAGENET1K_V1),        
    }

    # Load model
    if args.model not in model_dict:
        raise ValueError(f"Model {args.model} not supported.")
    MODEL_FUNC, WEIGHTS = model_dict[args.model]
    model = MODEL_FUNC(weights=WEIGHTS)
    transform = WEIGHTS.transforms()

    # Load data
    dataset = ImagenetDataset(args.data_folder, transform)
    dataloader = DataLoader(dataset, batch_size=50, shuffle=False, num_workers=2)

    # Run inference
    probs = torch.zeros((len(dataset), 1000), dtype=torch.float16)
    model = model.to(DEVICE)
    model.eval()

    img_names = []
    accuracies = []
    with torch.no_grad():
        for batch_idx, (names, images, targets) in tqdm(enumerate(dataloader), total=len(dataloader)):
            targets = [CLASS_MAP[x] for x in targets]
            targets = torch.tensor(targets).to(DEVICE)
            images = images.to(DEVICE)
            logits = model(images)
            batch_probs = F.softmax(logits, dim=1)
            preds = batch_probs.argmax(dim=1)

            # Tính accuracy từng ảnh (1 nếu đúng, 0 nếu sai)
            batch_acc = (preds == targets).int().tolist()
            accuracies.extend(batch_acc)

            start = batch_idx * dataloader.batch_size
            end = start + batch_probs.size(0)
            probs[start:end] = batch_probs.cpu()
            img_names.extend(names)

    # Save output
    output_filename = f"{args.model}_imagenet_{args.split}.pth"
    torch.save({"probs": probs, "img_names": img_names, "accuracies": accuracies}, output_filename)

    # Clean up
    del batch_probs, probs
    gc.collect()
    torch.cuda.empty_cache()


if __name__ == "__main__":
    main()
