from nesim.utils.dimensionality import EffectiveDimensionality
from nesim.utils.tensor_size import get_tensor_size_string
import argparse
import os
from nesim.utils.json_stuff import dict_to_json, load_json_as_dict
from tqdm import tqdm
import torch
from einops import rearrange
from lightning import seed_everything
import numpy as np

seed_everything(0)

NUM_RANDOM_SUBSETS = 50

device = "cuda:0"
config = load_json_as_dict("config.json")

parser = argparse.ArgumentParser(description="get effective dimensionality")
parser.add_argument(
    "--hook-output-folder",
    type=str,
    help="name of folder containing hook output pth files",
    required=True,
)
parser.add_argument(
    "--layer-names-json",
    type=str,
    help="filename of json containing layer names",
    required=True,
)
parser.add_argument(
    "--result-filename",
    type=str,
    help="filename of json containing results",
    required=True,
)
args = parser.parse_args()


import torch

def random_sample_along_first_dim(matrix: torch.Tensor, num_samples: int) -> torch.Tensor:
    # Get the number of rows in the matrix
    num_rows = matrix.size(0)
    
    # Generate random indices
    indices = torch.randperm(num_rows)[:num_samples]
    
    # Return the sampled rows
    return matrix[indices]


filenames = [
    os.path.join(args.hook_output_folder, f"{dataset_idx}.pth")
    for dataset_idx in range(config["num_samples"] // config["batch_size"])
]
layer_names = load_json_as_dict(filename=args.layer_names_json)

effective_dimensionality = EffectiveDimensionality(
    flatten=True, device=device, batch_size=512, progress=True
)

results = {}

results_ratan = np.zeros(
    shape = (
        len(layer_names),
        NUM_RANDOM_SUBSETS
    )
)
loaded_hook_ouput_files = [
    torch.load(f, map_location="cpu")
    for f in tqdm(filenames, desc = "loading hook outputs")
]

for layer_index, layer_name in enumerate(layer_names):
    all_outputs_for_single_layer = []

    

    for loaded_hook_output in tqdm(loaded_hook_ouput_files, desc=f"Computing stuff for layer: {layer_name}"):
        tensor = loaded_hook_output[layer_name].cpu()
        # tensor.shape: batch, *
        all_outputs_for_single_layer.append(tensor)

    print(f"Concatenating all outputs...")
    all_outputs_for_single_layer = [
        rearrange(
            x,
            "batch seq emb -> (batch seq) emb"
        )
        for x in all_outputs_for_single_layer
    ]
    all_outputs_for_single_layer = torch.cat(all_outputs_for_single_layer, dim=0)
    print(f"Tensor size: {get_tensor_size_string(all_outputs_for_single_layer)}")
    print(
        f"Computing effective dimensionality for tensor of shape: {all_outputs_for_single_layer.shape}"
    )
    ed_across_random_subsets = []
    for subset_index in range(NUM_RANDOM_SUBSETS):
        num_random_samples = all_outputs_for_single_layer.shape[0]//NUM_RANDOM_SUBSETS
        random_subset = random_sample_along_first_dim(matrix = all_outputs_for_single_layer, num_samples=num_random_samples)

        ed = effective_dimensionality.compute(random_subset)
        print(f"layer: {layer_name} subset_index: {subset_index} num_random_samples: {num_random_samples} Effective Dimensionality: {ed}")
        ed_across_random_subsets.append(ed)
    results_ratan[layer_index] = np.array(ed_across_random_subsets)

    results[layer_name] = ed
    del all_outputs_for_single_layer

np.save(
    file = os.path.join(f'results/{f"{os.path.basename(args.hook_output_folder)}.npy"}'),
    arr = results_ratan
)
dict_to_json(results, args.result_filename)
print(f"Saved: {args.result_filename}")
"""
python3 compute_effective_dimensionality.py --hook-output-folder hook_outputs --layer-names-json layer_names.json --result-filename results.json
"""
