from tqdm import tqdm
import torch
import numpy as np
from pykeen.nn.compositions import composition_resolver, CompositionModule


def propagate_embeddings(ctrl, graph, core_embed, relations_embed, subgraphs, **kwargs):
    if ctrl.propagation_type == "normalized_sum":
        embeddings = normalized_sum_propagation(
            ctrl, graph, core_embed, relations_embed, subgraphs, **kwargs
        )
    elif ctrl.propagation_type == "barycenter":
        embeddings = barycenter_propagation(
            ctrl, graph, core_embed, relations_embed, subgraphs, **kwargs
        )

    return embeddings


def normalized_sum_propagation(
    ctrl, graph, core_embed, relations_embed, subgraphs, **kwargs
):
    # Initialize embeddings with zeros outside the core subgraph
    embeddings = torch.zeros((graph.num_entities, ctrl.embed_dim))
    if ctrl.embed_method == "rotate":
        embeddings = torch.zeros((graph.num_entities, ctrl.embed_dim), dtype=torch.cfloat)

    # Sort core_embed based on core_subgraph_idx
    core_embed = core_embed[np.argsort(graph.core_subgraph_idx)]

    for _ in range(ctrl.n_passes_by_subgraph):
        for subgraph_idx in tqdm(subgraphs, desc="Propagating through subgraphs"):
            # Add core subgraph
            subgraph_idx = np.union1d(subgraph_idx, graph.core_subgraph_idx)
            core_idx = np.where(np.isin(subgraph_idx, graph.core_subgraph_idx))[0]

            # Load subgraph embeddings on GPU
            outer_idx = np.where(~np.isin(subgraph_idx, graph.core_subgraph_idx))[0]
            subgraph_embed = torch.zeros_like(embeddings[subgraph_idx], device=ctrl.device)
            subgraph_embed[outer_idx] = embeddings[subgraph_idx[outer_idx]].to(ctrl.device)
            subgraph_embed[core_idx] = core_embed

            # Create edge_index and edge_type variables
            subgraph_edge_index, subgraph_edge_type = (
                graph.get_subgraph_edge_index_and_type(subgraph_idx)
            )
            subgraph_edge_index = subgraph_edge_index.to(ctrl.device)
            subgraph_edge_type = subgraph_edge_type.to(ctrl.device)

            for _ in range(ctrl.n_propagation_steps // ctrl.n_passes_by_subgraph):
                # Propagate
                subgraph_embed += ctrl.propagation_lr * composition(
                    subgraph_embed,
                    relations_embed,
                    subgraph_edge_index,
                    subgraph_edge_type,
                    kind=ctrl.embed_setting.composition,
                    **kwargs,
                )

                # Normalize
                if ctrl.normalize_embeddings:
                    subgraph_embed = torch.nn.functional.normalize(
                        subgraph_embed, p=2, dim=1
                    )

                # Reset core subgraph
                if ctrl.reset_embed:
                    subgraph_embed[core_idx] = core_embed

            # Load outer embeddings back to CPU
            embeddings[subgraph_idx[outer_idx]] = subgraph_embed[outer_idx].cpu()
    
    # Load core embeddings on CPU
    embeddings[np.sort(graph.core_subgraph_idx)] = core_embed.cpu()
    return embeddings


def composition(
    entity_embed,
    relations_embed,
    edge_index,
    edge_type,
    kind="multiplication",
    batch_size=100000,
    **kwargs,
):
    source, target = edge_index

    # Create an empty tensor that has same shape as entity_embed
    new_embed = torch.zeros_like(entity_embed)

    # Loop over the edges in batches
    for start in range(0, len(edge_type), batch_size):
        stop = min(start + batch_size, len(edge_type))

        # compose
        if kind == "tucker":
            core_tensor = kwargs["core_tensor"]
            message = TuckerCompositionModule()(
                entity_embed[source[start:stop]],
                relations_embed[edge_type[start:stop]],
                core_tensor,
            )
        else:
            message = composition_resolver.make(kind)(
                entity_embed[source[start:stop]], relations_embed[edge_type[start:stop]]
            )

        # aggregate by sum
        new_embed = new_embed.index_add(dim=0, index=target[start:stop], source=message)

    return new_embed


def tucker_product(
    a: torch.FloatTensor, b: torch.FloatTensor, core_tensor: torch.FloatTensor
) -> torch.FloatTensor:
    """Compute the 'Tucker product' of tensors a and b, ie compute W x_1 a x_2 b"""
    return torch.einsum(
        # x_1 contraction
        "...ik,...i->...k",
        torch.einsum(
            # x_2 contraction
            "ijk,...j->...ik",
            core_tensor,
            b,
        ),
        a,
    )


class TuckerCompositionModule(CompositionModule):
    """Composition by Tucker product."""

    func = tucker_product

    # docstr-coverage: inherited
    def forward(
        self, a: torch.FloatTensor, b: torch.FloatTensor, core_tensor: torch.FloatTensor
    ) -> torch.FloatTensor:  # noqa: D102
        return self.__class__.func(a, b, core_tensor)


def barycenter_propagation(
    ctrl, graph, core_embed, relations_embed, subgraphs, **kwargs
):
    # Initialize embeddings with NaNs outside the core subgraph
    embeddings = torch.empty((graph.num_entities, ctrl.embed_dim))
    embeddings[:] = torch.nan
    embeddings[graph.core_subgraph_idx] = core_embed.cpu()

    print(
        f"{int(torch.isnan(embeddings).any(dim=1).sum()) / graph.num_entities:.1%} of entities unreached  ",
        end="\r",
    )

    # Iterate over subgraphs until no NaNs remain
    i = 0
    while torch.isnan(embeddings).any():
        # Get current subgraph
        subgraph_idx = subgraphs[i]
        # Add core subgraph
        subgraph_idx = np.union1d(subgraph_idx, graph.core_subgraph_idx)
        core_idx = np.where(np.isin(subgraph_idx, graph.core_subgraph_idx))[0]
        core_embed = embeddings[subgraph_idx[core_idx]].to(ctrl.device)
        # Load subgraph embeddings on GPU
        subgraph_embed = embeddings[subgraph_idx].to(ctrl.device)
        # Create edge_index and edge_type variables
        subgraph_edge_index, subgraph_edge_type = (
            graph.get_subgraph_edge_index_and_type(subgraph_idx)
        )
        subgraph_edge_index = subgraph_edge_index.to(ctrl.device)
        subgraph_edge_type = subgraph_edge_type.to(ctrl.device)

        for _ in range(ctrl.n_propagation_steps):
            # Propagate
            subgraph_embed = barycenter_composition(
                subgraph_embed,
                core_idx,
                relations_embed,
                subgraph_edge_index,
                subgraph_edge_type,
                kind=ctrl.embed_setting.composition,
            )

            # Normalize
            if ctrl.normalize_embeddings:
                subgraph_embed = torch.nn.functional.normalize(
                    subgraph_embed, p=2, dim=1
                )

            # Reset core subgraph
            if ctrl.reset_embed:
                subgraph_embed[core_idx] = core_embed

        embeddings[subgraph_idx] = subgraph_embed.cpu()

        i = (i + 1) % len(subgraphs)
        print(
            f"{int(torch.isnan(embeddings).any(dim=1).sum()) / graph.num_entities:.1%} of entities unreached  ",
            end="\r",
        )
    print("")
    return embeddings


def barycenter_composition(
    entity_embed,
    core_idx,
    relations_embed,
    edge_index,
    edge_type,
    kind="multiplication",
    batch_size=10000,
):
    source, target = edge_index

    # Remove NaNs
    nan_mask = torch.isnan(entity_embed[source, 0])
    source, target, edge_type = (
        source[~nan_mask],
        target[~nan_mask],
        edge_type[~nan_mask],
    )
    del nan_mask

    # Create an empty tensor for outer entities embeddings
    outer_mask = torch.ones(entity_embed.size(0), dtype=torch.bool)
    outer_mask[core_idx] = False
    new_embed = torch.zeros_like(entity_embed[outer_mask])

    # Loop over the outer entities in batches
    for start in range(0, len(new_embed), batch_size):
        stop = min(start + batch_size, len(new_embed))
        edge_mask = (target >= start) & (target < stop)

        # Compose
        message = composition_resolver.make(kind)(
            entity_embed[source[edge_mask]], relations_embed[edge_type[edge_mask]]
        )

        # Aggregate by mean
        indices = target[edge_mask] - start  # Adjust indices within the batch
        unique_indices, counts = indices.unique(return_counts=True)
        message_sum = torch.zeros(
            (stop - start, message.shape[1]), device=message.device
        )
        message_sum.index_add_(0, indices, message)
        new_embed[start:stop][unique_indices] = message_sum[
            unique_indices
        ] / counts.unsqueeze(-1)

    entity_embed[outer_mask] = new_embed

    return entity_embed
