#!/usr/bin/env python3
if __name__ == "__main__":
    version_number = '0.1.3'
    print('\nget_imagenet_activations_version_number: ' + version_number + '\n')

from argparse import ArgumentParser
import torch
from torchvision import datasets, models, transforms
from pathlib import Path
from tqdm import tqdm
import sys

from scripts.utils import *
# version = '1.0.1'
# print("VERSION: ", version)
assert torch.cuda.is_available(), "This script isn't nearly fast enough on CPU"

parser = ArgumentParser("Determines the images that correspond to the top N activations for each pixel of each channel of each layer of vgg19, run over ImageNet")
parser.add_argument("-f", "--folder", help="Folder to save activations to", default="ceph/imagenetactiv")
parser.add_argument("-N", default=10)
parser.add_argument("-b", "--batch", help="Batch size for images to run through vgg19, defaults to N", type=int, default=None)
parser.add_argument("-w", "--workers", help="Number of workers for torch DataLoader object", type=int, default=1)
parser.add_argument("-s", "--img-size", help="Side length of the (square) images fed through vgg19", type=int, default=224)
parser.add_argument("-d", "--data-dir", help="Path of ImageNet root directory", default="/data/imagenet_data")
parser.add_argument("--save-step", help="Number of batches after which to save the current state", type=int, default=1000)
parser.add_argument('-m', '--model', type=str, default='vgg19')
args = parser.parse_args()

if args.batch is None:
    args.batch = args.N

# Ensure that args.folder exists
Path(args.folder).mkdir(parents=True, exist_ok=True)


class ImageFolderWithIndices(datasets.ImageFolder):

    def __getitem__(self, index):

        img, label = super(ImageFolderWithIndices, self).__getitem__(index)

        return (img, label, index)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if args.model == 'vgg19':
    print('Using VGG19!')
    model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)
else:
    print('using AlexNet!')
    model = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)

model.to(device).eval()


normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

traindir = os.path.join(args.data_dir, "train")

train_dataset = ImageFolderWithIndices(
        traindir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, num_workers=args.workers)

# Save the dataset to file
torch.save(train_dataset.imgs, f"{args.folder}/dataset")

activations = get_model_activations(model)
layers = get_model_layers(model)
n_channels = {name: determine_output_channels(layer) for name, layer in layers.items()}

layer_ones = {layer: None for layer in activations.keys()}
top100activations = {layer: None for layer in activations.keys()}
top100indices = {layer: None for layer in activations.keys()}

#@profile
def updateTop100(layer, activ, indices, device):

    layer_ones[layer] = torch.ones(activ.shape, dtype=torch.long, device=device)
    if top100activations[layer] is None:
        top100activations[layer] = activ 
        top100indices[layer] = layer_ones[layer] * indices[:, None, None, None]
        return

    # Stack the activations and index tensors from this batch and the top100
    activations_tensor = torch.cat([top100activations[layer], activ], dim=0)
    
    batch_index_tensor = layer_ones[layer] * indices[:, None, None, None]
    index_tensor = torch.cat([top100indices[layer], batch_index_tensor], dim=0)

    # We then sort the activations tensor in descending order and use this sorting process to move around elements in the index tensor
    activations_tensor_sorted, topk_indices = torch.topk(activations_tensor, args.N, dim=0)

    #activations_tensor_sorted = torch.gather(activations_tensor, dim=0, index=argsort)
    index_tensor_sorted = torch.gather(index_tensor, dim=0, index=topk_indices)

    # The result is that the top 100 entries in the activations and index tensors are the top 100 activations for each tensor and the image indices associated with them
    top100indices[layer] = index_tensor_sorted[:args.batch]
    top100activations[layer] = activations_tensor_sorted[:args.batch]

def save_data():
    for layer in activations.keys():
        
        if n_channels[layer] is None:
            continue

        # Extract data from torch
        activations_data = top100activations[layer].detach().cpu().numpy()
        index_data = top100indices[layer].detach().cpu().numpy()
        
        # Save to file
        np.save(f"{args.folder}/{layer}.activations", activations_data)
        np.save(f"{args.folder}/{layer}.indices", index_data)


for batch_number, (batch, _, indices) in enumerate(tqdm(train_loader)):

    # Send the indices and batch tensors to gpu, where we do most of our computations
    batch = batch.cuda()
    indices = indices.cuda()

    # Running the network (activations sent to activations dict)
    model(batch)

    for layer, activ in activations.items():
        if n_channels[layer] is None:
            continue
        updateTop100(layer, activ, indices, device)

    if batch_number % args.save_step == 0:
        save_data()

# Save the final results
save_data()
