import os
import numpy as np
from pdb import set_trace as pds
import json
import os
import sys

def ensure_path(path, early_exit = False):
    if os.path.exists(path):
        if early_exit:
            if input('{:s} exists, continue? ([y]/n): '.format(path)) == 'n':
                sys.exit(0)
    else:
        os.makedirs(path)

# Function to load JSON data from a file
def load_json(file_path):
    with open(file_path, 'r') as file:
        return json.load(file)
    
def save_json(data, file_path, indent = 4):
    with open(file_path, 'w') as file:
        json.dump(data, file, indent = indent)


def compute_mse_and_relative_mse(conv_data, naive_data):
    # Create a mask for non-NaN values in conv_data
    mask = ~np.isnan(conv_data)
    
    # Use the mask to select only non-NaN values and their corresponding naive values
    conv_valid = conv_data[mask]
    naive_valid = naive_data[mask]
    
    # Compute MSE
    mse = np.mean((conv_valid - naive_valid) ** 2)
    
    # Compute Frobenius norm of naive data
    frobenius_norm = np.linalg.norm(naive_data)
    
    # Compute relative MSE
    relative_mse = mse / (frobenius_norm ** 2)
    
    return mse, relative_mse

def check_nan_and_compare(directory):
    errors = []
    # Hard-coded k values and index ranges
    # k_values = [8, 16, 32, 64, 'naive']
    k_values = [8, 16, 32, 64]

    index_ranges = ['0_10', '10_20', '20_30', '30_40', '40_50', '50_60', '60_70', '70_80']

    for k in k_values:
        print(f"\nAnalyzing files for k = {k}")
        for idx_range in index_ranges:
            if k == 'naive':
                filename = f"last_hidden_naive_{idx_range}.npy"
            else:
                filename = f"last_hidden_conv_k_{k}_{idx_range}.npy"
            
            filepath = os.path.join(directory, filename)
            if not os.path.exists(filepath):
                print(f"  File not found: {filename}")
                continue

            data = np.load(filepath)

            
            # Check for NaN values
            print(f"data size: {data.shape}, total values: {data.size}")
            nan_count = np.isnan(data).sum()
            total_elements = data.size
            nan_percentage = (nan_count / total_elements) * 100            
            print(f"\n  {filename}:")
            print(f"    NaN count: {nan_count} ({nan_percentage:.2f}% of total elements)")
            # pds()
            
            # Compare with naive approach if not naive
            if k != 'naive':
                naive_filename = f"last_hidden_naive_{idx_range}.npy"
                naive_filepath = os.path.join(directory, naive_filename)
                if os.path.exists(naive_filepath):
                    naive_data = np.load(naive_filepath)
                    mse, relative_mse = compute_mse_and_relative_mse(data, naive_data)
                    print(f"    MSE with naive (ignoring NaNs): {mse}")
                    print(f"    Relative MSE: {relative_mse}")
                else:
                    print(f"    No corresponding naive file found for comparison")
            
            error = {
                "k":k,
                "idx_range":idx_range,
                "NaN_count": str(nan_count),
                "NaN_percentage": f"{nan_percentage:.2f}%",
                "mse": str(mse), 
                "relative_mse": str(relative_mse),
            }
            errors.append(error)
    
    save_json(errors, f"{directory}/errors.json")

if __name__ == "__main__":
    minus_constant = 100
    directory = f"constant{minus_constant}"  # Change this to the actual directory path if different
    check_nan_and_compare(directory)