import torch
import os
import torch.nn.functional as F

from tqdm import tqdm
from torchvision.io import read_image, ImageReadMode
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torchvision.models import *

import warnings
warnings.filterwarnings('ignore')


class ImagenetTrainClassDataset(Dataset):
    def __init__(self, path: str, class_id: int, transform):
        assert path.split('/')[-1] == 'train'
        super().__init__()
        class_names = sorted(os.listdir(path))
        self.class_name = class_names[class_id]
        self.class_path = path + '/' + self.class_name

        self.img_names = sorted(os.listdir(self.class_path))
        self.transform = transform

    def __getitem__(self, idx):
        img_path = self.class_path + '/' + self.img_names[idx]
        image = read_image(img_path, ImageReadMode.RGB)
        return self.transform(image)

    def __len__(self):
        return len(self.img_names)


def model_generator():
    models = {
        'resnet18_v1': (resnet18(ResNet18_Weights.IMAGENET1K_V1), ResNet18_Weights.IMAGENET1K_V1.transforms()),
        #'resnet34_v1': (resnet34(ResNet34_Weights.IMAGENET1K_V1), ResNet34_Weights.IMAGENET1K_V1.transforms()),
        #'resnet50_v1': (resnet50(ResNet50_Weights.IMAGENET1K_V1), ResNet50_Weights.IMAGENET1K_V1.transforms()),
        #'resnet101_v1': (resnet101(ResNet101_Weights.IMAGENET1K_V1), ResNet101_Weights.IMAGENET1K_V1.transforms()),
        #'resnet152_v1': (resnet152(ResNet152_Weights.IMAGENET1K_V1), ResNet152_Weights.IMAGENET1K_V1.transforms()),
        #'resnet50_v2': (resnet50(ResNet50_Weights.IMAGENET1K_V2), ResNet50_Weights.IMAGENET1K_V2.transforms()),
        #'resnet101_v2': (resnet101(ResNet101_Weights.IMAGENET1K_V2), ResNet101_Weights.IMAGENET1K_V2.transforms()),
        #'resnet152_v2': (resnet152(ResNet152_Weights.IMAGENET1K_V2), ResNet152_Weights.IMAGENET1K_V2.transforms()),
        #'vgg13': (vgg13(VGG13_Weights.IMAGENET1K_V1), VGG13_Weights.IMAGENET1K_V1.transforms()),
        #'vgg13_bn': (vgg13_bn(VGG13_BN_Weights.IMAGENET1K_V1), VGG13_BN_Weights.IMAGENET1K_V1.transforms()),
        #'vgg19': (vgg19(VGG19_Weights.IMAGENET1K_V1), VGG19_Weights.IMAGENET1K_V1.transforms()),
        #'vgg19_bn': (vgg19_bn(VGG19_BN_Weights.IMAGENET1K_V1), VGG19_BN_Weights.IMAGENET1K_V1.transforms()),
        #'densenet121': (densenet121(DenseNet121_Weights.IMAGENET1K_V1), DenseNet121_Weights.IMAGENET1K_V1.transforms()),
        #'densenet161': (densenet161(DenseNet161_Weights.IMAGENET1K_V1), DenseNet161_Weights.IMAGENET1K_V1.transforms()),
        #'densenet169': (densenet169(DenseNet169_Weights.IMAGENET1K_V1), DenseNet169_Weights.IMAGENET1K_V1.transforms()),
        #'densenet201': (densenet201(DenseNet201_Weights.IMAGENET1K_V1), DenseNet201_Weights.IMAGENET1K_V1.transforms()),
        #'swin_t': (swin_t(Swin_T_Weights.IMAGENET1K_V1), Swin_T_Weights.IMAGENET1K_V1.transforms()),
        #'swin_b': (swin_b(Swin_B_Weights.IMAGENET1K_V1), Swin_B_Weights.IMAGENET1K_V1.transforms()),
        #'swin_v2_t': (swin_v2_t(Swin_V2_T_Weights.IMAGENET1K_V1), Swin_V2_T_Weights.IMAGENET1K_V1.transforms()),
        #'swin_v2_b': (swin_v2_b(Swin_V2_B_Weights.IMAGENET1K_V1), Swin_V2_B_Weights.IMAGENET1K_V1.transforms()),
    }
    for name, (model, transform) in models.items():
        yield name, (model, transform)


def forward(image_dir, output_path):
    os.makedirs(output_path, exist_ok=True)
    for name, (model, transform) in model_generator():
        num_classes = 1000
        chunk_size = 500  # number of classes in one chunk in order to out of RAM

        for chunk_id in range(num_classes // chunk_size):
            subsets = []
            name_list = []
            for i in range(chunk_id * chunk_size, (chunk_id + 1) * chunk_size):
                class_subset = ImagenetTrainClassDataset(image_dir + '/train', class_id=i, transform=transform)
                subsets.append(class_subset)
                name_list += class_subset.img_names

            name_list = [name.split('.')[0] for name in name_list]  # remove JPEG extension
            subset = ConcatDataset(subsets)
            train_dataloader = DataLoader(subset, batch_size=100, shuffle=False, num_workers=2)
            train_probs = torch.empty((len(subset), 1000), dtype=torch.float32)

            with torch.no_grad():
                for i, images in tqdm(enumerate(train_dataloader)):
                    images = images.to(device)
                    logits = model(images)
                    probs = F.softmax(logits, dim=1)
                    train_probs[i * 100: i * 100 + probs.size(0)] = probs.detach().cpu()

            train_probs = train_probs.half()
            output = {
                'probs': train_probs,
                'img_names': name_list
            }

            torch.save(output, f'{output_path}/{name}_train_{chunk_id}.pth')


if __name__ == '__main__':
    assert torch.cuda.is_available(), "cuda is not available"
    device = 'cuda'
    image_dir = "imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC"  # kaggle format
    assert os.path.exists(image_dir), "ImageNet dataset is not available"
    forward(image_dir, output_path='imagenet_model_outputs')

