import torch
import time
import torch.nn.functional as F
from post_process.process_utils import get_virtual_perturbation, compute_and_time_pinv
from torch.optim import Adam
import torch.nn as nn
import copy

def transfer_weights_to_target(model, target, e_recovery_matrix, step_size):
    
    for i, agg_layer in enumerate(model.aggs):
        # Transfer weights to the target model
        target.aggs[i].agg = agg_layer.agg
        if agg_layer.version!="base":
            target.aggs[i].er_matrix.data = agg_layer.er_matrix.data
        else:
            target.aggs[i].er_matrix.data = torch.zeros_like(target.aggs[i].er_matrix.data)
        target.aggs[i].er_matrix.data += step_size * e_recovery_matrix[i]      
    
    return target

def train_e_recovery_matrix(embeddings_target, h, E_h, device, num_epochs=1000, lr=0.01, patience=10, min_delta=1e-4):
    start_time = time.time()

    embeddings_target = embeddings_target.to(device)
    h_minus_Eh = (h - E_h).to(device)

    # Initialize W_va as a parameter
    F_in = embeddings_target.size(1)
    F_out = h_minus_Eh.size(1)
    W_va = nn.Parameter(torch.zeros(F_in, F_out, device=device))

    # Define optimizer
    optimizer = Adam([W_va], lr=lr)

    # Early stopping variables
    best_loss = float('inf')
    best_W_va = None
    epochs_no_improve = 0

    # Training loop
    for epoch in range(num_epochs):
        optimizer.zero_grad()

        # Compute predictions
        pred = embeddings_target @ W_va

        # Compute loss (Mean Squared Error)
        loss = nn.functional.mse_loss(pred, h_minus_Eh)

        # Backpropagation and parameter update
        loss.backward()
        optimizer.step()

        # Check for improvement
        current_loss = loss.item()
        if best_loss - current_loss > min_delta:
            best_loss = current_loss
            best_W_va = W_va.detach().clone()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # Early stopping condition
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch + 1}")
            break

    end_time = time.time()
    print(f"Time taken to train W_va: {end_time - start_time:.4f} seconds")

    # Log the time consumption into a file
    #log_file = 'log.txt'
    #with open(log_file, 'a') as f:
    #    f.write(f"{end_time - start_time:.4f}\n")

    # Return the best W_va found during training
    return best_W_va if best_W_va is not None else W_va.detach()


def post_process(data, model, target, device, args, ratio, step_size=0.2):
    model.eval()
    target.eval()

    # Perform inference on the entire dataset
    agg_embeddings = []
    _, embeddings, agg_embeddings = model.inference(data, device)

    # Create a random mask with the given ratio from the dataset
    num_used = int(data.x.size(0)*ratio)
    random_mask = torch.zeros_like(data.train_mask, dtype=torch.bool) 
    random_indices = torch.randperm(data.train_mask.size(0))[:num_used]  
    random_mask[random_indices] = True 
    target_idx = random_mask.cpu() 

    embeddings_target = [embed[target_idx] for embed in embeddings]
    # Compute pseudoinverse of the embeddings of each layer
    # When obtain ER in train way, you can comment out below line
    embeddings_pinv = [compute_and_time_pinv(args, embed, device) for embed in embeddings_target]

    # Calculate virtual perturbation of each layer
    # To evaluate the other virtual perturbation : change the edge_ratio
    edge_ratio_vp = 1.0
    virtual_perturb = get_virtual_perturbation(model, data, embeddings, device, removal_ratio=edge_ratio_vp, num_samples=1)
    # Extract the target nodes for calcualting virtual aggregation
    agg_embed_target = [agg_embed[target_idx] for agg_embed in agg_embeddings] 
    virtual_perturb_target = [E_agg[target_idx] for E_agg in virtual_perturb] 

    # Calculate edge-shift recovery weight matrix (W)
    e_recovery_matrix = []

    for i, agg_layer in enumerate(model.aggs):
        h = agg_embed_target[i].to(device) 
        E_h = virtual_perturb_target[i].to(device)
        e_recovery_matrix_i = embeddings_pinv[i] @ (h - E_h) 
        # Below line is for using gradient-based training instead of pseudoinverse
        #e_recovery_matrix_i = train_e_recovery_matrix(embeddings_target[i], h, E_h, device, num_epochs=100, lr=0.01)
        e_recovery_matrix.append(e_recovery_matrix_i)

    # Transfer weights with virtual aggregation
    target = transfer_weights_to_target(model, target, e_recovery_matrix, step_size)
    # number of edges in train graph : for non-deterministic perturbation
    target.edge_num.data = torch.tensor(int(data.edge_index.size(1)*(1-edge_ratio_vp)))
    model = target

    return target
