import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np

import os
import time
import torch
import pickle
import argparse
import random
import numpy as np
import pandas as pd

from csv import writer
from tqdm import tqdm
from copy import deepcopy
from transformers import AutoTokenizer
from transformers import GPTJForCausalLM
from transformers import RobertaForMaskedLM
from dataset_utils.fever import FEVER
# from dataset_utils.hotpot import Hotpot
from dataset_utils.bias_in_bios import BiasBiosGender, BiasBiosOccupation
from dataset_utils.truthfulqa import get_truthfulqa_pointwise_data_no_logger
from dataset_utils.bigbench import get_bb_dataset
from laser.LaserWrapper import LaserWrapper
from study_utils.log_utils import Logger
from study_utils.metric_utils import Metrics, DatasetMetrics, ContextAnswerLogProb
from study_utils.time_utils import elapsed_from_str, Progress

def get_svd_of_weights(model):
    grad_svd_dict = {}
    for name, param in model.named_parameters():
        if ("intermediate.dense" in name or "output.dense" in name) and "weight" in name:
            torch.cuda.empty_cache()
            with torch.no_grad():
                weight = param.float()
                U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
            U = U.cpu()
            S = S.cpu()
            Vh = Vh.cpu()
            
            grad_svd_dict[name] = [U, S, Vh]
    return grad_svd_dict

def get_ksvd_of_weights(model, k):
    grad_svd_dict = {}
    for name, param in model.named_parameters():
        if ("intermediate.dense" in name or "output.dense" in name) and "weight" in name:
            torch.cuda.empty_cache()
            with torch.no_grad():
                weight = param.float()
                n_rows = weight.shape[0]
                rows_per_group = int(n_rows/k)
                grouped_rows = torch.split(weight, rows_per_group)
                U = []
                S = []
                Vh = []
                for idx, group in enumerate(grouped_rows):
                    u, s, vh = torch.linalg.svd(group, full_matrices=False)
                    U.append(u.cpu())
                    S.append(s.cpu())
                    Vh.append(vh.cpu())
            
            grad_svd_dict[name] = [U, S, Vh]
    return grad_svd_dict

llm_name = "roberta-base"
tokenizer = AutoTokenizer.from_pretrained(llm_name)
model = RobertaForMaskedLM.from_pretrained(llm_name)

model.eval()

# dataset_util = FEVER()
# dataset = dataset_util.get_dataset_no_logger()
# sampled_data = dataset
# print(model)

##################################

hooks = {}

# def register_hooks(model):
#     for i, layer in enumerate(model.transformer.h):  # Iterate through all GPTJBlock layers
#         hooks[f"mlp_fc_in_{i}"] = layer.mlp.fc_in.register_forward_hook(forward_hook_fn)
#         hooks[f"mlp_fc_out_{i}"] = layer.mlp.fc_out.register_forward_hook(forward_hook_fn)

hooks = []  # Use a list to store all hooks for easy cleanup

def register_hooks(model):
    global hooks
    if hooks:
        print("Hooks already registered. Skipping re-registration.")
        return
    
    for i, layer in enumerate(model.transformer.h):
        hook_in = layer.mlp.fc_in.register_forward_hook(forward_hook_fn)
        hook_out = layer.mlp.fc_out.register_forward_hook(forward_hook_fn)
        hooks.extend([hook_in, hook_out])
    print(f"Registered hooks for {len(hooks)} operations.")

def remove_hooks():
    global hooks
    for hook in hooks:
        hook.remove()
    hooks = []
    print("All hooks have been removed.")

def forward_hook_fn(module, input, output):
    print(f"Hook triggered for {module}")
    layer_info = {}
    x = output.detach()
    
    layer_info["mean"] = torch.mean(x).item()
    layer_info["std"] = torch.std(x).item()
    
    matrix = x.view(-1, x.size(-1)).to(torch.float32).cpu().numpy()

    if np.isnan(matrix).any() or np.isinf(matrix).any():
        print("Warning: NaNs or Infs detected in matrix. Replacing with 0.")
        matrix = np.nan_to_num(matrix, nan=0.0, posinf=0.0, neginf=0.0)

    try:
        eigenvalues = np.linalg.eigvals(matrix @ matrix.T)
        eigenvalues = np.sort(eigenvalues)[::-1]
        layer_info["sum_last_k"] = sum(eigenvalues[-5:])

        drop_index = np.argmax(np.diff(eigenvalues))
        layer_info["drop_off_index"] = drop_index
    except np.linalg.LinAlgError as e:
        print(f"LinAlgError: {e}. Skipping eigenvalue computation for this layer.")
        layer_info["sum_last_k"] = None
        layer_info["drop_off_index"] = None
    
    per_layer_info.append(layer_info)

################################
def compute_gradients(model, inputs):
    torch.cuda.empty_cache()
    for param in model.parameters():
        param.requires_grad = False
    for name, param in model.named_parameters():
        if "intermediate.dense" in name or "output.dense" in name:
            param.requires_grad = True
    outputs = model(**inputs)
    loss = outputs.logits.sum() / outputs.logits.numel()  # Dummy loss for gradients
    # print(f"Loss: {loss.item()}")
    model.zero_grad()
    loss.backward()
    grad_norm_dict = {}
    for name, param in model.named_parameters():
        if ("intermediate.dense" in name or "output.dense" in name) and ".weight" in name:
            grad_norm = torch.norm(param.grad).item()
            # print(f"Gradient norm for {name}: {grad_norm}")
            grad_norm_dict[name] = grad_norm
    return grad_norm_dict

