import torch
import time
from model.model import GAT
from experiment.exp import remove_edges

def compute_pinv(matrix, num_iterations=1):
    pinv_sum = torch.zeros_like(matrix.T)
    for _ in range(num_iterations):
        pinv_sum += torch.linalg.pinv(matrix)
    return pinv_sum / num_iterations

def compute_and_time_pinv(args, matrix, device):
    start_time = time.time()
    pinv = compute_pinv(matrix.to(device))
    end_time = time.time()
    print(f"Time taken to calculate pseudoinverse: {end_time - start_time:.4f} seconds")
    # Log the time consumption if ping into a file
    # log_file = 'log.txt'
    # with open(log_file, 'a') as f:
    #    f.write(f"{end_time - start_time:.4f}\n")

    return pinv

def get_virtual_perturbation(model, data, embeddings, device, removal_ratio, num_samples=1):
    # Initialize virtual perturbation as None to identify the first iteration
    v_perturb = None
    data.backup_edge_index = data.edge_index.clone()
    
    for sample in range(num_samples):
        
        data = remove_edges(data, removal_ratio, True)
        modified_agg_embeddings = []
        for i, agg in enumerate(model.aggs):
            modified_agg_embeddings.append(model.agg_inference(data, embeddings[i], data.edge_index, i, device))

        data.edge_index = data.backup_edge_index.clone().detach()

        # Initialize v_perturb with zeros in the first iteration
        if v_perturb is None:
            v_perturb = [torch.zeros_like(embed) for embed in modified_agg_embeddings]

        # Accumulate each embedding separately
        for i in range(len(modified_agg_embeddings)):
            v_perturb[i] += modified_agg_embeddings[i]

    # Compute the expectation by averaging over all samples
    v_perturb = [agg_embed / num_samples for agg_embed in v_perturb]

    return v_perturb

