import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np

import os
import time
import torch
import random
import pickle
import argparse
import numpy as np
import pandas as pd

from csv import writer
from tqdm import tqdm
from copy import deepcopy
from transformers import AutoTokenizer
from transformers import GPTJForCausalLM
from dataset_utils.fever import FEVER
# from dataset_utils.hotpot import Hotpot
from dataset_utils.bias_in_bios import BiasBiosGender, BiasBiosOccupation
from dataset_utils.truthfulqa import get_truthfulqa_pointwise_data_no_logger
from dataset_utils.bigbench import get_bb_dataset
from laser.LaserWrapper import LaserWrapper
from study_utils.log_utils import Logger
from study_utils.metric_utils import Metrics, DatasetMetrics, ContextAnswerLogProb
from study_utils.time_utils import elapsed_from_str, Progress

def get_svd_of_weights(model):
    grad_svd_dict = {}
    for name, param in model.named_parameters():
        if ("fc_in" in name or "fc_out" in name) and "weight" in name:
            torch.cuda.empty_cache()
            with torch.no_grad():
                weight = param.float()
                # Perform SVD: A = UDV^T
                U, S, Vh = torch.linalg.svd(weight, full_matrices=False) # S is the diagonal of D
            U = U.cpu()
            S = S.cpu()
            Vh = Vh.cpu()
            
            grad_svd_dict[name] = [U, S, Vh]
    return grad_svd_dict

def get_ksvd_of_weights(model, k):
    grad_svd_dict = {}
    for name, param in model.named_parameters():
        if ("fc_in" in name or "fc_out" in name) and "weight" in name:
            torch.cuda.empty_cache()
            with torch.no_grad():
                weight = param.float()
                n_rows = weight.shape[0]
                rows_per_group = int(n_rows/k)
                grouped_rows = torch.split(weight, rows_per_group)
                U = []
                S = []
                Vh = []
                for idx, group in enumerate(grouped_rows):
                    u, s, vh = torch.linalg.svd(group, full_matrices=False)
                    U.append(u.cpu())
                    # S.append(s.cpu())
                    Vh.append(vh.cpu())
            
            grad_svd_dict[name] = [U, S, Vh]
    return grad_svd_dict

# Load GPT-J model
llm_name = "GPTJ"
llm_path = "EleutherAI/gpt-j-6B"
tokenizer = AutoTokenizer.from_pretrained(llm_path)
model = GPTJForCausalLM.from_pretrained(
    llm_path,
    revision="float16",
    torch_dtype=torch.float16
)

# Ensure model in evaluation mode
model.eval()

# dataset_util = FEVER()
# dataset = dataset_util.get_dataset_no_logger()
# sampled_data = dataset
# print(model)

##################################


hooks = {}

# def register_hooks(model):
#     for i, layer in enumerate(model.transformer.h):  # Iterate through all GPTJBlock layers
#         hooks[f"mlp_fc_in_{i}"] = layer.mlp.fc_in.register_forward_hook(forward_hook_fn)
#         hooks[f"mlp_fc_out_{i}"] = layer.mlp.fc_out.register_forward_hook(forward_hook_fn)

hooks = []  # Use a list to store all hooks for easy cleanup

def register_hooks(model):
    global hooks
    if hooks:
        print("Hooks already registered. Skipping re-registration.")
        return
    
    for i, layer in enumerate(model.transformer.h):
        hook_in = layer.mlp.fc_in.register_forward_hook(forward_hook_fn)
        hook_out = layer.mlp.fc_out.register_forward_hook(forward_hook_fn)
        hooks.extend([hook_in, hook_out])
    print(f"Registered hooks for {len(hooks)} operations.")

def remove_hooks():
    global hooks
    for hook in hooks:
        hook.remove()
    hooks = []
    print("All hooks have been removed.")