def compute_gradients_with_svd(model, inputs, grad_svd_dict):
    torch.cuda.empty_cache()

    for param in model.parameters():
        param.requires_grad = False

    for name, param in model.named_parameters():
        print(name)
        if "intermediate.dense" in name or "output.dense" in name:
            param.requires_grad = True

    outputs = model(**inputs)
    loss = outputs.logits.sum() / outputs.logits.numel() # Dummy loss for gradients

    model.zero_grad()

    loss.backward()

    svd_grads = {}
    grad_norm_dict = {}
    for name, param in model.named_parameters():
        if ("intermediate.dense" in name or "output.dense" in name) and "weight" in name:
            weight = param.float()
            [U, S, Vh] = grad_svd_dict[name]
            U = U.to("cuda")
            S = S.to("cuda")
            Vh = Vh.to("cuda")

            grad = (param.grad).to(torch.float)
            grad_D = U.T @ grad @ Vh.T
            # print("Gradient of diagonal: ", grad_D)
            # print("Gradient norm of diagonal: ", torch.norm(torch.diagonal(grad_D)))
            grad_norm_dict[name] = torch.norm(torch.diagonal(grad_D)).item()

    return grad_norm_dict

def compute_gradients_with_svd_last_k(model, inputs, grad_svd_dict, k=20):
    torch.cuda.empty_cache()

    for param in model.parameters():
        param.requires_grad = False

    for name, param in model.named_parameters():
        if "intermediate.dense" in name or "output.dense" in name:
            param.requires_grad = True

    outputs = model(**inputs)
    loss = outputs.logits.sum() / outputs.logits.numel() # Dummy loss for gradients

    model.zero_grad()

    loss.backward()

    svd_grads = {}
    grad_norm_dict = {}
    for name, param in model.named_parameters():
        if ("intermediate.dense" in name or "output.dense" in name) and "weight" in name:
            weight = param.float()
            [U, S, Vh] = grad_svd_dict[name]
            U = U.to("cuda")
            S = S.to("cuda")
            Vh = Vh.to("cuda")
            grad = (param.grad).to(torch.float)
            grad_D = U.T @ grad @ Vh.T
            diag_grad_D = torch.diagonal(grad_D)
            last_k_values = diag_grad_D[-k:]
            negative_idxs = last_k_values < 0
            # print("Last k values: ", last_k_values)
            # print("Gradient of diagonal: ", grad_D)
            # print("Gradient norm of diagonal: ", torch.norm(torch.diagonal(grad_D)))
            grad_norm_dict[name] = -torch.sum(last_k_values[negative_idxs]).item()

    return grad_norm_dict

def compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=2, k=20):
    torch.cuda.empty_cache()

    for param in model.parameters():
        param.requires_grad = False

    for name, param in model.named_parameters():
        if "intermediate.dense" in name or "output.dense" in name:
            param.requires_grad = True

    outputs = model(**inputs)
    loss = outputs.logits.sum() / outputs.logits.numel() # Dummy loss for gradients

    model.zero_grad()

    loss.backward()

    svd_grads = {}
    grad_norm_dict = {}
    for name, param in model.named_parameters():
        if ("intermediate.dense" in name or "output.dense" in name) and "weight" in name:
            weight = param.float()
            [U, S, Vh] = grad_svd_dict[name]
            grad = (param.grad).to(torch.float)
            n_rows = grad.shape[0]
            rows_per_group = int(n_rows/num_clusters)
            grouped_rows = torch.split(grad, rows_per_group)
            for idx, group in enumerate(grouped_rows):
                u = U[idx].to("cuda")
                s = S[idx].to("cuda")
                vh = Vh[idx].to("cuda")
                grad_d = u.T @ group @ vh.T
                diag_grad_d = torch.diagonal(grad_d)
                last_k_idx = (int)(k)
                last_k_values = diag_grad_d[-last_k_idx:]
                negative_idxs = last_k_values < 0
                # print("Last k values: ", last_k_values)
                # print("Gradient of diagonal: ", grad_D)
                # print("Gradient norm of diagonal: ", torch.norm(torch.diagonal(grad_D)))
                grad_norm_dict[name] = -torch.sum(last_k_values[negative_idxs]).item()


    return grad_norm_dict

################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# print(len(sampled_data))
model = model.to(device)

per_layer_info = []

# layer_statistics = defaultdict(list)
# register_hooks(model)

'''
dataset_util = FEVER()
dataset = dataset_util.get_dataset_no_logger()
sampled_data = dataset
print("Number of sampled_data: ", len(sampled_data))
grad_norms = {}
# grad_svd_dict = get_svd_of_weights(model)
num_clusters = 16
grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
# print(grad_svd_dict)
num_iters = 100 # len(sampled_data)
sampled_data = random.sample(sampled_data, num_iters)
# Run inference
for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
    # print(text)
    question = sampled_data[i]["question"]
    answer_ix = sampled_data[i]["answer"]
    assert answer_ix in [0, 1], "Answer must be 0 (False) or 1 (True)"
    
    # Create the prompt dynamically
    if question.strip().endswith(".") or question.strip().endswith("?"):
        prompted_question = "Consider the following claim: " + question.strip() + " Is this claim true or false. The claim is"
    else:
        prompted_question = "Consider the following claim: " + question.strip() + ". Is this claim true or false. The claim is"
    
    # Tokenize input
    inputs = tokenizer(prompted_question, return_tensors="pt").to(device)
    # inputs["input_ids"] = inputs["input_ids"].to(torch.long)

    with torch.no_grad():
        outputs = model(**inputs)
    
    #   
    if i == 0:
        # remove_hooks()
        # grad_norms = compute_gradients(model, inputs)
        # grad_norms = compute_gradients_with_svd_last_k(model, inputs, grad_svd_dict)
        grad_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
    else:
        # iter_norms = compute_gradients(model, inputs)
        # iter_norms = compute_gradients_with_svd_last_k(model, inputs, grad_svd_dict)
        iter_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
        for key in iter_norms:
            grad_norms[key] += iter_norms[key]
    # compute_gradients_with_svd(model, inputs)



# #   
# remove_hooks()

# Print per-layer information
# print(f"Number of layers recorded: {len(per_layer_info)}")
# for idx, info in enumerate(per_layer_info):
#     if idx % 2 == 0:
#         print(f"Layer {idx / 2} matrix fc_in: {info}")
#     else:
#         print(f"Layer {(idx - 1) / 2} matrix fc_out: {info}")

# print("Grad Norms List: ", grad_norms_list)

for key in grad_norms:
    grad_norms[key] /= num_iters

print("Grad norms for Fever Dataset Clustering: ", grad_norms)

# Sorting by values in descending order
top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
print("Top 10: ", top_10_weights)
'''

