import os
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from safetensors.torch import save_file






@torch.no_grad()
def llava_calib_cov_distribution(model, trainer, epoch_iterator, use_cache=True, calib_dataset="iconqa_txt", calib_size=500, seed=42):

    save_path_prefix = "/calcov"
    save_path_safetensors = f"{save_path_prefix}/llava1_5_cov_{calib_dataset}_size{calib_size}_seed{seed}.safetensors"




    cache_file = save_path_safetensors
    if os.path.exists(cache_file) and use_cache:

        return



    
    model.eval()


    def hook(module, input, output, calib_size=calib_size):
        
        input = input[0].detach().squeeze(0).data 
        input = input.float()
        input = input / torch.max(input).abs()

        if torch.isnan(input).any():
            print("nan detected")
            raise Exception("nan in input, break")
        if torch.isinf(input).any():
            print("inf detected")
            raise Exception("inf in input, break")
        
        covariance = input.t().matmul(input)
        if torch.isnan(covariance).any():
            print("nan detected")
            raise Exception("nan in covariance, break")
        if torch.isinf(covariance).any():
            print("inf detected")
            raise Exception("inf in covariance, break")        
        module.covariance_matrix += covariance / calib_size
        del covariance, input



    target_keywords = [
    "self_attn.q_proj",
    "self_attn.k_proj",
    "self_attn.v_proj",
    "self_attn.o_proj",
    "mlp.gate_proj",
    "mlp.up_proj",
    "mlp.down_proj"
    ]


    for name, module in model.named_modules():
        if (
            isinstance(module, nn.Linear)
            and any(k in name for k in target_keywords)
            and "vision_tower" not in name
        ):
            print(f"Registering covariance_matrix on: {name}")
            module.covariance_matrix = 0
            module.register_forward_hook(hook)



    with torch.no_grad():
        for step, inputs in enumerate(tqdm(epoch_iterator, desc=f"Calculating Cov Matrix of {calib_dataset}")):
            inputs = trainer._prepare_inputs(inputs)
            model(**inputs)

    all_covariance_matrix = {}
    for name, module in model.named_modules():
        if (
            isinstance(module, nn.Linear)
            and any(k in name for k in target_keywords)
            and "vision_tower" not in name
        ):
            module._forward_hooks.clear()
            if torch.isnan(module.covariance_matrix).any():
                print("nan detected")
                raise Exception("nan in covariance")
            if torch.isinf(module.covariance_matrix).any():
                print("inf detected")
                raise Exception("inf in covariance")
            all_covariance_matrix[name] = module.covariance_matrix





    save_file(all_covariance_matrix, save_path_safetensors)  
    print("covariance matrices saved")
