import os
import torch
import math
import json
import torch.nn.functional as F
from safetensors.torch import load_file as safe_load_file, save_file as safe_save_file
from collections import defaultdict

base_path = ""
svd_base_path = os.path.join(base_path, "calsvd")
save_base_path = os.path.join(base_path, "initlora")
dataset_list = ["chartqa", "docvqa"]

total_rank = 64
temperature = 0.2
init_lora_scale = 5


normalize_entropy = True
min_rank = 2
copart_rank = total_rank // (len(dataset_list) + 1) 
task_rank_total = total_rank - copart_rank





def spectral_entropy(sigma):
    sigma = sigma.to(dtype=torch.float32)
    if sigma.sum() == 0:
        return 0.0
    p = sigma / sigma.sum()
    entropy = -torch.sum(p * torch.log(p + 1e-12)).item()
    return entropy

def discretize_rank_allocation(weights: torch.Tensor, total_rank: int, min_rank: int) -> list[int]:
    n = len(weights)
    assert total_rank >= n * min_rank
    base = [min_rank] * n
    remaining = total_rank - sum(base)
    if remaining == 0:
        return base
    weights = weights / weights.sum()
    soft = weights * remaining
    floor_add = torch.floor(soft).int().tolist()
    remainders = (soft - torch.tensor(floor_add, dtype=torch.float32)).tolist()
    diff = remaining - sum(floor_add)
    if diff > 0:
        sorted_idx = sorted(range(n), key=lambda i: remainders[i], reverse=True)
        for i in range(diff):
            floor_add[sorted_idx[i]] += 1
    return [base[i] + floor_add[i] for i in range(n)]

layer_entropy = defaultdict(dict)
for dataset in dataset_list:
    svd_path = os.path.join(svd_base_path, f"{dataset}_svd_rank{total_rank}.safetensors")
    if not os.path.exists(svd_path):
        continue
    svd_dict = safe_load_file(svd_path)
    for key, value in svd_dict.items():
        if key.endswith(".svd_S"):
            layer_name = key.replace(".svd_S", "")
            entropy = spectral_entropy(value)
            layer_entropy[layer_name][dataset] = entropy

rank_allocation = defaultdict(dict)
for layer, task_entropy in layer_entropy.items():
    tasks = list(task_entropy.keys())
    entropy_vec = torch.tensor([task_entropy[t] for t in tasks], dtype=torch.float32)
    weights = F.softmax(entropy_vec / temperature, dim=0) if normalize_entropy else entropy_vec / entropy_vec.sum()
    ranks = discretize_rank_allocation(weights, total_rank=task_rank_total, min_rank=min_rank)
    for t, rk in zip(tasks, ranks):
        rank_allocation[t][layer] = rk

layer_init = defaultdict(lambda: {"U": [], "V": []})

for dataset in dataset_list:
    svd_path = os.path.join(svd_base_path, f"{dataset}_svd_rank{total_rank}.safetensors")
    if not os.path.exists(svd_path):
        continue

    svd_dict = safe_load_file(svd_path)
    for key in svd_dict:
        if not key.endswith(".svd_S"):
            continue

        layer = key.replace(".svd_S", "")
        if layer not in rank_allocation[dataset]:
            continue

        rk = rank_allocation[dataset][layer]
        S = svd_dict[f"{layer}.svd_S"][:rk].to(torch.float32)
        U = svd_dict[f"{layer}.svd_U"][:, :rk].to(torch.float32)
        Vh = svd_dict[f"{layer}.svd_Vh"][:rk, :].to(torch.float32)

        sqrt_S = S.sqrt()
        U_scaled = U * sqrt_S.view(1, -1) 
        V_scaled = Vh.T * sqrt_S.view(1, -1) 

        layer_init[layer]["U"].append(U_scaled)
        layer_init[layer]["V"].append(V_scaled)

lora_A, lora_B = {}, {}
for layer, data in layer_init.items():
    num_tasks = len(data["V"])  
    div_scale = (num_tasks + 1) * init_lora_scale 
    scale = 1 / (div_scale ** 0.5)

    A_task = torch.cat([v.T for v in data["V"]], dim=0) * scale
    B_task = torch.cat(data["U"], dim=1) * scale

    in_dim = data["V"][0].size(0)  
    out_dim = data["U"][0].size(0) 

    A_copart = torch.empty(copart_rank, in_dim, dtype=torch.float32)
    torch.nn.init.kaiming_uniform_(A_copart, a=math.sqrt(5))
    A_copart = A_copart * scale

    B_copart = torch.zeros(out_dim, copart_rank, dtype=torch.float32)
    B_copart = B_copart * scale

    A_final = torch.cat([A_task, A_copart], dim=0).to(torch.bfloat16)
    B_final = torch.cat([B_task, B_copart], dim=1).to(torch.bfloat16)

    lora_A[f"{layer}.lora_A"] = A_final
    lora_B[f"{layer}.lora_B"] = B_final

dataset_str = "_".join(dataset_list)
save_dir = os.path.join(save_base_path, f"llava1.5_7b-{dataset_str}-r{total_rank}wCoPartRank{copart_rank}-T{temperature}-S{init_lora_scale}")
os.makedirs(save_dir, exist_ok=True)

safe_save_file(lora_A, os.path.join(save_dir, "lora_A_init.safetensors"))
safe_save_file(lora_B, os.path.join(save_dir, "lora_B_init.safetensors"))

with open(os.path.join(save_dir, "rank_allocation.json"), "w") as f:
    json.dump(rank_allocation, f, indent=2)