'''
dataset_util = BiasBiosGender()
dataset = dataset_util.get_dataset_no_logger()
sampled_data = dataset

print("Number of sampled_data: ", len(sampled_data))
per_layer_info = []
grad_norms = {}
num_clusters = 16
grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
num_iters = 100
sampled_data = random.sample(sampled_data, num_iters)

# Run inference
for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
    # print(text)
    question = sampled_data[i]["hard_text"]
    answer_ix = sampled_data[i]["answer"]
    assert answer_ix in [0, 1], "Answer must be 0 (False) or 1 (True)"
    
    # Create the prompt dynamically
    if question.strip().endswith(".") or question.strip().endswith("?"):
        prompted_question = "Consider the following text: " + question.strip() + " Is the person in this text a male or female? The gender of this person is"
    else:
        prompted_question = "Consider the following text: " + question.strip() + ". Is the person in this text a male or female? The gender of this person is"
    
    # Tokenize input
    inputs = tokenizer(prompted_question, return_tensors="pt").to(device)
    # inputs["input_ids"] = inputs["input_ids"].to(torch.long)

    with torch.no_grad():
        outputs = model(**inputs)
    
    #   
    if i == 0:
        # remove_hooks()
        grad_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
    else:
        iter_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
        for key in iter_norms:
            grad_norms[key] += iter_norms[key]
    # compute_gradients_with_svd(model, inputs)
    del inputs
    torch.cuda.empty_cache()

# print("Grad Norms List: ", grad_norms_list)

for key in grad_norms:
    grad_norms[key] /= num_iters

print("Grad norms for Bios Gender Dataset Clustering: ", grad_norms)

# Sorting by values in descending order
top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
print("Top 10: ", top_10_weights)
'''


'''
dataset_util = BiasBiosOccupation()
sampled_data = dataset_util.get_dataset_no_logger()
# sampled_data = dataset

print("Number of sampled_data: ", len(sampled_data))
per_layer_info = []
grad_norms = {}
num_clusters = 16
grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
num_iters = 100
sampled_data = random.sample(sampled_data, num_iters)

# Run inference
for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
    # print(text)
    question = sampled_data[i]["hard_text"]
    answer_ix = sampled_data[i]["answer"]
    # assert answer_ix in [0, 1], "Answer must be 0 (False) or 1 (True)"
    
    # Create the prompt dynamically
    if question.strip().endswith(".") or question.strip().endswith("?"):
        prompted_question = "Consider the following text: " + question.strip()
    else:
        prompted_question = "Consider the following text: " + question.strip() + "."
    prompted_question += " What is the profession of the person in this text? The profession of this person is"  
    # Tokenize input
    with torch.no_grad():
        inputs = tokenizer(prompted_question, return_tensors="pt").to(device)
    # inputs["input_ids"] = inputs["input_ids"].to(torch.long)

    with torch.no_grad():
        outputs = model(**inputs)
    
    #   
    if i == 0:
        # remove_hooks()
        grad_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
    else:
        iter_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
        for key in iter_norms:
            grad_norms[key] += iter_norms[key]
    # compute_gradients_with_svd(model, inputs)
    del inputs
    del outputs
    torch.cuda.empty_cache()

# print("Grad Norms List: ", grad_norms_list)

for key in grad_norms:
    grad_norms[key] /= num_iters

print("Grad norms for Bios Profession Dataset: ", grad_norms)

top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
print("Top 10: ", top_10_weights)
'''


'''
dataset = get_truthfulqa_pointwise_data_no_logger()
sampled_data = dataset

print("Number of sampled_data: ", len(sampled_data))
per_layer_info = []
grad_norms = {}
num_clusters = 16
grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
num_iters = 100 # len(sampled_data)
sampled_data = random.sample(sampled_data, num_iters)
# Run inference
for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
    # print(text)
    prompt = dataset[i][0]
    label = dataset[i][1]
 
    # Tokenize input
    with torch.no_grad():
        input_and_answer = tokenizer(prompt, return_tensors="pt").to(device)
    # inputs["input_ids"] = inputs["input_ids"].to(torch.long)

    with torch.no_grad():
        outputs = model(input_and_answer.input_ids)
    
    #   
    if i == 0:
        # remove_hooks()
        # grad_norms = compute_gradients(model, input_and_answer)
        grad_norms = compute_gradients_with_svd_clustering_last_k(model, input_and_answer, grad_svd_dict, num_clusters=num_clusters)
    else:
        # iter_norms = compute_gradients(model, input_and_answer)
        iter_norms = compute_gradients_with_svd_clustering_last_k(model, input_and_answer, grad_svd_dict, num_clusters=num_clusters)
        for key in iter_norms:
            grad_norms[key] += iter_norms[key]
    # compute_gradients_with_svd(model, input_and_answer)
    del input_and_answer
    torch.cuda.empty_cache()

# print("Grad Norms List: ", grad_norms_list)

for key in grad_norms:
    grad_norms[key] /= num_iters

print("Grad norms for TruthfulQA Dataset: ", grad_norms)

# Sorting by values in descending order
top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
print("Top 10: ", top_10_weights)

def get_choice_tokens(choices, tokenizer):

    choice_token_ids = []
    for choice in choices:
        assert not choice.startswith(" "), f"Expecting choice token {choice} to not start with space"
        assert not choice.endswith(" "), f"Expecting choice token {choice} to not end with space"
        token_ids = tokenizer(f" {choice}")

        if len(token_ids["input_ids"]) != 1:
            # This is a multi-token target and so must be evaluated differently
            return None
        else:
            token_id = int(token_ids["input_ids"][0])
            choice_token_ids.append(token_id)

    return choice_token_ids
'''

