import numpy as np
import torch
import torch.nn.functional as F
from src.utils import config


def greedy_loss(pred_feats, true_feats, pred_missing, true_missing):
    has_true_missing = False
    if config.cuda:
        true_missing = true_missing.cpu()
        pred_missing = pred_missing.cpu()

    # Initialize the loss tensor with zeros, matching the shape of predicted features
    loss = torch.zeros(pred_feats.shape, requires_grad=True)

    if config.cuda:
        loss = loss.cuda()

    # Get the number of predictions
    pred_len = len(pred_feats)

    # Convert tensors to NumPy arrays and reshape for processing
    pred_missing_np = pred_missing.detach().numpy().reshape(-1).astype(np.int32)
    true_missing_np = true_missing.detach().numpy().reshape(-1).astype(np.int32)

    # Clip the values to be within the range [0, config.num_pred] to ensure valid indices
    true_missing_np = np.clip(true_missing_np, 0, config.num_pred)
    pred_missing_np = np.clip(pred_missing_np, 0, config.num_pred)

    # Iterate over each set of predicted features
    for i in range(pred_len):
        # Iterate over each predicted feature up to the minimum of config.num_pred or the number of predicted missing items
        for pred_j in range(min(config.num_pred, pred_missing_np[i])):
            # Check if there are true missing features to consider
            if true_missing_np[i] > 0:
                has_true_missing = True
                # Retrieve the true feature that corresponds to the current missing index
                if isinstance(true_feats[i][true_missing_np[i] - 1], np.ndarray):
                    true_feats_tensor = torch.tensor(true_feats[i][true_missing_np[i] - 1])
                    if config.cuda:
                        true_feats_tensor = true_feats_tensor.cuda()
                else:
                    true_feats_tensor = true_feats[i][true_missing_np[i] - 1]

                # Calculate the MSE loss between the predicted feature and the true feature
                loss[i][pred_j] += F.mse_loss(
                    pred_feats[i][pred_j].unsqueeze(0).float(),
                    true_feats_tensor.unsqueeze(0).float()
                ).squeeze(0)

                # Iterate over possible true features to find the best match that minimizes the loss
                for true_k in range(min(config.num_pred, true_missing_np[i])):
                    if isinstance(true_feats[i][true_k], np.ndarray):
                        true_feats_tensor = torch.tensor(true_feats[i][true_k])
                        if config.cuda:
                            true_feats_tensor = true_feats_tensor.cuda()
                    else:
                        true_feats_tensor = true_feats[i][true_k]

                    # Calculate the loss for the current pair of predicted and true features
                    loss_ijk = F.mse_loss(
                        pred_feats[i][pred_j].unsqueeze(0).float(),
                        true_feats_tensor.unsqueeze(0).float()
                    ).squeeze(0)

                    # Update the loss if a lower value is found
                    if torch.sum(loss_ijk) < torch.sum(loss[i][pred_j].data):
                        loss[i][pred_j] = loss_ijk
            else:
                # Skip if there are no true missing features to consider
                continue
    if not has_true_missing:
        # If no loss was calculated, add a small value to ensure gradients
        loss = loss + torch.sum(pred_feats) * 0.0

    return loss


def loss_from_other(own, other, output_feat, output_missing, feat_shape):
    choice = np.random.choice(len(list(other.subG.nodes())),
                              len(own.train_ilocs))
    others_ids = other.subG.nodes()[choice]
    global_target_feat = []
    for c_i in others_ids:
        neighbors_ids = other.subG.neighbors(c_i)
        while len(neighbors_ids) == 0:
            c_i = np.random.choice(len(list(other.subG.nodes())), 1)[0]
            id_i = other.subG.nodes()[c_i]
            neighbors_ids = other.subG.neighbors(id_i)
        choice_i = np.random.choice(neighbors_ids, config.num_pred)
        for ch_i in choice_i:
            global_target_feat.append(other.subG.node_features([ch_i])[0])
    global_target_feat = np.asarray(global_target_feat).reshape(
        (len(own.train_ilocs), config.num_pred, feat_shape))
    loss_train_feat_other = greedy_loss(output_feat[own.train_ilocs],
                                                  global_target_feat,
                                                  output_missing[own.train_ilocs],
                                                  own.all_targets_missing[
                                                      own.train_ilocs]
                                                  ).unsqueeze(0).mean().float()
    return loss_train_feat_other