def forward_hook_fn(module, input, output):
    print(f"Hook triggered for {module}")
    layer_info = {}
    x = output.detach()
    
    # Layer Norm
    layer_info["mean"] = torch.mean(x).item()
    layer_info["std"] = torch.std(x).item()
    
    matrix = x.view(-1, x.size(-1)).to(torch.float32).cpu().numpy()

    if np.isnan(matrix).any() or np.isinf(matrix).any():
        print("Warning: NaNs or Infs detected in matrix. Replacing with 0.")
        matrix = np.nan_to_num(matrix, nan=0.0, posinf=0.0, neginf=0.0)

    try:
        eigenvalues = np.linalg.eigvals(matrix @ matrix.T)
        eigenvalues = np.sort(eigenvalues)[::-1]
        layer_info["sum_last_k"] = sum(eigenvalues[-5:])

        # Drop-off index
        drop_index = np.argmax(np.diff(eigenvalues))
        layer_info["drop_off_index"] = drop_index
    except np.linalg.LinAlgError as e:
        print(f"LinAlgError: {e}. Skipping eigenvalue computation for this layer.")
        layer_info["sum_last_k"] = None
        layer_info["drop_off_index"] = None
    
    # Save per-layer info
    per_layer_info.append(layer_info)

################################
def compute_gradients(model, inputs):
    torch.cuda.empty_cache()
    for param in model.parameters():
        param.requires_grad = False
    for name, param in model.named_parameters():
        if "fc_in" in name or "fc_out" in name:
            param.requires_grad = True
    outputs = model(**inputs)
    loss = outputs.logits.sum() / outputs.logits.numel()  # Dummy loss for gradients
    # print(f"Loss: {loss.item()}")
    model.zero_grad()
    loss.backward()
    grad_norm_dict = {}
    for name, param in model.named_parameters():
        if ("fc_in" in name or "fc_out" in name) and ".weight" in name:
            grad_norm = torch.norm(param.grad).item()
            # print(f"Gradient norm for {name}: {grad_norm}")
            grad_norm_dict[name] = grad_norm
    return grad_norm_dict

def compute_gradients_with_svd(model, inputs, grad_svd_dict):
    torch.cuda.empty_cache()  # Clear GPU memory cache

    for param in model.parameters():
        param.requires_grad = False

    for name, param in model.named_parameters():
        if "fc_in" in name or "fc_out" in name:
            param.requires_grad = True

    outputs = model(**inputs)
    loss = outputs.logits.sum() / outputs.logits.numel() # Dummy loss for gradients

    model.zero_grad()

    loss.backward()

    svd_grads = {}
    grad_norm_dict = {}
    for name, param in model.named_parameters():
        if ("fc_in" in name or "fc_out" in name) and "weight" in name:
            weight = param.float()
            # Perform SVD: A = UDV^T
            # U, S, Vh = torch.linalg.svd(weight, full_matrices=False)  # S is the diagonal of D
            [U, S, Vh] = grad_svd_dict[name]
            U = U.to("cuda")
            S = S.to("cuda")
            Vh = Vh.to("cuda")
            
            grad = (param.grad).to(torch.float)
            grad_D = U.T @ grad @ Vh.T

            grad_norm_dict[name] = torch.norm(torch.diagonal(grad_D)).item()


    return grad_norm_dict

def compute_gradients_with_svd_last_k(model, inputs, grad_svd_dict, k=20):
    torch.cuda.empty_cache()

    for param in model.parameters():
        param.requires_grad = False

    for name, param in model.named_parameters():
        if "fc_in" in name or "fc_out" in name:
            param.requires_grad = True

    outputs = model(**inputs)
    loss = outputs.logits.sum() / outputs.logits.numel() # Dummy loss for gradients

    model.zero_grad()

    loss.backward()

    svd_grads = {}
    grad_norm_dict = {}
    for name, param in model.named_parameters():
        if ("fc_in" in name or "fc_out" in name) and "weight" in name:

            [U, S, Vh] = grad_svd_dict[name]
            U = U.to("cuda")
            S = S.to("cuda")
            Vh = Vh.to("cuda")
            grad = (param.grad).to(torch.float)
            grad_D = U.T @ grad @ Vh.T
            diag_grad_D = torch.diagonal(grad_D)
            last_k_values = diag_grad_D[-k:]
            negative_idxs = last_k_values < 0
            # print("Last k values: ", last_k_values)
            # print("Gradient of diagonal: ", grad_D)
            # print("Gradient norm of diagonal: ", torch.norm(torch.diagonal(grad_D)))
            grad_norm_dict[name] = -torch.sum(last_k_values[negative_idxs]).item()


    return grad_norm_dict

def compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=2, k=20):
    torch.cuda.empty_cache()

    for param in model.parameters():
        param.requires_grad = False

    for name, param in model.named_parameters():
        if "fc_in" in name or "fc_out" in name:
            param.requires_grad = True

    outputs = model(**inputs)
    loss = outputs.logits.sum() / outputs.logits.numel() # Dummy loss for gradients

    model.zero_grad()

    loss.backward()

    svd_grads = {}
    grad_norm_dict = {}
    iter = 0
    for name, param in model.named_parameters():
        if ("fc_in" in name or "fc_out" in name) and "weight" in name:
            iter += 1
            [U, S, Vh] = grad_svd_dict[name]
            grad = (param.grad).to(torch.float)
            n_rows = grad.shape[0]
            rows_per_group = int(n_rows/num_clusters)
            grouped_rows = torch.split(grad, rows_per_group)
            for idx, group in enumerate(grouped_rows):
                u = U[idx].to("cuda")
                # s = S[idx].to("cuda")
                vh = Vh[idx].to("cuda")
                grad_d = u.T @ group @ vh.T
                diag_grad_d = torch.diagonal(grad_d)
                last_k_idx = (int)(k)
                last_k_values = diag_grad_d[-last_k_idx:]
                negative_idxs = last_k_values < 0
                positive_idxs = last_k_values > 0
                # print("Last k values: ", last_k_values)
                # print("Gradient of diagonal: ", grad_D)
                # print("Gradient norm of diagonal: ", torch.norm(torch.diagonal(grad_D)))
                grad_norm_dict[name] = -torch.sum(last_k_values[negative_idxs]).item()

                del u
                del vh
                del grad_d
                del diag_grad_d
            del grad
            del U
            del Vh
            del S
            del group


    return grad_norm_dict

################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# print(len(sampled_data))
model = model.to(device)


# dataset_util = FEVER()
# dataset = dataset_util.get_dataset_no_logger()
# sampled_data = dataset
# print("Number of sampled_data: ", len(sampled_data))
# grad_norms = {}
# num_clusters = 2
# grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
# num_iters = 100 # len(sampled_data)
# sampled_data = random.sample(sampled_data, num_iters)
# # Run inference
# for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
#     # print(text)
#     question = sampled_data[i]["question"]
#     answer_ix = sampled_data[i]["answer"]
#     assert answer_ix in [0, 1], "Answer must be 0 (False) or 1 (True)"
    
#     # Create the prompt dynamically
#     if question.strip().endswith(".") or question.strip().endswith("?"):
#         prompted_question = "Consider the following claim: " + question.strip() + " Is this claim true or false. The claim is"
#     else:
#         prompted_question = "Consider the following claim: " + question.strip() + ". Is this claim true or false. The claim is"
    
#     # Tokenize input
#     inputs = tokenizer(prompted_question, return_tensors="pt").to(device)
#     # inputs["input_ids"] = inputs["input_ids"].to(torch.long)

#     with torch.no_grad():
#         outputs = model(**inputs)
    
#     if i == 0:
#         # remove_hooks()
#         # grad_norms = compute_gradients(model, inputs)
#         grad_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
#     else:
#         # iter_norms = compute_gradients(model, inputs)
#         iter_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
#         for key in iter_norms:
#             grad_norms[key] += iter_norms[key]
#     # compute_gradients_with_svd(model, inputs)