'''
dataset, choices = get_bb_dataset("epistemic_reasoning")
sampled_data = dataset

choice_token_ids = get_choice_tokens(choices, tokenizer)
single_token_choices = False
if choice_token_ids is None:
    single_token_choices = False
else:
    single_token_choices = True

print("Number of sampled_data: ", len(sampled_data))
per_layer_info = []
grad_norms = {}
num_iters = 100 # len(sampled_data)
num_clusters = 16
grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
sampled_data = random.sample(sampled_data, num_iters)
# Run inference
for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
    # print(text)
    prompt = dataset[i][0]
    label = dataset[i][1]
 
    # Tokenize input
    with torch.no_grad():

        if single_token_choices:
            input_and_answer = tokenizer(prompt, return_tensors="pt").to(device)
        else:
            input_and_answer = tokenizer(prompt + " " + choices[0], return_tensors="pt").to(device)
    # inputs["input_ids"] = inputs["input_ids"].to(torch.long)

    with torch.no_grad():
        outputs = model(input_and_answer.input_ids)
    
    #   
    if i == 0:
        # remove_hooks()
        grad_norms = compute_gradients_with_svd_clustering_last_k(model, input_and_answer, grad_svd_dict, num_clusters=num_clusters)
    else:
        iter_norms = compute_gradients_with_svd_clustering_last_k(model, input_and_answer, grad_svd_dict, num_clusters=num_clusters)
        for key in iter_norms:
            grad_norms[key] += iter_norms[key]
    # compute_gradients_with_svd(model, input_and_answer)
    del input_and_answer
    torch.cuda.empty_cache()

# print("Grad Norms List: ", grad_norms_list)

for key in grad_norms:
    grad_norms[key] /= num_iters

print("Grad norms for BBH ER Dataset: ", grad_norms)

# Sorting by values in descending order
top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
print("Top 10: ", top_10_weights)
'''

'''
dataset, choices = get_bb_dataset("qa_wikidata")
sampled_data = dataset

# choice_token_ids = get_choice_tokens(choices, tokenizer)
# single_token_choices = False
# if choice_token_ids is None:
#     single_token_choices = False
# else:
#     single_token_choices = True

print("Number of sampled_data: ", len(sampled_data))
per_layer_info = []
grad_norms = {}
num_iters = len(sampled_data)
num_clusters = 1
grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
sampled_data = random.sample(sampled_data, num_iters)
# Run inference
for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
    # print(text)
    prompt = dataset[i][0].strip()
    answer = dataset[i][1].strip()
 
    # Tokenize input
    with torch.no_grad():
        input_and_answer = tokenizer(prompt + " " + answer, return_tensors="pt").to(device)

    # with torch.no_grad():
    #     outputs = model(input_and_answer.input_ids)
    
    #   
    if i == 0:
        # remove_hooks()
        # grad_norms = compute_gradients_with_svd_last_k(model, input_and_answer, grad_svd_dict)
        grad_norms = compute_gradients_with_svd_clustering_last_k(model, input_and_answer, grad_svd_dict, num_clusters=num_clusters)
    else:
        # iter_norms = compute_gradients_with_svd_last_k(model, input_and_answer, grad_svd_dict)
        iter_norms = compute_gradients_with_svd_clustering_last_k(model, input_and_answer, grad_svd_dict, num_clusters=num_clusters)
        for key in iter_norms:
            grad_norms[key] += iter_norms[key]
    # compute_gradients_with_svd(model, input_and_answer)
    del input_and_answer
    torch.cuda.empty_cache()

# print("Grad Norms List: ", grad_norms_list)

for key in grad_norms:
    grad_norms[key] /= num_iters

print("Grad norms for BBH WikidataQA Dataset: ", grad_norms)

# Sorting by values in descending order
top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
print("Top 10: ", top_10_weights)
'''



with open("counterfact", "rb") as f:
    data = pickle.load(f)

num_dp = len(data)
dataset = []

for i in range(num_dp):
    question = data[i]["question"]
    answer = data[i]["gold-answer"]
    assert answer.startswith(" "), f"Found answer that doesn't start with space ${answer}$"
    dataset.append((question, answer))

validation_index = int(len(dataset) * 0.2)
dataset = dataset[:validation_index]
sampled_data = dataset

print("Number of sampled_data: ", len(sampled_data))
per_layer_info = []
grad_norms = {}
num_iters = 100 # len(sampled_data)
num_clusters = 1
grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
sampled_data = random.sample(sampled_data, num_iters)
# Run inference
for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
    # print(text)
    question = dataset[i][0].strip()
    answer = dataset[i][1].strip()
 
    # Tokenize input
    with torch.no_grad():
        inputs = tokenizer(question, return_tensors="pt").to(device)

    # with torch.no_grad():
    #     outputs = model(inputs.input_ids)
    
    #   
    if i == 0:
        # remove_hooks()
        # grad_norms = compute_gradients_with_svd_last_k(model, inputs, grad_svd_dict)
        grad_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
    else:
        # iter_norms = compute_gradients_with_svd_last_k(model, inputs, grad_svd_dict)
        iter_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
        for key in iter_norms:
            grad_norms[key] += iter_norms[key]
    # compute_gradients_with_svd(model, inputs)
    del inputs
    torch.cuda.empty_cache()

