from argparse import ArgumentParser
from torchvision import models
import utils
from tqdm import tqdm
import torch
import os.path as path

# This code goes over the entire training set from imagenet and saves each channel of the target layer's activations.
# for each of the intermediate training models.
# Can also be used to only get 1st and last versions of the model.
#


def calc_activations_by_channel(
    initial_model, final_model, layers, get_atk_act, output_folder, results_directory, feature_name, device,
    data_loader, name = '', model_name = "alexnet"
):
    # print('hi')
    utils.ensure_dir(output_folder)
    print(f"getting activations for the model!")
    models = [initial_model, final_model]
    paths = []
    if name != '':
        name = '_'+name

    for i, model in enumerate(models):
        # print(f'step {step}')
        activations_list = []

        with torch.no_grad():
            if i == 0:
                activations_save_path = path.join(model_name, f"{layers[0]}_initial_model_activations{name}.pt")
            else:
                activations_save_path = path.join(output_folder, f"{layers[0]}_final_model_activations{name}.pt")

            paths.append(activations_save_path)

            if path.exists(activations_save_path):
                print(f"{activations_save_path} already exists!")
            else:
                print(f"{activations_save_path} not found, calculating activations")
                loaded_model = model
                loaded_model.eval()
                loaded_model = loaded_model.to(device)

                activations_dict = {}
                # Define the function for the hook.
                model_layers = utils.get_model_layers(loaded_model)
                utils.register_hooks(model_layers, activations_dict, layers, feature_name)

                
                for data, _, _ in tqdm(data_loader):
                    loaded_model(data.to(device))
                    activations = activations_dict[feature_name]
                    # print(activations.shape)

                    # Returns an (image, channels) shaped tensor
                    channel_activations = get_atk_act(activations)

                    # print(channel_activations.shape)
                    activations_list.append(channel_activations.T)
                activations_tensor = torch.cat(activations_list, dim=1)
                if i == 0:
                    print(f"Initial activations shape for model: {activations_tensor.shape}")
                else:
                    print(f"Final activations shape for model: {activations_tensor.shape}")
                #
                torch.save(activations_tensor, activations_save_path)
                print(f"Activations for model saved!")
    return paths


# def sort_activations_by_channel(steps, channels, output_folder, num_indices_to_track=10,):

#     results_dict = {}
#     for channel in tqdm(channels):
#         channel_indices = []
#         for step in steps:
#             with torch.no_grad():
#                 # Perform set up and grab the indices of certain images for each channel
#                 activations_save_path = path.join(output_folder, f"model_checkpoint_activations_step_{step}.pt")
#                 activations = torch.load(activations_save_path)[channel]
#                 # print(activations.shape)
#                 top_acts, top_indices = torch.topk(activations, k=activations.shape[0])
#                 # print(top_indices.shape)
#                 channel_indices.append(top_indices)
#                 # Track some specific
#                 if step == 0:
#                     top_samples = top_indices[0:num_indices_to_track]
#                     mid_samples = top_indices[
#                                   top_indices.shape[0] // 2:top_indices.shape[0] // 2 + num_indices_to_track]
#                     bot_samples = top_indices[-num_indices_to_track:]
#                     # print(f'top samples shape {top_samples.shape}')
#                     # print(mid_samples.shape)
#                     # print(bot_samples.shape)

#         channel_indices = torch.vstack(channel_indices)
#         # print(f'channel indices shape: {channel_indices.shape}')
#         top_positions = utils.track_images(channel_indices, top_samples)
#         mid_positions = utils.track_images(channel_indices, mid_samples)
#         bot_positions = utils.track_images(channel_indices, bot_samples)
#         # print(f'top positions shape: {top_positions.shape}')
#         results_dict[f'channel_{channel}_top_positions'] = top_positions
#         results_dict[f'channel_{channel}_mid_positions'] = mid_positions
#         results_dict[f'channel_{channel}_bot_positions'] = bot_positions
#         return results_dict


def get_and_sort_activations(args, init_model, final_model, data_loader = utils.get_topk_dataset_loader(), model_name ="alexnet", do_sorting=False, name = ''):
    # This code extracts information from the configuration file my do_optimization script generates
    # grabs a list of channels, the step numbers where the model was saved, defines the new results directory,
    print(f'dataloader length {len(data_loader)}')
    config_dict = {}
    with open(path.join(args.results_directory, "configuration.txt")) as f:
        for line in f:
            (key, val) = line.split(":")
            config_dict[key] = val.strip()
    # print(config_dict)
    channels = utils.get_tuple_from_config_dict(config_dict, "channel")
    channels = [int(i) for i in channels]
    output_folder = args.results_directory + "/results"
    get_atk_act = utils.get_attack_activations_function(config_dict["attack_obj"])
    _default_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = _default_device
    # Feature name is arbitrary but necessary for the hook
    feature_name = "activations"
    layers = utils.get_tuple_from_config_dict(config_dict, "layer")
    layers = [layer.strip("'") for layer in layers]
    # print(steps)

    result = calc_activations_by_channel(
        init_model, final_model, layers, get_atk_act, output_folder, args.results_directory, feature_name, device,
        data_loader, model_name=model_name, name=name
    )
    return result

def get_and_sort_activations_v2(results_directory, init_model, final_model, data_loader = utils.get_topk_dataset_loader(), model_name ="alexnet", do_sorting=False, name = ''):
    # This code extracts information from the configuration file my do_optimization script generates
    # grabs a list of channels, the step numbers where the model was saved, defines the new results directory,
    print(f'dataloader length {len(data_loader)}')
    config_dict = {}
    with open(path.join(results_directory, "configuration.txt")) as f:
        for line in f:
            (key, val) = line.split(":")
            config_dict[key] = val.strip()
    # print(config_dict)
    channels = utils.get_tuple_from_config_dict(config_dict, "channel")
    channels = [int(i) for i in channels]
    output_folder = results_directory + "/results"
    get_atk_act = utils.get_attack_activations_function(config_dict["attack_obj"])
    _default_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = _default_device
    # Feature name is arbitrary but necessary for the hook
    feature_name = "activations"
    layers = utils.get_tuple_from_config_dict(config_dict, "layer")
    layers = [layer.strip("'") for layer in layers]
    # print(steps)

    result = calc_activations_by_channel(
        init_model, final_model, layers, get_atk_act, output_folder, results_directory, feature_name, device,
        data_loader, model_name=model_name, name=name
    )
    return result


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--results-directory", type=str)
    parser.add_argument("--num-img-to-save", type=int, default=10)
    parser.add_argument("--num-top-classes", type=int, default=6)
    parser.add_argument("--num-indices-to-track", type=int, default=10)
    # TODO: Implement this
    parser.add_argument("--track-intermediate-steps", action="store_true")

    args = parser.parse_args()
    # results_directory = args.results_directory

    get_and_sort_activations(args, do_sorting=False)