# # #  
# # remove_hooks()

# # # Print per-layer information
# # print(f"Number of layers recorded: {len(per_layer_info)}")
# # for idx, info in enumerate(per_layer_info):
# #     if idx % 2 == 0:
# #         print(f"Layer {idx / 2} matrix fc_in: {info}")
# #     else:
# #         print(f"Layer {(idx - 1) / 2} matrix fc_out: {info}")

# # print("Grad Norms List: ", grad_norms_list)

# for key in grad_norms:
#     grad_norms[key] /= num_iters

# print("Grad norms for Fever Dataset: ", grad_norms)

# # Sorting by values in descending order
# top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
# print("Top 10: ", top_10_weights)



# dataset_util = BiasBiosGender()
# dataset = dataset_util.get_dataset_no_logger()
# num_iters = 100
# sampled_data = dataset

# print("Number of sampled_data: ", len(sampled_data))
# per_layer_info = []
# grad_norms = {}
# num_clusters = 2
# grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
# sampled_data = random.sample(sampled_data, num_iters)
# # Run inference
# for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
#     # print(text)
    
#     question = sampled_data[i]["hard_text"]
#     answer_ix = sampled_data[i]["answer"]
#     assert answer_ix in [0, 1], "Answer must be 0 (False) or 1 (True)"
    
#     # Create the prompt dynamically
#     if question.strip().endswith(".") or question.strip().endswith("?"):
#         prompted_question = "Consider the following text: " + question.strip() + " Is the person in this text a male or female? The gender of this person is"
#     else:
#         prompted_question = "Consider the following text: " + question.strip() + ". Is the person in this text a male or female? The gender of this person is"
    
#     # Tokenize input
#     inputs = tokenizer(prompted_question, return_tensors="pt").to(device)
#     # inputs["input_ids"] = inputs["input_ids"].to(torch.long)
#     # print(len(inputs))
#     # print(len(prompted_question))
#     # if i == 40:
#     #     print("Index 40!!!")
#     #     print("-----------")
#     #     continue
#     # print("-----------")
#     # with torch.no_grad():
#     #     outputs = model(**inputs)
    
#     #   
#     if i == 0:
#         # remove_hooks()
#         # grad_norms = compute_gradients_with_svd_last_k(model, inputs)
#         grad_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters, k=20)
#     else:
#         # iter_norms = compute_gradients_with_svd_last_k(model, inputs)
#         iter_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters, k=20)
#         for key in iter_norms:
#             grad_norms[key] += iter_norms[key]
#     # compute_gradients_with_svd(model, inputs)
#     del inputs
#     # del outputs
#     torch.cuda.empty_cache()

# # print("Grad Norms List: ", grad_norms_list)

# for key in grad_norms:
#     grad_norms[key] /= num_iters

# print("Grad norms for Bios Gender Dataset: ", grad_norms)

# # Sorting by values in descending order
# top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
# print("Top 10: ", top_10_weights)

# dataset_util = BiasBiosOccupation()
# sampled_data = dataset_util.get_dataset_no_logger()
# # sampled_data = dataset

# print("Number of sampled_data: ", len(sampled_data))
# per_layer_info = []
# grad_norms = {}
# num_iters = 100 # len(sampled_data)
# num_clusters = 16
# grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
# sampled_data = random.sample(sampled_data, num_iters)
# # Run inference
# for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
#     # print(text)
#     question = sampled_data[i]["hard_text"]
#     answer_ix = sampled_data[i]["answer"]
#     # assert answer_ix in [0, 1], "Answer must be 0 (False) or 1 (True)"
    
