from argparse import ArgumentParser
from torchvision import models
import utils
from tqdm import tqdm
import torch
import os.path as path
from scipy import stats
import numpy as np


def load_in_activations(output_folder, steps):

    init_activations_save_path = path.join(output_folder, f"model_checkpoint_activations_step_{steps[0]}.pt")
    final_activations_save_path = path.join(output_folder, f"model_checkpoint_activations_step_{steps[-1]}.pt")

    init_activations = torch.load(init_activations_save_path)
    final_activations = torch.load(final_activations_save_path)
    return init_activations, final_activations


def generate_visualization(activations):
    print(f'activations shape {activations.shape}')
    # Look at bin totals as opposed to binning the whole distribution
    h = np.zeros(10)
    for channel in range(activations.shape[0]):
        h_channel, bins = np.histogram(activations[channel].cpu())
        h = h + h_channel

    print('bin frequencies:\n', h)
    #print('bin cutoffs:\n', bins)
    pass

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--results-directory", type=str)
    #parser.add_argument("--num-img-to-save", type=int, default=10)
    #parser.add_argument("--num-top-classes", type=int, default=6)
    #parser.add_argument("--num-indices-to-track", type=int, default=10)

    args = parser.parse_args()

    config_dict = {}
    with open(path.join(args.results_directory, "configuration.txt")) as f:
        for line in f:
            (key, val) = line.split(':')
            config_dict[key] = val.strip()
    # print(config_dict)
    channels = utils.get_tuple_from_config_dict(config_dict, 'channel')
    channels = [int(i) for i in channels]
    output_folder = args.results_directory + '/results'
    #get_atk_act = utils.get_attack_activations_function(config_dict['attack_obj'])
    #_default_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #device = _default_device
    steps = [i for i in
             range(0, int(config_dict['nsteps']) + 1, int(config_dict['save_interval']))]
    #feature_name = 'activations'
    layers = utils.get_tuple_from_config_dict(config_dict, 'layer')
    layers = [layer.strip("\'") for layer in layers]

    init_acts, final_acts = load_in_activations(output_folder, steps)
    generate_visualization(init_acts)
    generate_visualization(final_acts)

