import torch

def sample_loss(model, X, y, beta, lbd):
    if y.shape != torch.Size([]):
        # for task-wise loss computation, needs to consider the regularization term
        return (torch.sum(torch.square(X @ model - y.squeeze()))) / len(y) +  lbd  * torch.sum(torch.square(model - beta)) 
    else:
        # for individual loss computation, only consider the loss term
        return torch.square(X @ model - y.squeeze())

def compute_attribution_linear(models, beta, x, y, num_dim, lbd, device):
    # x: list of input tensors (we use list since the number of data samples may be different for each task)
    #    size: (num_tasks, num_samples, num_dimensions)
    # y: target
    #    size: (num_tasks, num_samples)
    # num_dim: number of dimensions of the input
    # beta: regularization parameter
    num_tasks = len(x)
    num_dim = x[0][0].shape[0]
    attribution = []
    H_ss, H_s_m_plus_one, H_m_plus_one_m_plus_one = get_hessian(models, beta, x, y, num_dim, lbd, device)
    H_st_inv, H_s_m_plus_one_inv, H_m_plus_one_m_plus_one_inv = hessian_inverse(H_ss, H_s_m_plus_one, H_m_plus_one_m_plus_one, num_tasks, num_dim, device)
    for i, (task_data, task_label) in enumerate(zip(x,y)):
        temp = torch.zeros([task_data.shape[0], num_tasks, num_dim]).to(device)
        for j, (target_data, target_label) in enumerate(zip(task_data, task_label)):
            loss = sample_loss(models[i], target_data, target_label.squeeze(), beta, lbd[i])
            loss.backward() 
            for k in range(num_tasks):
                score = H_st_inv[k][i] @ models[i].grad
                temp[j][k] = - score
            # reset torch gradients, very important!
            models[i].grad.data.zero_()
        attribution.append(temp)
    return attribution

def get_hessian(models, beta, x, y, num_dim, lbd, device):
    # x: list of input tensors (we use list since the number of data samples may be different for each task)
    #    size: (num_tasks, num_samples, num_dimensions)
    # y: target
    #    size: (num_tasks, num_samples)
    # num_dim: number of dimensions of the input
    # beta: regularization parameter
    num_tasks = len(x)
    parameters = ()
    for i in range(num_tasks):
        parameters += (models[i],)
    parameters += (beta,)

    def loss_func(*parameters):
        loss = torch.zeros(1).to(device)
        for i in range(num_tasks):
            loss += sample_loss(parameters[i], x[i], y[i], parameters[-1], lbd[i])
        return loss
    
    hess = torch.autograd.functional.hessian(loss_func, parameters)
    H_ss = torch.zeros([num_tasks, num_dim, num_dim]).to(device)
    H_s_m_plus_one = torch.zeros([num_tasks, num_dim, num_dim]).to(device)
    H_m_plus_one_m_plus_one = torch.zeros([num_dim, num_dim]).to(device)
    for i in range(num_tasks):
        H_ss[i] += hess[i][i]
        H_s_m_plus_one[i] += hess[i][-1]
    H_m_plus_one_m_plus_one += hess[-1][-1]

    return H_ss, H_s_m_plus_one, H_m_plus_one_m_plus_one

def hessian_inverse(H_ss, H_s_m_plus_one, H_m_plus_one_m_plus_one, num_tasks, num_dim, device):
    # H_ss: Hessian matrix of the loss function with respect to the model parameters
    # H_s_m_plus_one: Hessian matrix of the loss function with respect to the model parameters and beta
    # H_m_plus_one_m_plus_one: Hessian matrix of the loss function with respect to beta
    H_st_inv = torch.zeros([num_tasks, num_tasks, num_dim, num_dim]).to(device) 
    H_s_m_plus_one_inv = torch.zeros_like(H_s_m_plus_one).to(device)
    N = H_m_plus_one_m_plus_one.clone().detach().to(device)
    cached_inverse = []
    for i in range(num_tasks):
        inv = torch.inverse(H_ss[i])
        N -= H_s_m_plus_one[i].t() @ inv @ H_s_m_plus_one[i]
        cached_inverse.append(inv)
    H_m_plus_one_m_plus_one_inv = torch.inverse(N)
    cached_mat = dict()  
    for i in range(num_tasks):
        cached_mat[i] = cached_inverse[i] @ H_s_m_plus_one[i] 
        H_s_m_plus_one_inv[i] = - cached_mat[i] @ H_m_plus_one_m_plus_one_inv
        H_st_inv[i][i] += cached_inverse[i] 
        for j in range(i+1):
            H_st_inv[i][j] -= H_s_m_plus_one_inv[i] @ cached_mat[j].t()
            H_st_inv[j][i] = H_st_inv[i][j].clone().t()
    return H_st_inv, H_s_m_plus_one_inv, H_m_plus_one_m_plus_one_inv