#     # Create the prompt dynamically
#     if question.strip().endswith(".") or question.strip().endswith("?"):
#         prompted_question = "Consider the following text: " + question.strip()
#     else:
#         prompted_question = "Consider the following text: " + question.strip() + "."
#     prompted_question += " What is the profession of the person in this text? The profession of this person is"  
#     # Tokenize input
#     with torch.no_grad():
#         inputs = tokenizer(prompted_question, return_tensors="pt").to(device)
#     # inputs["input_ids"] = inputs["input_ids"].to(torch.long)

#     # with torch.no_grad():
#     #     outputs = model(**inputs)
    
#     #   
#     if i == 0:
#         # remove_hooks()
#         # grad_norms = compute_gradients_with_svd_last_k(model, inputs)
#         grad_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
#     else:
#         # iter_norms = compute_gradients_with_svd_last_k(model, inputs)
#         iter_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
#         for key in iter_norms:
#             grad_norms[key] += iter_norms[key]
#     # compute_gradients_with_svd(model, inputs)
#     del inputs
#     # del outputs
#     torch.cuda.empty_cache()

# # print("Grad Norms List: ", grad_norms_list)

# for key in grad_norms:
#     grad_norms[key] /= num_iters

# print("Grad norms for Bios Profession Dataset: ", grad_norms)

# top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
# print("Top 10: ", top_10_weights)

dataset = get_truthfulqa_pointwise_data_no_logger()
sampled_data = dataset

print("Number of sampled_data: ", len(sampled_data))
per_layer_info = []
grad_norms = {}
num_iters = 100 # len(sampled_data)
num_clusters = 16
grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
sampled_data = random.sample(sampled_data, num_iters)
# Run inference
for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
    # print(text)
    prompt = dataset[i][0]
    label = dataset[i][1]
 
    # Tokenize input
    with torch.no_grad():
        input_and_answer = tokenizer(prompt, return_tensors="pt").to(device)
    # inputs["input_ids"] = inputs["input_ids"].to(torch.long)

    with torch.no_grad():
        outputs = model(input_and_answer.input_ids)
    
    #   
    if i == 0:
        # remove_hooks()
        # grad_norms = compute_gradients(model, input_and_answer)
        grad_norms = compute_gradients_with_svd_clustering_last_k(model, input_and_answer, grad_svd_dict, num_clusters=num_clusters)
    else:
        # iter_norms = compute_gradients(model, input_and_answer)
        iter_norms = compute_gradients_with_svd_clustering_last_k(model, input_and_answer, grad_svd_dict, num_clusters=num_clusters)
        for key in iter_norms:
            grad_norms[key] += iter_norms[key]
    # compute_gradients_with_svd(model, input_and_answer)
    del input_and_answer
    torch.cuda.empty_cache()

# print("Grad Norms List: ", grad_norms_list)

for key in grad_norms:
    grad_norms[key] /= num_iters

print("Grad norms for TruthfulQA Dataset: ", grad_norms)

# Sorting by values in descending order
top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
print("Top 10: ", top_10_weights)

def get_choice_tokens(choices, tokenizer):

    choice_token_ids = []
    for choice in choices:
        assert not choice.startswith(" "), f"Expecting choice token {choice} to not start with space"
        assert not choice.endswith(" "), f"Expecting choice token {choice} to not end with space"
        token_ids = tokenizer(f" {choice}")

        if len(token_ids["input_ids"]) != 1:
            # This is a multi-token target and so must be evaluated differently
            return None
        else:
            token_id = int(token_ids["input_ids"][0])
            choice_token_ids.append(token_id)

    return choice_token_ids

# dataset, choices = get_bb_dataset("epistemic_reasoning")
# sampled_data = dataset

# choice_token_ids = get_choice_tokens(choices, tokenizer)
# single_token_choices = False
# if choice_token_ids is None:
#     single_token_choices = False
# else:
#     single_token_choices = True

