import os
import numpy as np
import torch
from tqdm import tqdm
import pandas as pd

# Define the directory containing the .npy files
directory = "./attnmaps_init/eval_iters_100"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# List all .npy files in the directory
npy_files = sorted([f for f in os.listdir(directory) if f.endswith('.npy')])


# Function to extract layer, head, and iteration from filename
def extract_layer_head_iter(file_name):
    try:
        name = file_name.replace('.npy', '')
        tokens = name.split('_')
        # tokens should be ['layer{layer}', 'head{head}', 'iter{iteration}', 'att' or 'rank']
        layer_num = int(tokens[0].replace('layer', ''))
        head_num = int(tokens[1].replace('head', ''))
        iter_num = int(tokens[2].replace('iter', ''))
        return layer_num, head_num, iter_num
    except (IndexError, ValueError) as e:
        print(f"Error parsing file name {file_name}: {e}")
        return None, None, None


def compute_element_mass(att_map):
    try:
        # Ensure the attention map is a NumPy array of floating-point type
        if not isinstance(att_map, np.ndarray):
            att_map = np.array(att_map)
        if not np.issubdtype(att_map.dtype, np.floating):
            att_map = att_map.astype(np.float32)

        # Convert to tensor and move to device
        att_matrix = torch.tensor(att_map, dtype=torch.float32).mean(0).to(device)

        # Check if the attention matrix has at least 2 dimensions
        if att_matrix.dim() < 2:
            print(f"Attention map has less than 2 dimensions: {att_matrix.shape}")
            return None

        # Compute sum of squares per column
        sum_squares_per_column = torch.sum(att_matrix ** 2, dim=0)

        # Compute total sum of squares
        total_sum_squares = torch.sum(sum_squares_per_column)

        if total_sum_squares == 0:
            print("Total sum of squares is zero.")
            return None

        # Compute the proportion of each column
        proportions = sum_squares_per_column / total_sum_squares

        # Sort columns by sum of squares in descending order
        sorted_sum_squares, sorted_indices = torch.sort(sum_squares_per_column, descending=True)
        sorted_proportions = proportions[sorted_indices]

        # Compute cumulative proportions
        cumulative_proportions = torch.cumsum(sorted_proportions, dim=0)

        # Find the number of columns needed to reach 90% of total magnitude
        mass_approx = torch.searchsorted(cumulative_proportions, 0.90, right=False) + 1

        return mass_approx.item()

    except Exception as e:
        print(f"Error computing element mass: {e}")
        return None


# Initialize dictionary to temporarily store all mass_approx for each (layer, iteration)
temp_layer_iter_ranks = {}

# Iterate over each .npy file and compute the mass_approx for each head
for file_name in tqdm(npy_files, desc="Processing files"):
    file_path = os.path.join(directory, file_name)
    att_map = np.load(file_path)

    # Skip the file if there is an error extracting layer, head, and iter info
    layer_num, head_num, iter_num = extract_layer_head_iter(file_name)
    if layer_num is None or head_num is None or iter_num is None:
        continue

    mass_approx = compute_element_mass(att_map)

    # Skip if mass_approx could not be computed
    if mass_approx is None:
        continue

    # Store the mass_approx temporarily by (layer, iteration)
    if (layer_num, iter_num) not in temp_layer_iter_ranks:
        temp_layer_iter_ranks[(layer_num, iter_num)] = []
    temp_layer_iter_ranks[(layer_num, iter_num)].append(mass_approx)

# Initialize final dictionary to store average mass_approx for each (layer, iteration)
layer_iter_avg_ranks = {}

# Calculate the average mass_approx for each layer and iteration
for (layer_num, iter_num), mass_approx_list in temp_layer_iter_ranks.items():
    # Compute the average mass_approx among all heads for this (layer, iteration)
    avg_mass_approx = np.mean(mass_approx_list)

    # Store the result in the final dictionary
    if layer_num not in layer_iter_avg_ranks:
        layer_iter_avg_ranks[layer_num] = {}
    layer_iter_avg_ranks[layer_num][f"iter_{iter_num}"] = avg_mass_approx

# Prepare data for DataFrame
rows = []
for layer_num, iter_ranks in sorted(layer_iter_avg_ranks.items()):
    row = {'Layer': layer_num}
    row.update(iter_ranks)  # Add each iter's average mass_approx as a column
    rows.append(row)

# Create a DataFrame from the list of rows
df = pd.DataFrame(rows)

# Fill missing values with NaN if some iterations are missing for a layer
df.fillna(value=np.nan, inplace=True)

# Create directory for saving if it does not exist
output_dir = "/out/medium/save"
os.makedirs(output_dir, exist_ok=True)

# Save the DataFrame to a CSV file
csv_output_path = os.path.join(output_dir, "medinit_mass.csv")
df.to_csv(csv_output_path, index=False)

print(f'CSV file with Layer and iteration average mass_approx saved to {csv_output_path}')