# print("Grad Norms List: ", grad_norms_list)

for key in grad_norms:
    grad_norms[key] /= num_iters

print("Grad norms for Counterfact Dataset: ", grad_norms)

# Sorting by values in descending order
top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
print("Top 10: ", top_10_weights)


'''
dataset_util = Hotpot(llama_tokenizer_path="/data/cl/scratch/llama_weights")
sampled_data = dataset_util.get_dataset_no_logger()
# sampled_data = dataset

print("Number of sampled_data: ", len(sampled_data))
per_layer_info = []
grad_norms = {}
num_iters = 100 # len(sampled_data)
num_clusters = 2
grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
sampled_data = random.sample(sampled_data, num_iters)
# Run inference
for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
    # print(text)
    question = sampled_data[i]["question"]

    if not question.endswith("?") and not question.endswith("."):
        prompted_question = f"{question}? The answer is"
    else:
        prompted_question = f"{question} The answer is"

    answer = sampled_data[i]["answer"]
    inputs = tokenizer(prompted_question, return_tensors="pt").to(device)

    # inputs["input_ids"] = inputs["input_ids"].to(torch.long)

    # with torch.no_grad():
    #     outputs = model(**inputs)
    
    #   
    if i == 0:
        # remove_hooks()
        # grad_norms = compute_gradients_with_svd_last_k(model, inputs)
        grad_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
    else:
        # iter_norms = compute_gradients_with_svd_last_k(model, inputs)
        iter_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
        for key in iter_norms:
            grad_norms[key] += iter_norms[key]
    # compute_gradients_with_svd(model, inputs)
    del inputs
    # del outputs
    torch.cuda.empty_cache()

# print("Grad Norms List: ", grad_norms_list)

for key in grad_norms:
    grad_norms[key] /= num_iters

print("Grad norms for Hotpot Dataset: ", grad_norms)

top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
print("Top 10: ", top_10_weights)
'''