# print("Number of sampled_data: ", len(sampled_data))
# per_layer_info = []
# grad_norms = {}
# num_iters = 100 # len(sampled_data)
# num_clusters = 16
# grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
# sampled_data = random.sample(sampled_data, num_iters)
# # Run inference
# for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
#     # print(text)
#     prompt = dataset[i][0]
#     label = dataset[i][1]
 
#     # Tokenize input
#     with torch.no_grad():

#         if single_token_choices:
#             input_and_answer = tokenizer(prompt, return_tensors="pt").to(device)
#         else:
#             input_and_answer = tokenizer(prompt + " " + choices[0], return_tensors="pt").to(device)
#     # inputs["input_ids"] = inputs["input_ids"].to(torch.long)

#     with torch.no_grad():
#         outputs = model(input_and_answer.input_ids)
    
#     #   
#     if i == 0:
#         # remove_hooks()
#         # grad_norms = compute_gradients_with_svd_last_k(model, input_and_answer, grad_svd_dict)
#         grad_norms = compute_gradients_with_svd_clustering_last_k(model, input_and_answer, grad_svd_dict, num_clusters=num_clusters)
#     else:
#         # iter_norms = compute_gradients_with_svd_last_k(model, input_and_answer, grad_svd_dict)
#         iter_norms = compute_gradients_with_svd_clustering_last_k(model, input_and_answer, grad_svd_dict, num_clusters=num_clusters)
#         for key in iter_norms:
#             grad_norms[key] += iter_norms[key]
#     # compute_gradients_with_svd(model, input_and_answer)
#     del input_and_answer
#     torch.cuda.empty_cache()

# # print("Grad Norms List: ", grad_norms_list)

# for key in grad_norms:
#     grad_norms[key] /= num_iters

# print("Grad norms for BBH ER Dataset: ", grad_norms)

# # Sorting by values in descending order
# top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
# print("Top 10: ", top_10_weights)

# dataset, choices = get_bb_dataset("qa_wikidata")
# sampled_data = dataset

# # choice_token_ids = get_choice_tokens(choices, tokenizer)
# # single_token_choices = False
# # if choice_token_ids is None:
# #     single_token_choices = False
# # else:
# #     single_token_choices = True

# print("Number of sampled_data: ", len(sampled_data))
# per_layer_info = []
# grad_norms = {}
# num_iters = len(sampled_data)
# num_clusters = 16
# grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
# sampled_data = random.sample(sampled_data, num_iters)
# # Run inference
# for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
#     # print(text)
#     prompt = dataset[i][0].strip()
#     answer = dataset[i][1].strip()
 
#     # Tokenize input
#     with torch.no_grad():
#         input_and_answer = tokenizer(prompt + " " + answer, return_tensors="pt").to(device)

#     # with torch.no_grad():
#     #     outputs = model(input_and_answer.input_ids)
    
#     #   
#     if i == 0:
#         # remove_hooks()
#         # grad_norms = compute_gradients_with_svd_last_k(model, input_and_answer, grad_svd_dict)
#         grad_norms = compute_gradients_with_svd_clustering_last_k(model, input_and_answer, grad_svd_dict, num_clusters=num_clusters)
#     else:
#         # iter_norms = compute_gradients_with_svd_last_k(model, input_and_answer, grad_svd_dict)
#         iter_norms = compute_gradients_with_svd_clustering_last_k(model, input_and_answer, grad_svd_dict, num_clusters=num_clusters)
#         for key in iter_norms:
#             grad_norms[key] += iter_norms[key]
#     # compute_gradients_with_svd(model, input_and_answer)
#     del input_and_answer
#     torch.cuda.empty_cache()

# # print("Grad Norms List: ", grad_norms_list)

# for key in grad_norms:
#     grad_norms[key] /= num_iters

# print("Grad norms for BBH WikidataQA Dataset: ", grad_norms)

# # Sorting by values in descending order
# top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
# print("Top 10: ", top_10_weights)


# with open("counterfact", "rb") as f:
#     data = pickle.load(f)

