import torch
import os.path as osp
import torch.nn.functional as F
import sys

def compute_svd_base(features, energy=0.95):
    U, S, Vh = torch.svd_lowrank(features, q=min(features.shape), niter=2)
    sval_total = (S ** 2).sum()
    sval_ratio = (S ** 2) / sval_total
    torch.use_deterministic_algorithms(False)
    r = torch.sum(torch.cumsum(sval_ratio, dim=0) < energy) + 1
    torch.use_deterministic_algorithms(True)
    print(f"Selected r {r} for energy threshold {energy}")
    V_k = Vh[:, 0:r]
    mu_S = torch.matmul(V_k, V_k.T)
    return mu_S

def compute_task_id(features, singular_list, threshold=0.99, l2_dist=False):
    bs, _ = features.shape
    n = len(singular_list)
    similarities = torch.zeros(bs, n, device=features.device, dtype=features.dtype)
    for k in range(n):
        singular_k = singular_list[k].to(dtype=features.dtype)
        proj_k = features @ singular_k
        
        features = F.normalize(features, p=2, dim=1)
        proj_k = F.normalize(proj_k, p=2, dim=1)
        if l2_dist:
            l2_dist = torch.norm(features - proj_k, dim=1)
            l2_sim = torch.exp(-l2_dist)
            similarities[:, k] = l2_sim
        else:
            cos_sim = F.cosine_similarity(features.to(dtype=torch.float), proj_k.to(dtype=torch.float), dim=1)  # 计算余弦相似性 [bs]
            similarities[:, k] = cos_sim # [bs, learned]
    max_similarities, task_ids = torch.max(similarities, dim=1)
    task_ids = torch.where(max_similarities < threshold, torch.tensor(-1, device=task_ids.device), task_ids)
    return task_ids, similarities