# Normalize Results:
# fever_results_dict = {'transformer.h.0.mlp.fc_in.weight': 1.04865012717329, 'transformer.h.0.mlp.fc_out.weight': 1.8958723908578525, 'transformer.h.1.mlp.fc_in.weight': 1.5124460409103935, 'transformer.h.1.mlp.fc_out.weight': 1.5990787399694306, 'transformer.h.2.mlp.fc_in.weight': 1.4348533551657432, 'transformer.h.2.mlp.fc_out.weight': 2.720607595768055, 'transformer.h.3.mlp.fc_in.weight': 1.6682232458444783, 'transformer.h.3.mlp.fc_out.weight': 1.6678030664883454, 'transformer.h.4.mlp.fc_in.weight': 1.7731938258263278, 'transformer.h.4.mlp.fc_out.weight': 1.5445393849111577, 'transformer.h.5.mlp.fc_in.weight': 1.8115178400840657, 'transformer.h.5.mlp.fc_out.weight': 1.5181957065580818, 'transformer.h.6.mlp.fc_in.weight': 1.6408709131400459, 'transformer.h.6.mlp.fc_out.weight': 1.4005137681505542, 'transformer.h.7.mlp.fc_in.weight': 1.668882248280474, 'transformer.h.7.mlp.fc_out.weight': 1.3572502209113488, 'transformer.h.8.mlp.fc_in.weight': 1.5601434878916698, 'transformer.h.8.mlp.fc_out.weight': 1.2510562323867978, 'transformer.h.9.mlp.fc_in.weight': 1.3920064422525793, 'transformer.h.9.mlp.fc_out.weight': 1.2335454629585403, 'transformer.h.10.mlp.fc_in.weight': 1.34247621471389, 'transformer.h.10.mlp.fc_out.weight': 1.1597498776031716, 'transformer.h.11.mlp.fc_in.weight': 1.1570056511511273, 'transformer.h.11.mlp.fc_out.weight': 1.073791853744746, 'transformer.h.12.mlp.fc_in.weight': 1.136183450456152, 'transformer.h.12.mlp.fc_out.weight': 1.0196066285345815, 'transformer.h.13.mlp.fc_in.weight': 1.131860752710642, 'transformer.h.13.mlp.fc_out.weight': 0.9794250185087887, 'transformer.h.14.mlp.fc_in.weight': 1.030036480225449, 'transformer.h.14.mlp.fc_out.weight': 0.9494988073772449, 'transformer.h.15.mlp.fc_in.weight': 0.9516437371632595, 'transformer.h.15.mlp.fc_out.weight': 0.8647714687022354, 'transformer.h.16.mlp.fc_in.weight': 0.9465958008931983, 'transformer.h.16.mlp.fc_out.weight': 0.8289427824679977, 'transformer.h.17.mlp.fc_in.weight': 0.9532393738655903, 'transformer.h.17.mlp.fc_out.weight': 0.8005265675749904, 'transformer.h.18.mlp.fc_in.weight': 0.8732174096054642, 'transformer.h.18.mlp.fc_out.weight': 0.7715522903133358, 'transformer.h.19.mlp.fc_in.weight': 0.8645936574560565, 'transformer.h.19.mlp.fc_out.weight': 0.7824961713675009, 'transformer.h.20.mlp.fc_in.weight': 0.7785223792152274, 'transformer.h.20.mlp.fc_out.weight': 0.7247054639735384, 'transformer.h.21.mlp.fc_in.weight': 0.8486897195619985, 'transformer.h.21.mlp.fc_out.weight': 0.7336999516383264, 'transformer.h.22.mlp.fc_in.weight': 0.8112948763493504, 'transformer.h.22.mlp.fc_out.weight': 0.7247228159629346, 'transformer.h.23.mlp.fc_in.weight': 0.9652475102693925, 'transformer.h.23.mlp.fc_out.weight': 0.7422524300248377, 'transformer.h.24.mlp.fc_in.weight': 1.1571847684610241, 'transformer.h.24.mlp.fc_out.weight': 0.7965253480846389, 'transformer.h.25.mlp.fc_in.weight': 1.2416857849517577, 'transformer.h.25.mlp.fc_out.weight': 0.9017840084065724, 'transformer.h.26.mlp.fc_in.weight': 1.387544928591899, 'transformer.h.26.mlp.fc_out.weight': 0.9103838782479939, 'transformer.h.27.mlp.fc_in.weight': 3.0038241545662974, 'transformer.h.27.mlp.fc_out.weight': 1.0481553156047}
# biosg_results_dict = {'transformer.h.0.mlp.fc_in.weight': 1.1317874145507814, 'transformer.h.0.mlp.fc_out.weight': 2.1360762532552084, 'transformer.h.1.mlp.fc_in.weight': 1.5963194274902344, 'transformer.h.1.mlp.fc_out.weight': 1.6691835530598957, 'transformer.h.2.mlp.fc_in.weight': 1.6080424499511718, 'transformer.h.2.mlp.fc_out.weight': 2.5285362752278644, 'transformer.h.3.mlp.fc_in.weight': 1.8505101013183594, 'transformer.h.3.mlp.fc_out.weight': 1.8290023803710938, 'transformer.h.4.mlp.fc_in.weight': 1.867293446858724, 'transformer.h.4.mlp.fc_out.weight': 1.6914052327473958, 'transformer.h.5.mlp.fc_in.weight': 2.0216649373372397, 'transformer.h.5.mlp.fc_out.weight': 1.707790018717448, 'transformer.h.6.mlp.fc_in.weight': 1.8550031534830729, 'transformer.h.6.mlp.fc_out.weight': 1.5871132405598958, 'transformer.h.7.mlp.fc_in.weight': 1.8556094360351563, 'transformer.h.7.mlp.fc_out.weight': 1.5286573282877605, 'transformer.h.8.mlp.fc_in.weight': 1.7395070393880208, 'transformer.h.8.mlp.fc_out.weight': 1.3419957478841147, 'transformer.h.9.mlp.fc_in.weight': 1.3841307067871094, 'transformer.h.9.mlp.fc_out.weight': 1.2363413492838542, 'transformer.h.10.mlp.fc_in.weight': 1.3191409810384114, 'transformer.h.10.mlp.fc_out.weight': 1.1392608642578126, 'transformer.h.11.mlp.fc_in.weight': 1.1001375834147136, 'transformer.h.11.mlp.fc_out.weight': 1.0730325317382812, 'transformer.h.12.mlp.fc_in.weight': 1.046516876220703, 'transformer.h.12.mlp.fc_out.weight': 1.0113560485839843, 'transformer.h.13.mlp.fc_in.weight': 1.045888671875, 'transformer.h.13.mlp.fc_out.weight': 0.9552669270833334, 'transformer.h.14.mlp.fc_in.weight': 0.873486073811849, 'transformer.h.14.mlp.fc_out.weight': 0.9198356628417969, 'transformer.h.15.mlp.fc_in.weight': 0.8060215250651042, 'transformer.h.15.mlp.fc_out.weight': 0.8255031331380208, 'transformer.h.16.mlp.fc_in.weight': 0.8118447875976562, 'transformer.h.16.mlp.fc_out.weight': 0.8067711893717447, 'transformer.h.17.mlp.fc_in.weight': 0.8049703470865885, 'transformer.h.17.mlp.fc_out.weight': 0.7744306945800781, 'transformer.h.18.mlp.fc_in.weight': 0.7795513407389323, 'transformer.h.18.mlp.fc_out.weight': 0.7640529378255209, 'transformer.h.19.mlp.fc_in.weight': 0.778843994140625, 'transformer.h.19.mlp.fc_out.weight': 0.7669414774576823, 'transformer.h.20.mlp.fc_in.weight': 0.6795903015136718, 'transformer.h.20.mlp.fc_out.weight': 0.6931246948242188, 'transformer.h.21.mlp.fc_in.weight': 0.7565930684407552, 'transformer.h.21.mlp.fc_out.weight': 0.7004726155598958, 'transformer.h.22.mlp.fc_in.weight': 0.7190971883138021, 'transformer.h.22.mlp.fc_out.weight': 0.7130079142252604, 'transformer.h.23.mlp.fc_in.weight': 0.8633175150553385, 'transformer.h.23.mlp.fc_out.weight': 0.7464804585774739, 'transformer.h.24.mlp.fc_in.weight': 1.1283934529622395, 'transformer.h.24.mlp.fc_out.weight': 0.7542613220214843, 'transformer.h.25.mlp.fc_in.weight': 1.1461099243164063, 'transformer.h.25.mlp.fc_out.weight': 0.8673237101236979, 'transformer.h.26.mlp.fc_in.weight': 1.2906497192382813, 'transformer.h.26.mlp.fc_out.weight': 0.8983906046549479, 'transformer.h.27.mlp.fc_in.weight': 2.8138991292317708, 'transformer.h.27.mlp.fc_out.weight': 1.000883585611979}
# truth_results_dict = {'transformer.h.0.mlp.fc_in.weight': 0.8777266521843112, 'transformer.h.0.mlp.fc_out.weight': 1.6064884938350341, 'transformer.h.1.mlp.fc_in.weight': 1.2027218358046343, 'transformer.h.1.mlp.fc_out.weight': 1.2973387841464712, 'transformer.h.2.mlp.fc_in.weight': 1.2509977379623725, 'transformer.h.2.mlp.fc_out.weight': 2.353120349702381, 'transformer.h.3.mlp.fc_in.weight': 1.3810395740327381, 'transformer.h.3.mlp.fc_out.weight': 1.4046365128082483, 'transformer.h.4.mlp.fc_in.weight': 1.5133745881164966, 'transformer.h.4.mlp.fc_out.weight': 1.3817931713701106, 'transformer.h.5.mlp.fc_in.weight': 1.6892268381962159, 'transformer.h.5.mlp.fc_out.weight': 1.4443288790125426, 'transformer.h.6.mlp.fc_in.weight': 1.5921203198076106, 'transformer.h.6.mlp.fc_out.weight': 1.3640697245695153, 'transformer.h.7.mlp.fc_in.weight': 1.6079620568930697, 'transformer.h.7.mlp.fc_out.weight': 1.3274681454613095, 'transformer.h.8.mlp.fc_in.weight': 1.5647421077806123, 'transformer.h.8.mlp.fc_out.weight': 1.2136857428518282, 'transformer.h.9.mlp.fc_in.weight': 1.2985902124521684, 'transformer.h.9.mlp.fc_out.weight': 1.1563189240539966, 'transformer.h.10.mlp.fc_in.weight': 1.2878476097470237, 'transformer.h.10.mlp.fc_out.weight': 1.1204572405133928, 'transformer.h.11.mlp.fc_in.weight': 1.1697487344547195, 'transformer.h.11.mlp.fc_out.weight': 1.0627815090880102, 'transformer.h.12.mlp.fc_in.weight': 1.1125330503295068, 'transformer.h.12.mlp.fc_out.weight': 0.9943673270089286, 'transformer.h.13.mlp.fc_in.weight': 1.0998119951105443, 'transformer.h.13.mlp.fc_out.weight': 0.9519242267219388, 'transformer.h.14.mlp.fc_in.weight': 0.9840378534226191, 'transformer.h.14.mlp.fc_out.weight': 0.9209013439360119, 'transformer.h.15.mlp.fc_in.weight': 0.8770963707748725, 'transformer.h.15.mlp.fc_out.weight': 0.8270749461894132, 'transformer.h.16.mlp.fc_in.weight': 0.8657928259194303, 'transformer.h.16.mlp.fc_out.weight': 0.79812850433142, 'transformer.h.17.mlp.fc_in.weight': 0.8719063064678997, 'transformer.h.17.mlp.fc_out.weight': 0.7614272759885204, 'transformer.h.18.mlp.fc_in.weight': 0.7929147733312075, 'transformer.h.18.mlp.fc_out.weight': 0.7274975419855442, 'transformer.h.19.mlp.fc_in.weight': 0.7742658342633929, 'transformer.h.19.mlp.fc_out.weight': 0.7405586502178997, 'transformer.h.20.mlp.fc_in.weight': 0.6900667982036565, 'transformer.h.20.mlp.fc_out.weight': 0.6693580264136905, 'transformer.h.21.mlp.fc_in.weight': 0.7512975160767432, 'transformer.h.21.mlp.fc_out.weight': 0.6670590355282738, 'transformer.h.22.mlp.fc_in.weight': 0.7024291161777211, 'transformer.h.22.mlp.fc_out.weight': 0.6823485497714711, 'transformer.h.23.mlp.fc_in.weight': 0.8871077972204506, 'transformer.h.23.mlp.fc_out.weight': 0.7016128228635204, 'transformer.h.24.mlp.fc_in.weight': 1.0964538159013606, 'transformer.h.24.mlp.fc_out.weight': 0.7245973340508078, 'transformer.h.25.mlp.fc_in.weight': 1.18500212585034, 'transformer.h.25.mlp.fc_out.weight': 0.8602236793154762, 'transformer.h.26.mlp.fc_in.weight': 1.2609813456632653, 'transformer.h.26.mlp.fc_out.weight': 0.8239114981930272, 'transformer.h.27.mlp.fc_in.weight': 2.803566446109694, 'transformer.h.27.mlp.fc_out.weight': 0.9956868489583334}
# bbher_results_dict = {'transformer.h.0.mlp.fc_in.weight': 0.72726318359375, 'transformer.h.0.mlp.fc_out.weight': 1.32373291015625, 'transformer.h.1.mlp.fc_in.weight': 0.962685546875, 'transformer.h.1.mlp.fc_out.weight': 1.043179931640625, 'transformer.h.2.mlp.fc_in.weight': 1.01311767578125, 'transformer.h.2.mlp.fc_out.weight': 1.77912353515625, 'transformer.h.3.mlp.fc_in.weight': 1.010750732421875, 'transformer.h.3.mlp.fc_out.weight': 1.073232421875, 'transformer.h.4.mlp.fc_in.weight': 1.183431396484375, 'transformer.h.4.mlp.fc_out.weight': 1.098284912109375, 'transformer.h.5.mlp.fc_in.weight': 1.33456787109375, 'transformer.h.5.mlp.fc_out.weight': 1.18592041015625, 'transformer.h.6.mlp.fc_in.weight': 1.2648828125, 'transformer.h.6.mlp.fc_out.weight': 1.16466552734375, 'transformer.h.7.mlp.fc_in.weight': 1.32064208984375, 'transformer.h.7.mlp.fc_out.weight': 1.2002880859375, 'transformer.h.8.mlp.fc_in.weight': 1.324619140625, 'transformer.h.8.mlp.fc_out.weight': 1.155164794921875, 'transformer.h.9.mlp.fc_in.weight': 1.24272705078125, 'transformer.h.9.mlp.fc_out.weight': 1.19887939453125, 'transformer.h.10.mlp.fc_in.weight': 1.34179931640625, 'transformer.h.10.mlp.fc_out.weight': 1.154412841796875, 'transformer.h.11.mlp.fc_in.weight': 1.16597900390625, 'transformer.h.11.mlp.fc_out.weight': 1.084248046875, 'transformer.h.12.mlp.fc_in.weight': 1.117425537109375, 'transformer.h.12.mlp.fc_out.weight': 1.01807373046875, 'transformer.h.13.mlp.fc_in.weight': 1.182584228515625, 'transformer.h.13.mlp.fc_out.weight': 1.015654296875, 'transformer.h.14.mlp.fc_in.weight': 1.00908935546875, 'transformer.h.14.mlp.fc_out.weight': 0.964180908203125, 'transformer.h.15.mlp.fc_in.weight': 0.899246826171875, 'transformer.h.15.mlp.fc_out.weight': 0.8433251953125, 'transformer.h.16.mlp.fc_in.weight': 0.882059326171875, 'transformer.h.16.mlp.fc_out.weight': 0.8004541015625, 'transformer.h.17.mlp.fc_in.weight': 0.8288427734375, 'transformer.h.17.mlp.fc_out.weight': 0.76293212890625, 'transformer.h.18.mlp.fc_in.weight': 0.751593017578125, 'transformer.h.18.mlp.fc_out.weight': 0.748994140625, 'transformer.h.19.mlp.fc_in.weight': 0.710816650390625, 'transformer.h.19.mlp.fc_out.weight': 0.73787109375, 'transformer.h.20.mlp.fc_in.weight': 0.63523193359375, 'transformer.h.20.mlp.fc_out.weight': 0.64753662109375, 'transformer.h.21.mlp.fc_in.weight': 0.694547119140625, 'transformer.h.21.mlp.fc_out.weight': 0.64552490234375, 'transformer.h.22.mlp.fc_in.weight': 0.640220947265625, 'transformer.h.22.mlp.fc_out.weight': 0.646861572265625, 'transformer.h.23.mlp.fc_in.weight': 0.809072265625, 'transformer.h.23.mlp.fc_out.weight': 0.67755859375, 'transformer.h.24.mlp.fc_in.weight': 0.92193603515625, 'transformer.h.24.mlp.fc_out.weight': 0.6558837890625, 'transformer.h.25.mlp.fc_in.weight': 1.07173828125, 'transformer.h.25.mlp.fc_out.weight': 0.77143798828125, 'transformer.h.26.mlp.fc_in.weight': 1.13881591796875, 'transformer.h.26.mlp.fc_out.weight': 0.803399658203125, 'transformer.h.27.mlp.fc_in.weight': 2.58853515625, 'transformer.h.27.mlp.fc_out.weight': 0.9593408203125}