# num_dp = len(data)
# dataset = []

# for i in range(num_dp):
#     question = data[i]["question"]
#     answer = data[i]["gold-answer"]
#     assert answer.startswith(" "), f"Found answer that doesn't start with space ${answer}$"
#     dataset.append((question, answer))

# validation_index = int(len(dataset) * 0.2)
# dataset = dataset[:validation_index]
# sampled_data = dataset

# print("Number of sampled_data: ", len(sampled_data))
# per_layer_info = []
# grad_norms = {}
# num_iters = 100 # len(sampled_data)
# num_clusters = 16
# grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
# sampled_data = random.sample(sampled_data, num_iters)
# # Run inference
# for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
#     # print(text)
#     question = dataset[i][0].strip()
#     answer = dataset[i][1].strip()
 
#     # Tokenize input
#     with torch.no_grad():
#         inputs = tokenizer(question, return_tensors="pt").to(device)

#     # with torch.no_grad():
#     #     outputs = model(inputs.input_ids)
    
#     #   
#     if i == 0:
#         # remove_hooks()
#         # grad_norms = compute_gradients_with_svd_last_k(model, inputs, grad_svd_dict)
#         grad_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
#     else:
#         # iter_norms = compute_gradients_with_svd_last_k(model, inputs, grad_svd_dict)
#         iter_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
#         for key in iter_norms:
#             grad_norms[key] += iter_norms[key]
#     # compute_gradients_with_svd(model, inputs)
#     del inputs
#     torch.cuda.empty_cache()

# # print("Grad Norms List: ", grad_norms_list)

# for key in grad_norms:
#     grad_norms[key] /= num_iters

# print("Grad norms for Counterfact Dataset: ", grad_norms)

# # Sorting by values in descending order
# top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
# print("Top 10: ", top_10_weights)

# dataset_util = Hotpot(llama_tokenizer_path="/data/cl/scratch/llama_weights")
# sampled_data = dataset_util.get_dataset_no_logger()
# # sampled_data = dataset

# print("Number of sampled_data: ", len(sampled_data))
# per_layer_info = []
# grad_norms = {}
# num_iters = 100 # len(sampled_data)
# num_clusters = 16
# grad_svd_dict = get_ksvd_of_weights(model, num_clusters)
# sampled_data = random.sample(sampled_data, num_iters)
# # Run inference
# for i in tqdm(range(num_iters), desc="Processing Samples", unit="sample"):
#     # print(text)
#     question = sampled_data[i]["question"]

#     if not question.endswith("?") and not question.endswith("."):
#         prompted_question = f"{question}? The answer is"
#     else:
#         prompted_question = f"{question} The answer is"

#     answer = sampled_data[i]["answer"]
#     inputs = tokenizer(prompted_question, return_tensors="pt").to(device)

#     # inputs["input_ids"] = inputs["input_ids"].to(torch.long)

#     # with torch.no_grad():
#     #     outputs = model(**inputs)
    
#     #   
#     if i == 0:
#         # remove_hooks()
#         # grad_norms = compute_gradients_with_svd_last_k(model, inputs)
#         grad_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
#     else:
#         # iter_norms = compute_gradients_with_svd_last_k(model, inputs)
#         iter_norms = compute_gradients_with_svd_clustering_last_k(model, inputs, grad_svd_dict, num_clusters=num_clusters)
#         for key in iter_norms:
#             grad_norms[key] += iter_norms[key]
#     # compute_gradients_with_svd(model, inputs)
#     del inputs
#     # del outputs
#     torch.cuda.empty_cache()

# # print("Grad Norms List: ", grad_norms_list)

# for key in grad_norms:
#     grad_norms[key] /= num_iters

# print("Grad norms for Hotpot Dataset: ", grad_norms)

# top_10_weights = sorted(grad_norms.items(), key=lambda x: x[1], reverse=True)[:10]
# print("Top 10: ", top_10_weights)