# src/train_utils.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch_geometric.data import Data

def compute_loss(model, dataset, edge_index, device):
    out, feat, emb = [], [], []

    model.eval()
    with torch.no_grad():
        x_data = torch.tensor(dataset.transpose(0, 2, 1), dtype=torch.float, device=device)
        edge_index = edge_index.to(device)

        for i in range(x_data.shape[0]):
            batch = torch.zeros(x_data[i].shape[0], dtype=torch.long, device=device)
            data = Data(x=x_data[i], edge_index=edge_index, batch=batch)
            output_model = model(data.x, data.edge_index, batch=batch)
            out.append(output_model[0])
            feat.append(output_model[1])
            emb_val = output_model[2]
            if isinstance(emb_val, tuple):
                emb_val = torch.stack(emb_val, dim=0)
            emb.append(emb_val)

    soft_prob = torch.cat(out, dim=0)
    feat = torch.stack([f.T for f in feat], dim=0)

    try:
        emb = torch.stack(emb, dim=0)
    except TypeError:
        emb = torch.cat([e.unsqueeze(0) if e.dim() < 3 else e for e in emb], dim=0)

    return [soft_prob, feat, emb]


def model_acc(model, dataset, edge_index, labels, device):
    model.eval()  # Set model to evaluation mode
    predictions = []

    with torch.no_grad():
        x_data = torch.tensor(dataset.transpose(0, 2, 1), dtype=torch.float, device=device)  
        edge_index = edge_index.to(device)

        for i in range(x_data.shape[0]):  
            batch = torch.zeros(x_data[i].shape[0], dtype=torch.long, device=device)  # Assign nodes to batch
            data = Data(x=x_data[i], edge_index=edge_index, batch=batch)

            # Forward pass
            output_model = model(data.x, data.edge_index, batch=data.batch)[0]  # Extract logits

            # Get predicted class
            pred_label = output_model.argmax(dim=1)
            predictions.append(pred_label[0])

    # Convert predictions to tensor
    pred_tensor = torch.tensor(predictions, dtype=torch.long, device=device)

    # Compute accuracy efficiently
    correct = (pred_tensor == labels).sum().item()
    accuracy = correct / dataset.shape[0]

    return accuracy


def weight_loss(model, dataset, weights, edge_index, training_label, device):
    pred = []
    for i in range(dataset.shape[0]):
        nodes_fea = np.transpose(dataset[i, :, :])
        x = torch.tensor(nodes_fea, dtype=torch.float)
        data = Data(x=x, edge_index=edge_index).to(device)
        pred.append(model(data.x, data.edge_index)[0].argmax(dim=1)[0])
    pred = torch.tensor(pred, device=device)
    err_rate = (weights.to(device) * (pred != training_label)).sum() / weights.sum()
    return err_rate


def quality_update(err_rate):
    alpha = torch.log(0.01 + err_rate / (1 - err_rate))
    return torch.max(-alpha, torch.tensor(0.05))


def weight_update(model, err_rate, dataset, weights, edge_index, training_label, device):
    alpha = quality_update(err_rate)

    for i in range(dataset.shape[0]):
        nodes_fea = [tmp for tmp in np.transpose(dataset[i,:,:])]
        x = torch.tensor(nodes_fea, dtype=torch.float)
        data = Data(x=x, edge_index=edge_index)
        data = data.to(device)

        # model prediction and weights updating via exponential rule
        curr_pred = torch.Tensor(model(data.x, data.edge_index)[0].argmax(dim=1)[0]).to(device) 
        weights[i] = torch.exp(alpha * (training_label[i] != curr_pred))

    # normalize the weights for all samples
    weights = nn.functional.normalize(weights, p=2, dim=0)
        
    return weights
