import os
import numpy as np
from pdb import set_trace as pds
import json
import os
import sys
import time

class Timer(object):

    def __init__(self):

        self.start()

    def start(self):
        self.v = time.time()

    def end(self):
        return time.time() - self.v


def time_str(t):
    if t >= 3600:
        return '{:.1f}h'.format(t / 3600)
    if t > 60:
        return '{:.1f}m'.format(t / 60)
    return '{:.1f}s'.format(t)


def ensure_path(path, early_exit = False):
    if os.path.exists(path):
        if early_exit:
            if input('{:s} exists, continue? ([y]/n): '.format(path)) == 'n':
                sys.exit(0)
    else:
        os.makedirs(path)

# Function to load JSON data from a file
def load_json(file_path):
    with open(file_path, 'r') as file:
        return json.load(file)
    
def save_json(data, file_path, indent = 4):
    with open(file_path, 'w') as file:
        json.dump(data, file, indent = indent)


def compute_mse_and_relative_diff(conv_data, naive_data):
    # Create a mask for non-NaN values in conv_data
    mask = ~np.isnan(conv_data)
    
    # Use the mask to select only non-NaN values and their corresponding naive values
    conv_valid = conv_data[mask]
    naive_valid = naive_data[mask]
    
    # Compute MSE
    mse = np.mean((conv_valid - naive_valid) ** 2)
    
    # Compute diff
    diff = conv_data - naive_data
    diff = diff.reshape(-1, conv_data.shape[-1])
    diff = np.linalg.norm(diff, ord='fro')
    
    # Compute Frobenius norm of naive data
    frobenius_norm = np.linalg.norm(naive_data.reshape(-1, conv_data.shape[-1]), ord='fro')
    
    # Compute relative diff
    relative_diff = diff / frobenius_norm
    
    return mse, relative_diff

def check_nan_and_compare(model_id = "llama3_8b_ins", seq_len = 2048):
    directory = f"{model_id}/seq_len{seq_len}/hidden"
    errors = []
    # Hard-coded k values and index ranges
    # k_values = [8, 16, 32, 64, 'naive']
    # k_values = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 800, 1000]
    # index_ranges = [
    #     '0_1000', 
    #     '1000_2000', 
    #     '2000_3000', 
    #     '3000_4000', 
    #     '4000_5000', 
    #     '5000_6000', 
    #     '6000_7000', 
    #     '7000_8000',
    #     ]

    k_values = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2000]
    index_ranges = [
    "0_200",
    "200_400",
    "400_600",
    "600_800",
    ]

    timer = Timer()
    for idx,idx_range in enumerate(index_ranges):
        naive_filename = f"last_hidden_naive_{idx_range}.npy"
        naive_filepath = os.path.join(directory, naive_filename)
        if os.path.exists(naive_filepath):
            naive_data = np.load(naive_filepath)
        else:
            print(f"    No corresponding naive file found for comparison")

        for idx_k, k in enumerate(k_values):
            print(f"\nAnalyzing files for k = {k}")
            
            filename = f"last_hidden_conv_k_{k}_{idx_range}.npy"
            
            filepath = os.path.join(directory, filename)
            if not os.path.exists(filepath):
                print(f"  File not found: {filename}")
                continue

            data = np.load(filepath)

            
            # Check for NaN values
            # print(f"data size: {data.shape}, total values: {data.size}")
            nan_count = np.isnan(data).sum()
            total_elements = data.size
            nan_percentage = (nan_count / total_elements) * 100            
            # print(f"\n  {filename}:")
            # print(f"    NaN count: {nan_count} ({nan_percentage:.2f}% of total elements)")
            # pds()
            
            # Compare with naive approach
            mse, relative_diff = compute_mse_and_relative_diff(data, naive_data)
            # print(f"    MSE with naive (ignoring NaNs): {mse}")
            # print(f"    Relative MSE: {relative_mse}")
                
            error = {
                "k":k,
                "idx_range":idx_range,
                "NaN_count": str(nan_count),
                "NaN_percentage": f"{nan_percentage:.2f}%",
                "mse": str(mse), 
                "relative_diff": str(relative_diff),
            }
            errors.append(error)

            
        time_elapsed = timer.end()
        print(f"idx range {idx_range}, time elapsed {time_str(time_elapsed)} | {time_str(time_elapsed/(idx+1)*len(index_ranges))}")

    
    save_json(errors, f"{model_id}/seq_len{seq_len}/errors_diff.json")

if __name__ == "__main__":
    # directory = f"llama3_8b_ins/seq_len2048/hidden"  # Change this to the actual directory path if different
    # check_nan_and_compare(directory)

    model_id = "llama3_8b_ins"
    # model_id = "mistral_7b_ins_v03"
    seq_len = 2048
    check_nan_and_compare(model_id = model_id, seq_len = seq_len)
