#!/usr/bin/env python3

import torch
import numpy as np
from tqdm import tqdm
from pathlib import Path
import os
from argparse import ArgumentParser

parser = ArgumentParser("Determine the images that give the top 10 activations for a given channel of a given layer")
parser.add_argument("-i", "--input", help="Input directory of activations and indices", default="ceph/imagenetactiv")
parser.add_argument("-o", "--output", help="Output directory", default="ceph/imagenetactivpostprocessed")
args = parser.parse_args()

# Set up torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Ensure that the output directory exists
Path(args.output).mkdir(parents=True, exist_ok=True)

# Determine the layer names from files in the input directory
filenames = os.listdir( args.input )
filenames.pop( filenames.index("dataset") )
layer_names = list( map(lambda file: file.split(".")[0], filenames) )
layer_names = list( np.unique(layer_names) )
print("The layer names: ", layer_names)

# Then load the activations and indices matrices for each layer
activations = {
        layer : np.load( os.path.join(args.input, f"{layer}.activations.npy") )
        for layer in layer_names
}
indices = {
        layer : np.load( os.path.join(args.input, f"{layer}.indices.npy") )
        for layer in layer_names
}

for layer in tqdm(layer_names):

    channel_activations = activations[layer].transpose(0,2,3,1).reshape(-1, activations[layer].shape[1])
    channel_indices = indices[layer].transpose(0,2,3,1).reshape(-1, activations[layer].shape[1])

    channel_activations = torch.from_numpy(channel_activations).to(device)
    channel_indices = torch.from_numpy(channel_indices).to(device)

    channel_activations_sorted, top10_indices = torch.topk(channel_activations, 10, axis=0)
    channel_indices_sorted = torch.gather(channel_indices, 0, top10_indices)

    determine_row = lambda index: (index % (activations[layer].shape[2] * activations[layer].shape[3])) // activations[layer].shape[3]
    determine_col = lambda index: (index % (activations[layer].shape[2] * activations[layer].shape[3])) % activations[layer].shape[3]

    top10_indices = top10_indices.detach().cpu().numpy()
    top10_rows = determine_row(top10_indices)
    top10_cols = determine_col(top10_indices)
    top10_positions = np.stack([top10_rows, top10_cols], axis=1)

    channel_activations_sorted = channel_activations_sorted.detach().cpu().numpy()
    channel_indices_sorted = channel_indices_sorted.detach().cpu().numpy()

    #print("layer: ", layer)
    #print("path: ", os.path.join(args.output, f"{layer}_activations.npy"))
    #print("alt_path: ", args.output+f"{layer}_activations.npy")

    np.save(os.path.join(args.output, f"{layer}.activations.npy"), channel_activations_sorted)
    np.save(os.path.join(args.output, f"{layer}.indices.npy"), channel_indices_sorted)
    np.save(os.path.join(args.output, f"{layer}.positions.npy"), top10_positions)
