import torch
import os
from nesim.utils.dimensionality import EffectiveDimensionality
from nesim.utils.json_stuff import dict_to_json, load_json_as_dict
from tqdm import tqdm
import argparse
import sys


def get_tensor_size_string(tensor):
    # Get the size of the tensor in bytes
    size_in_bytes = tensor.element_size() * tensor.numel()

    # Determine the appropriate unit
    if size_in_bytes >= 2**30:  # Gigabytes (GB)
        size_str = f"{size_in_bytes / (2**30):.2f} GB"
    elif size_in_bytes >= 2**20:  # Megabytes (MB)
        size_str = f"{size_in_bytes / (2**20):.2f} MB"
    elif size_in_bytes >= 2**10:  # Kilobytes (KB)
        size_str = f"{size_in_bytes / (2**10):.2f} KB"
    else:  # Bytes (B)
        size_str = f"{size_in_bytes} bytes"

    return size_str


parser = argparse.ArgumentParser(description="get effective dimensionality")
parser.add_argument(
    "--hook-output-folder",
    type=str,
    help="name of folder containing hook output pth files",
    required=True,
)
parser.add_argument(
    "--layer-names-json",
    type=str,
    help="filename of json containing layer names",
    required=True,
)
parser.add_argument(
    "--result-filename",
    type=str,
    help="filename of json containing results",
    required=True,
)
args = parser.parse_args()


device = "cuda:0"
batch_size = 100
num_samples = 10_000

assert num_samples % batch_size == 0
num_batches = int(num_samples / batch_size)
dataset_indices = [i for i in range(num_batches)]

effective_dimensionality = EffectiveDimensionality(
    flatten=True, batch_size=32, device="cuda:0"
)

filenames = [
    os.path.join(args.hook_output_folder, f"{dataset_idx}.pth")
    for dataset_idx in dataset_indices
]
layer_names = load_json_as_dict(filename=args.layer_names_json)

results = {}

for layer_name in layer_names:
    all_outputs_for_single_layer = []

    for f in tqdm(filenames, desc=f"Loading hook outputs for layer: {layer_name}"):
        assert os.path.exists(f), f"Invalid filename: {f}"
        tensor = torch.load(f, map_location="cpu")[layer_name].cpu()
        # tensor.shape: batch, *
        all_outputs_for_single_layer.append(tensor)

    print(f"Concatenating all outputs...")
    all_outputs_for_single_layer = torch.cat(all_outputs_for_single_layer, dim=0)
    print(f"Tensor size: {get_tensor_size_string(all_outputs_for_single_layer)}")

    print(
        f"Computing effective dimensionality for tensor of shape: {all_outputs_for_single_layer.shape}"
    )
    ed = effective_dimensionality.compute(all_outputs_for_single_layer)
    print(f"layer: {layer_name} Effective Dimensionality: {ed}")
    results[layer_name] = ed

dict_to_json(results, args.result_filename)