# for name, param in model.named_parameters():
#     if ("fc_in" in name or "fc_out" in name) and ".weight" in name:
#         print(param.size())
#         fever_results_dict[name] = fever_results_dict[name]**2 / (param.size()[0] * param.size()[1])

# print("Normalized Fever Results: ", fever_results_dict)

# per_layer_info = []

# Register hooks
# Use a dictionary to store unique per-layer statistics
# layer_statistics = defaultdict(list)
# register_hooks(model)


# with open("counterfact", "rb") as f:
#     data = pickle.load(f)

# num_dp = len(data)
# dataset = []

# for i in range(num_dp):
#     question = data[i]["question"]
#     answer = data[i]["gold-answer"]
#     assert answer.startswith(" "), f"Found answer that doesn't start with space ${answer}$"
#     dataset.append((question, answer))

# validation_index = int(len(dataset) * 0.2)
# dataset = dataset[:validation_index]
# sampled_data = dataset

# print("Number of sampled_data: ", len(sampled_data))
# grad_norms = {}
# num_iters = 100 # len(sampled_data)
# # Run inference
# for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
#     # print(text)
#     question, answer = sampled_data[i]
    
#     # Tokenize input
#     inputs = tokenizer(question, return_tensors="pt").to(device)
#     # inputs["input_ids"] = inputs["input_ids"].to(torch.long)

#     with torch.no_grad():
#         outputs = model(**inputs)
    
#     #   
#     if i == 0:
#         # remove_hooks()
#         # grad_norms = compute_gradients(model, inputs)
#         grad_norms = compute_gradients_with_svd_last_k(model, inputs)
#     else:
#         # iter_norms = compute_gradients(model, inputs)
#         iter_norms = compute_gradients_with_svd_last_k(model, inputs)
#         for key in iter_norms:
#             grad_norms[key] += iter_norms[key]
#     # compute_gradients_with_svd(model, inputs)



# # #   
# # remove_hooks()

# # Print per-layer information
# print(f"Number of layers recorded: {len(per_layer_info)}")
# for idx, info in enumerate(per_layer_info):
#     if idx % 2 == 0:
#         print(f"Layer {idx / 2} matrix fc_in: {info}")
#     else:
#         print(f"Layer {(idx - 1) / 2} matrix fc_out: {info}")

# # print("Grad Norms List: ", grad_norms_list)

# for key in grad_norms:
#     grad_norms[key] /= num_iters

# print("Grad norms for Counterfact Dataset: ", grad_norms)

# # Sorting by values in descending order
# top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
# print("Top 10: ", top_10_weights)