import os
import sys
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from safetensors.torch import load_file as safe_load_file
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
task_string = "chartqa"
r = 64
target_keywords = [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"mlp.gate_proj",
"mlp.up_proj",
"mlp.down_proj"
]

model_path = ""
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=get_model_name_from_path(model_path)
)


device = model.device








full_task_list = ["chartqa", "docvqa", "iconqa_txt", "scienceqa"]
selected_tasks = [t for t in full_task_list if t in task_string]


cov_path_template = "calcov/llava1_5_cov_{}_500_size500_seed42.safetensors"

cov_dicts = {}
for task in selected_tasks:
    path = cov_path_template.format(task)
    assert os.path.exists(path), f"Covariance file not found: {path}"
    raw_cov_dict = safe_load_file(path, device="cpu")
    processed_cov_dict = {k + ".weight" if not k.endswith(".weight") else k: v for k, v in raw_cov_dict.items()}
    cov_dicts[task] = processed_cov_dict
task_to_svd = {}
for task in selected_tasks:
    cov_dict = cov_dicts[task]

    for name, param in model.named_parameters():
        if (
            param.requires_grad and
            any(k in name for k in target_keywords) and
            "vision_tower" not in name
        ):
            if name not in cov_dict:
                continue

            pretrained_w = param.data.float()
            covariance_matrix = cov_dict[name].float()

            covariance_matrix = covariance_matrix.to(pretrained_w.device)

            damp = 0.01
            while True:
                diag_mean = torch.mean(torch.diag(covariance_matrix))
                compensate = torch.diag(
                    torch.ones(covariance_matrix.size(0), device=covariance_matrix.device) * diag_mean * damp
                    )

                fix_cov = covariance_matrix + compensate
                cov_inv = torch.linalg.inv(fix_cov)
                err = torch.dist(fix_cov @ cov_inv, torch.eye(covariance_matrix.size(0), device=device))
                if err.item() < 0.05:
                    break
                damp *= 2

            w = pretrained_w @ fix_cov

            try:
                U, S, Vh = torch.linalg.svd(w, full_matrices=False)
                V_hat_h = Vh @ cov_inv
            except Exception as e:
                continue

            U_r = U[:, :r].to(torch.bfloat16).contiguous().cpu()
            S_r = S[:r].to(torch.bfloat16).contiguous().cpu()
            Vh_r = V_hat_h[:r, :].to(torch.bfloat16).contiguous().cpu()

            task_to_svd[(task, name)] = {
                "U": U_r,
                "S": S_r,
                "Vh": Vh_r,
            }




from safetensors.torch import save_file as safe_save_file

save_dir = "/calsvd"
os.makedirs(save_dir, exist_ok=True)

for task in selected_tasks:
    save_dict = {}
    for (task_name, layer_name), svd_components in task_to_svd.items():
        if task_name != task:
            continue
        base_key = "base_model.model." + layer_name  
        save_dict[f"{base_key}.svd_U"] = svd_components["U"].cpu()
        save_dict[f"{base_key}.svd_S"] = svd_components["S"].cpu()
        save_dict[f"{base_key}.svd_Vh"] = svd_components["Vh"].cpu()
    save_path = os.path.join(save_dir, f"{task}_svd_rank{r}.safetensors")
    safe_save_file(save_dict, save_path)
