import os
import numpy as np
import torch
from tqdm import tqdm
import pandas as pd

# Define the directory containing the .npy files
directory = "./medium/attnmaps/eval_iters_100"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# List all .npy files in the directory
npy_files = sorted([f for f in os.listdir(directory) if f.endswith('.npy')])


# Function to extract layer, head, and iteration from filename
def extract_layer_head_iter(file_name):
    try:
        name = file_name.replace('.npy', '')
        tokens = name.split('_')
        # tokens should be ['layer{layer}', 'head{head}', 'iter{iteration}', 'att' or 'rank']
        layer_num = int(tokens[0].replace('layer', ''))
        head_num = int(tokens[1].replace('head', ''))
        iter_num = int(tokens[2].replace('iter', ''))
        return layer_num, head_num, iter_num
    except (IndexError, ValueError) as e:
        print(f"Error parsing file name {file_name}: {e}")
        return None, None, None


# Function to compute rank approximation
def compute_rank(att_map):
    try:
        # Ensure the attention map is of floating-point type
        if not isinstance(att_map, np.ndarray):
            att_map = np.array(att_map)
        if not np.issubdtype(att_map.dtype, np.floating):
            att_map = att_map.astype(np.float32)

        # Convert to tensor and move to device
        att_matrix = torch.tensor(att_map, dtype=torch.float32).mean(0).to(device)

        # Check if the attention matrix has at least 2 dimensions
        if len(att_matrix.shape) < 2:
            print(f"Attention map has less than 2 dimensions: {att_matrix.shape}")
            return None

        # Compute the singular values
        singular_values = torch.linalg.svdvals(att_matrix)

        # Compute explained variance
        total_variance = torch.sum(singular_values ** 2)
        explained_variances = (singular_values ** 2) / total_variance

        # Compute cumulative explained variance and find rank approximation
        cumulative_explained_variance = torch.cumsum(explained_variances, dim=0)
        rank_approx = torch.searchsorted(cumulative_explained_variance, 0.90) + 1

        return rank_approx.item()

    except Exception as e:
        print(f"Error computing rank: {e}")
        return None


# Initialize dictionary to store ranks by (layer, iteration)
layer_iter_ranks = {}

# Iterate over each .npy file and compute the rank for each head
for file_name in tqdm(npy_files, desc="Processing files"):
    file_path = os.path.join(directory, file_name)
    att_map = np.load(file_path)

    # Skip the file if there is an error extracting layer, head, and iter info
    layer_num, head_num, iter_num = extract_layer_head_iter(file_name)
    if layer_num is None or head_num is None or iter_num is None:
        continue

    rank_approx = compute_rank(att_map)

    # Skip if rank could not be computed
    if rank_approx is None:
        continue

    # Store the rank in a dictionary with (layer, iter) as key
    if layer_num not in layer_iter_ranks:
        layer_iter_ranks[layer_num] = {}
    layer_iter_ranks[layer_num][f"iter_{iter_num}"] = rank_approx

# Prepare data for DataFrame
rows = []
for layer_num, iter_ranks in sorted(layer_iter_ranks.items()):
    row = {'Layer': layer_num}
    row.update(iter_ranks)  # Add each iter's rank as a column
    rows.append(row)

# Create a DataFrame from the list of rows
df = pd.DataFrame(rows)

# Fill missing values with NaN if some iterations are missing for a layer
df.fillna(value=np.nan, inplace=True)

# Create directory for saving if it does not exist
output_dir = "./out/medium/save"

os.makedirs(output_dir, exist_ok=True)

# Save the DataFrame to a CSV file
csv_output_path = os.path.join(output_dir, "med_ranks2.csv")
df.to_csv(csv_output_path, index=False)

print(f'CSV file with Layer and iteration ranks saved to {csv_output_path}')
