"""Simple CompGCN embedding model. Code based on https://github.com/malllabiisc/CompGCN."""

from dataclasses import dataclass
from typing import Literal

import torch
from torch.nn import Parameter
from torch.nn.init import xavier_uniform_

from kge.dataset import TripleDataset
from kge.models.gnn_embedding.base import GNNEmbedding

from .compgcn_src import CompGCNConv, CompGCNConvBasis


@dataclass
class CompGCNParams:
    # Entity and relation counts (determined from dataset)
    num_ent: int
    num_rel: int
    embed_dim: int  # Embedding dimension to give as input to score function

    # Model architecture parameters
    gcn_dim: int = 200  # Number of hidden units in GCN
    init_dim: int = 100  # Initial dimension size for entities and relations
    gcn_layer: int = 1  # Number of GCN Layers to use
    num_bases: int = -1  # Number of basis relation vectors to use. If < 0, use CompGCNConv only
    # and not CompGCNConvBasis.

    # Dropout parameters
    dropout: float = 0.1  # Dropout to use in GCN Layer
    hid_drop: float = 0.1  # Dropout after GCN
    use_bias: bool = False  # Whether to use bias in the model

    # Composition operation
    opn: Literal["corr", "sub", "mult"] = "mult"  # Composition Operation to be used in CompGCN
    # TODO: debug corr operation


class CompGCN(GNNEmbedding):
    def __init__(
        self,
        num_entities: int,
        num_relations: int,
        dimension: int,
        train_dataset: TripleDataset,
        device: torch.device,
    ):
        super().__init__(num_entities, num_relations, dimension)
        self.train_dataset = train_dataset
        self.edge_index, self.edge_type = construct_adj(train_dataset, device)
        self.p = CompGCNParams(
            num_ent=num_entities,
            num_rel=num_relations,
            embed_dim=dimension,
        )
        self.drop = torch.nn.Dropout(self.p.hid_drop)
        self.act = torch.tanh
        self.p.gcn_dim = self.p.embed_dim if self.p.gcn_layer == 1 else self.p.gcn_dim
        self.init_embed = get_param((self.p.num_ent, self.p.init_dim))
        self.device = self.edge_index.device
        if self.p.num_bases > 0:
            self.init_rel = get_param((self.p.num_bases, self.p.init_dim))
        else:
            self.init_rel = get_param((num_relations * 2, self.p.init_dim))

        if self.p.num_bases > 0:
            self.conv1 = CompGCNConvBasis(
                self.p.init_dim,
                self.p.gcn_dim,
                num_relations,
                self.p.num_bases,
                act=self.act,
                params=self.p,
            )
            self.conv2 = (
                CompGCNConv(
                    self.p.gcn_dim,
                    self.p.embed_dim,
                    num_relations,
                    act=self.act,
                    dropout=self.p.dropout,
                    use_bias=self.p.use_bias,
                    opn=self.p.opn,
                )
                if self.p.gcn_layer == 2
                else None
            )
        else:
            self.conv1 = CompGCNConv(
                self.p.init_dim,
                self.p.gcn_dim,
                num_relations,
                act=self.act,
                dropout=self.p.dropout,
                use_bias=self.p.use_bias,
                opn=self.p.opn,
            )
            self.conv2 = (
                CompGCNConv(
                    self.p.gcn_dim,
                    self.p.embed_dim,
                    num_relations,
                    act=self.act,
                    dropout=self.p.dropout,
                    use_bias=self.p.use_bias,
                    opn=self.p.opn,
                )
                if self.p.gcn_layer == 2
                else None
            )

        self.register_parameter("bias", Parameter(torch.zeros(self.p.num_ent)))

    def embed_sr(
        self,
        s: torch.Tensor,
        r: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # indices
        s_ind, r_ind = s, r
        # embeddings
        r = self.init_rel

        x, r = self.conv1(self.init_embed, self.edge_index, self.edge_type, rel_embed=r)
        x = self.drop(x)
        x, r = (
            self.conv2(x, self.edge_index, self.edge_type, rel_embed=r)
            if self.p.gcn_layer == 2
            else (x, r)
        )
        x = self.drop(x) if self.p.gcn_layer == 2 else x

        s_emb = torch.index_select(x, 0, s_ind)
        rel_emb = torch.index_select(r, 0, r_ind)
        all_ent_emb = x

        return s_emb, rel_emb, all_ent_emb

    def regularization_term(self) -> torch.Tensor:
        return 0.0


def construct_adj(
    train_dataset: TripleDataset,
    device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Construct the adjacency matrix for GCN.

    Args:
        train_dataset: The training dataset
        device: The device to use

    Returns:
        edge_index: (2, num_edges) tensor with head and tail indices
        edge_type: (num_edges,) tensor with relation indices

    """
    edge_index, edge_type = [], []
    triples = train_dataset.triples
    for sub, rel, obj in triples:
        edge_index.append((sub, obj))
        edge_type.append(rel)

    edge_index = torch.LongTensor(edge_index).to(device).t()
    edge_type = torch.LongTensor(edge_type).to(device)

    return edge_index, edge_type


def get_param(shape):
    param = Parameter(torch.Tensor(*shape))
    xavier_uniform_(param.data)
    return param
