import torch
import copy
import torch.nn.functional as F


def _cal_optimal_tv_1d(tvs):
    numerator, denominator = 0.0, 0.0
    for tv in tvs:
        weight = torch.sum(tv ** 2)
        numerator += weight * tv
        denominator += weight
    denominator = torch.where(denominator == 0, torch.tensor(1e-10), denominator)
    T_optimal = numerator / denominator
    return T_optimal

def _cal_optimal_tv_2d(tvs):
    sum_TT, sum_TTT = 0.0, 0.0
    for tv in tvs:
        sum_TT += tv.T @ tv # calculating T_k T_k^\top
        sum_TTT += tv.T @ tv @ tv.T # calculating T_k T_k^\top T_k
    T_optimal = torch.linalg.pinv(sum_TT) @ sum_TTT
    return T_optimal.T


def GMF_Mean_(task_vectors, pretrained_checkpoint):
    opt_tv = copy.deepcopy(sum(task_vectors))

    model = task_vectors[0].apply_to(pretrained_checkpoint, scaling_coef=1).to('cuda')
    for name, pp in list(model.named_parameters()):
        opt_tv.vector[name] = opt_tv.vector[name] / len(task_vectors)

    for name, pp in list(model.named_parameters()):
        if opt_tv.vector[name] is None:
            continue
        elif len(pp.shape) == 1:
            tvs = [tv.vector[name] for tv in task_vectors]
            T_optimal = _cal_optimal_tv_1d(tvs)
            opt_tv.vector[name] = T_optimal
        elif len(pp.shape) == 2:
            tvs = [tv.vector[name].to('cuda') for tv in task_vectors]
            T_optimal = _cal_optimal_tv_2d(tvs)
            opt_tv.vector[name] = T_optimal.to('cpu')

    return opt_tv


