from argparse import ArgumentParser
from torchvision import models
import utils
from tqdm import tqdm
import torch
import os.path as path

#This code goes over the entire training set from imagenet and saves each channel of the target layer's activations.
# for each of the intermediate training models.
# Can also be used to only get 1st and last versions of the model.
#

def get_activations_by_channel(steps, layers, get_atk_act, output_folder, results_directory, feature_name, device):
    #print('hi')
    utils.ensure_dir(output_folder)
    steps = [steps[0], steps[-1]]
    print(f'getting activations for models at steps: {steps}')
    for step in steps:
        #print(f'step {step}')
        activations_list = []

        with torch.no_grad():

            activations_save_path = path.join(output_folder, f"model_checkpoint_activations_step_{step}.pt")
            if path.exists(activations_save_path):
                print(f'{activations_save_path} already exists!')
            else:
                print(f'{activations_save_path} not found, calculating activations')
                loaded_model = models.alexnet()
                loaded_model.load_state_dict(
                    torch.load(path.join(results_directory, f"model_checkpoint_step_{step}.pt")))

                loaded_model.eval()
                loaded_model = loaded_model.to(device)

                activations_dict = {}
                # Define the function for the hook.
                model_layers = utils.get_model_layers(loaded_model)
                utils.register_hooks(model_layers, activations_dict, layers, feature_name)

                data_loader = utils.get_topk_dataset_loader()
                for data, _, _ in tqdm(data_loader):
                    loaded_model(data.to(device))
                    activations = activations_dict[feature_name]
                    # print(activations.shape)

                    # Returns an (image, channels) shaped tensor
                    channel_activations = get_atk_act(activations)

                    # print(channel_activations.shape)
                    activations_list.append(channel_activations.T)
                activations_tensor = torch.cat(activations_list, dim=1)
                print(f'Final activations shape for model at step {step}: {activations_tensor.shape}')
                #
                torch.save(activations_tensor, activations_save_path)
                print(f'Activations for model_{step} saved!')


def sort_activations_by_channel(steps, channels, output_folder, num_indices_to_track=10,):

    results_dict = {}
    for channel in tqdm(channels):
        channel_indices = []
        for step in steps:
            with torch.no_grad():
                # Perform set up and grab the indices of certain images for each channel
                activations_save_path = path.join(output_folder, f"model_checkpoint_activations_step_{step}.pt")
                activations = torch.load(activations_save_path)[channel]
                # print(activations.shape)
                top_acts, top_indices = torch.topk(activations, k=activations.shape[0])
                # print(top_indices.shape)
                channel_indices.append(top_indices)
                # Track some specific
                if step == 0:
                    top_samples = top_indices[0:num_indices_to_track]
                    mid_samples = top_indices[
                                  top_indices.shape[0] // 2:top_indices.shape[0] // 2 + num_indices_to_track]
                    bot_samples = top_indices[-num_indices_to_track:]
                    # print(f'top samples shape {top_samples.shape}')
                    # print(mid_samples.shape)
                    # print(bot_samples.shape)

        channel_indices = torch.vstack(channel_indices)
        # print(f'channel indices shape: {channel_indices.shape}')
        top_positions = utils.track_images(channel_indices, top_samples)
        mid_positions = utils.track_images(channel_indices, mid_samples)
        bot_positions = utils.track_images(channel_indices, bot_samples)
        # print(f'top positions shape: {top_positions.shape}')
        results_dict[f'channel_{channel}_top_positions'] = top_positions
        results_dict[f'channel_{channel}_mid_positions'] = mid_positions
        results_dict[f'channel_{channel}_bot_positions'] = bot_positions
        return results_dict


def get_and_sort_activations(args, do_sorting=True):

    # This code extracts information from the configuration file my do_optimization script generates
    # grabs a list of channels, the step numbers where the model was saved, defines the new results directory,
    config_dict = {}
    with open(path.join(args.results_directory, "configuration.txt")) as f:
        for line in f:
            (key, val) = line.split(':')
            config_dict[key] = val.strip()
    #print(config_dict)
    channels = utils.get_tuple_from_config_dict(config_dict, 'channel')
    channels = [int(i) for i in channels]
    output_folder = args.results_directory + '/results'
    get_atk_act = utils.get_attack_activations_function(config_dict['attack_obj'])
    _default_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = _default_device
    steps = [i for i in
             range(0, int(config_dict['nsteps']) + 1, int(config_dict['save_interval']))]

    #If you're not tracking intermediate steps, take only the first and final versions of the model
    if not args.track_intermediate_steps:
        steps = [steps[0],steps[-1]]

    # Feature name is arbitrary but necessary for the hook
    feature_name = 'activations'
    layers = utils.get_tuple_from_config_dict(config_dict, 'layer')
    layers = [layer.strip("\'") for layer in layers]
    #print(steps)
    get_activations_by_channel(steps, layers, get_atk_act, output_folder, args.results_directory, feature_name, device)

    #This may be unnecessary
    if do_sorting:
        return sort_activations_by_channel(steps, channels, output_folder, args.num_indices_to_track)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--results-directory", type=str)
    parser.add_argument("--num-img-to-save", type=int, default=10)
    parser.add_argument("--num-top-classes", type=int, default=6)
    parser.add_argument("--num-indices-to-track", type=int, default=10)
    #TODO: Implement this
    parser.add_argument("--track-intermediate-steps", action="store_true")


    args = parser.parse_args()
    #results_directory = args.results_directory

    get_and_sort_activations(args, do_sorting=False)
