# PyTorch related imports

import torch
from torch.nn import Parameter

# PyTorch related imports
from torch.nn.init import xavier_uniform_

from kge.dataset import TripleDataset


def get_param(shape):
    param = Parameter(torch.Tensor(*shape))
    xavier_uniform_(param.data)
    return param


def com_mult(a, b):
    r1, i1 = a[..., 0], a[..., 1]
    r2, i2 = b[..., 0], b[..., 1]
    return torch.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim=-1)


def conj(a):
    a[..., 1] = -a[..., 1]
    return a


def cconv(a, b):
    return torch.fft.irfft(
        com_mult(torch.fft.rfft(a, 1), torch.fft.rfft(b, 1)),
        1,
        signal_sizes=(a.shape[-1],),
    )


def ccorr(a, b):
    return torch.fft.irfft(
        com_mult(conj(torch.fft.rfft(a, 1)), torch.fft.rfft(b, 1)),
        1,
        signal_sizes=(a.shape[-1],),
    )


def construct_adj(
    train_dataset: TripleDataset,
    device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Construct the adjacency matrix for GCN.

    Args:
        train_dataset: The training dataset
        device: The device to use

    Returns:
        edge_index: (2, num_edges) tensor with head and tail indices
        edge_type: (num_edges,) tensor with relation indices

    """
    edge_index, edge_type = [], []
    triples = train_dataset.triples
    for sub, rel, obj in triples:
        edge_index.append((sub, obj))
        edge_type.append(rel)

    edge_index = torch.LongTensor(edge_index).to(device).t()
    edge_type = torch.LongTensor(edge_type).to(device)

    return edge_index, edge_type
