
import torch
from torch_geometric.loader import NeighborLoader
from torch.nn.utils import clip_grad_norm_

from utils.training.sampling_functions import sample_positive_pairs, sample_negative_pairs_from_positives
from utils.training.loss import cosine_similarity_stats


def _localize_pairs_densefree(num_nodes: int, batch_n_id: torch.Tensor, pairs: torch.Tensor) -> torch.Tensor:
    device = batch_n_id.device
    inv = torch.full((num_nodes,), -1, device=device, dtype=torch.long)
    inv[batch_n_id] = torch.arange(batch_n_id.numel(), device=device)
    return inv[pairs]


def train_step(training_graph, model, optimizer,
               cluster_to_nodes: dict, num_anchors: int, num_neg_per_anchor: int,
               alpha_sampling: float, min_size_clusters: int,
               num_neighbors_loader: list[int],
               contrastive_loss_fn: callable,
               use_edge_features: bool = False,
               **kwargs
               ) -> torch.Tensor:

    model.train()
    optimizer.zero_grad()

    positive_pairs = sample_positive_pairs(cluster_to_nodes=cluster_to_nodes, num_anchors=num_anchors,
                                           alpha=alpha_sampling, min_size_clusters=min_size_clusters)
    negative_pairs = sample_negative_pairs_from_positives(cluster_to_nodes=cluster_to_nodes,
                                                          node_to_cluster=training_graph.y,
                                                          pos_pairs=positive_pairs,
                                                          num_neg_per_anchor=num_neg_per_anchor,
                                                          min_size_clusters=min_size_clusters,
                                                          alpha=alpha_sampling,
                                                          )

    input_nodes = torch.unique(torch.cat([positive_pairs, negative_pairs], dim=0))

    device = next(model.parameters()).device
    loader = NeighborLoader(training_graph, num_neighbors=num_neighbors_loader, batch_size=len(input_nodes),
                            input_nodes=input_nodes, replace=True)
    batch = next(iter(loader)).to(device, non_blocking=True)

    positive_pairs = positive_pairs.to(torch.long)
    negative_pairs = negative_pairs.to(torch.long)
    positive_pairs_local = _localize_pairs_densefree(training_graph.num_nodes, batch.n_id, positive_pairs)
    negative_pairs_local = _localize_pairs_densefree(training_graph.num_nodes, batch.n_id, negative_pairs)

    if use_edge_features:
        out = model(x=batch.x, edge_index=batch.edge_index, edge_attr=batch.edge_attr)
    else:
        out = model(x=batch.x, edge_index=batch.edge_index)
    loss = contrastive_loss_fn(out[:batch.batch_size], positive_pairs_local, negative_pairs_local)

    loss.backward()
    clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

    return loss


def valid_step(valid_batch, model, valid_pos_pairs, valid_neg_pairs,
               contrastive_loss_fn,
               use_edge_features: bool = False) -> dict:
    model.eval()
    with torch.no_grad():
        if use_edge_features:
            out = model(x=valid_batch.x, edge_index=valid_batch.edge_index, edge_attr=valid_batch.edge_attr)
        else:
            out = model(x=valid_batch.x, edge_index=valid_batch.edge_index)
        valid_loss = contrastive_loss_fn(out[:valid_batch.batch_size], valid_pos_pairs, valid_neg_pairs)
        pos_sim_avg, pos_sim_std, neg_sim_avg, neg_sim_std = cosine_similarity_stats(
            out, valid_pos_pairs, valid_neg_pairs)
    return {
        "loss": float(valid_loss.item()),
        "pos_sim_avg": float(pos_sim_avg),
        "pos_sim_std": float(pos_sim_std),
        "neg_sim_avg": float(neg_sim_avg),
        "neg_sim_std": float(neg_sim_std),
    }
