""" Module for loading and performing inference using baseline grapher method
    most of the code for this model comes from https://github.com/IBM/Grapher
    and the paper https://arxiv.org/abs/2211.10511
"""

import itertools
from collections import defaultdict
from functools import partial
import logging
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import numpy.typing as npt
import torch
import torch.nn.functional as F
from torch import nn
from transformers import BatchEncoding, PreTrainedTokenizerFast

from text2graph.data.base_dataset import TextGraph
from text2graph.models.base_model import BaseModel, GeneratedText
from text2graph.models.serialize import disambiguate_nodes, ambiguate_nodes
from text2graph.training.losses_and_metrics import UNLABELLED_CATEGORICAL

logger = logging.getLogger(__name__)
NODE_TOKEN = "<n>"
EDGE_TOKEN = "<e>"
NO_EDGE_TOKEN = "<ne>"


class Grapher(BaseModel):
    def __init__(self, metadata: Dict[str, Any]):
        super().__init__(metadata)
        self.max_nodes = self.metadata['max_nodes']
        self.default_seq_len_edge = self.metadata['default_seq_len_edge']
        self._add_new_tokens([NODE_TOKEN, EDGE_TOKEN, NO_EDGE_TOKEN])
        self.node_token_id = self.tokenizer.convert_tokens_to_ids(NODE_TOKEN)
        self.tokenizer.padding_side = "right"
        self.edge_generator = EdgesGen(
            hidden_dim=self.metadata['embedding_size'],
            vocab_size=self.metadata['vocab_size'],
            bos_token_id=self.tokenizer.pad_token_id
        )

    def get_collate_fn(self):
        """ Returns a batching function for loading data to train and evaluate the model """
        return partial(
            self.text_graph2inputs_batch,
            tokenizer=self.tokenizer,
            has_encoder=self.has_encoder
        )


    def split_nodes(self, output_ids, features):
        """ Desequences a node sequence and returns it as a matrix of node pair features """
        # features: batch_size x seq_len x hidden_dim
        # output_ids: batch_size x seq_len
        batch_size, _ = output_ids.size()
        split_features = torch.zeros(
            (self.max_nodes, batch_size, self.metadata['embedding_size']),
            device=features.device,
            dtype=features.dtype
        )  # num_nodes X batch_size X hidden_dim
        for n in range(self.max_nodes):
            mask_node_n = (
                (torch.cumsum((output_ids == self.node_token_id), 1) == n)
                & (output_ids != self.node_token_id)
            ).unsqueeze(2)
            features_node_n = features*mask_node_n
            sum_features_node_n = torch.cumsum(features_node_n, 1)[:, -1]
            num_tokens_node_n = torch.sum(mask_node_n, 1)
            num_tokens_node_n[num_tokens_node_n == 0] = 1
            ave_features_node_n = sum_features_node_n / num_tokens_node_n
            split_features[n] = ave_features_node_n
        return split_features

    def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        if self.has_encoder:
            output = self.language_model(
                input_ids=inputs['text_sequence'],
                attention_mask=inputs['text_attn_mask'],
                decoder_input_ids=inputs['node_sequence'],
                decoder_attention_mask=inputs['node_attn_mask'],
                output_hidden_states=True
            )
        else:
            output = self.language_model(
                input_ids=inputs['input_sequence'],
                attention_mask=inputs['input_attn_mask'],
                output_hidden_states=True
            )
        node_logits = output.logits  # batch_size x seq_len x vocab_size
        features = self.split_nodes(
            node_logits.argmax(-1),
            output.decoder_hidden_states[-1] if hasattr(output, 'decoder_hidden_states')
            else output.hidden_states[-1]
        )  # num_nodes x batch_size x hidden_dim
        edge_matrix_logits = self.edge_generator(features, self.default_seq_len_edge)
        return {
            'node_sequence_logits': node_logits,
            'edge_matrices_logits': edge_matrix_logits
        }

    def generate_graph(
        self,
        text_sequence: torch.Tensor,
        text_attn_mask: torch.Tensor,
        do_sample: bool,
        max_new_tokens: int,
        num_beams: int = 1,
        **kwargs
    ) -> Tuple[List[TextGraph], GeneratedText]:
        """ Generates a graph for each data point in a batch given the points text input """
        if max_new_tokens == -1:
            max_new_tokens = kwargs['node_sequence'].shape[1]
        graphs = []
        with torch.no_grad():
            nodes_sequence = self._generate_language(
                start_token_id=self.tokenizer.convert_tokens_to_ids(NODE_TOKEN),
                text_sequence=text_sequence,
                text_attn_mask=text_attn_mask,
                do_sample=do_sample,
                max_new_tokens=max_new_tokens,
                num_beams=num_beams
            )
            nodes_tokens = self.tokenizer.batch_decode(
                nodes_sequence.ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
                spaces_between_special_tokens=False
            )
            nodes_list = _desequence_nodes(nodes_tokens)
            input_sequence = torch.cat([text_sequence, nodes_sequence.ids], dim=1)
            pad_id = self.tokenizer.pad_token_id
            edge_matrices = self.forward({
                'input_sequence': input_sequence,
                'input_attn_mask': torch.where(input_sequence == pad_id, 0, 1),
                'text_sequence': text_sequence,
                'text_attn_mask': text_attn_mask,
                'node_sequence': nodes_sequence.ids,
                'node_attn_mask': torch.where(nodes_sequence.ids == pad_id, 0, 1),
            })['edge_matrices_logits'].argmax(-1)
        for nodes, edge_matrix in zip(nodes_list, edge_matrices):
            graph = TextGraph(nodes=nodes, edges=[], edge_index=[])
            if len(nodes) > 0:
                edge_matrix_decoded = [
                    self.tokenizer.batch_decode(
                        edges,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=False,
                        spaces_between_special_tokens=False
                    )
                    for edges in edge_matrix
                ]
                graph.edges, graph.edge_index = _decode_edges(edge_matrix_decoded)
            graphs.append(graph)
        return graphs, nodes_sequence

    @staticmethod
    def text_graph2inputs(
        text_graph_pair: TextGraph,
        randomize: bool = False
    ) -> Dict[str, Union[str, npt.NDArray[np.int_]]]:
        """ Processes a data point's text and graph into a dictionary of tensors with the required
            inputs for the model and returns the dictionary
        """
        sequenced_graph = {}
        nodes_disambiguated = disambiguate_nodes(text_graph_pair.nodes)
        sequenced_graph['node_sequence'] = _add_separator_tokens_nodes(nodes_disambiguated)
        sequenced_graph['edge_matrix'] = _encode_edges(
            text_graph_pair.edge_index,
            text_graph_pair.edges
        )
        sequenced_graph['text_sequence'] = text_graph_pair.text
        return sequenced_graph

    @staticmethod
    def text_graph2inputs_batch(
        text_graph_pairs: List[TextGraph],
        tokenizer: PreTrainedTokenizerFast,
        randomize: bool = False,
        has_encoder: bool = False
    ) -> Dict[str, torch.Tensor]:
        """ Processes a batch of text-graph pairs into a dictionary of tensors with the required
            inputs for a model and returns the dictionary
        """
        sequenced_graphs = [
            Grapher.text_graph2inputs(text_graph_pair)
            for text_graph_pair in text_graph_pairs
        ]
        batch = defaultdict(list)
        edge_matrices, max_nodes, max_sequence = [], 0, 0
        padding_side = tokenizer.padding_side
        tokenizer.padding_side = 'right'
        for graph in sequenced_graphs:
            batch['text_sequence'].append(graph['text_sequence'])
            batch['node_sequence'].append(graph['node_sequence'])
            edge_matrices.append([
                tokenizer(edges, return_tensors='pt', padding=True)['input_ids']
                for edges in graph['edge_matrix']
            ])
            max_nodes = max(max_nodes, edge_matrices[-1][0].size(0))
            sequence_length = max([edges.size(1) for edges in edge_matrices[-1]])
            max_sequence = max(max_sequence, sequence_length)
        resized_edge_matrices = []
        tokenizer.padding_side = padding_side
        for edge_matrix in edge_matrices:
            resized_edge_matrix = tokenizer.pad_token_id * torch.ones(
                (max_nodes, max_nodes, max_sequence)
            )
            for row_idx, edges in enumerate(edge_matrix):
                resized_edge_matrix[row_idx, :edges.size(0), :edges.size(1)] = edges
            resized_edge_matrices.append(resized_edge_matrix.unsqueeze(0))
        batch['edge_matrices'] = torch.cat(resized_edge_matrices, dim=0).long()
        batch['target_edge_matrices'] = torch.where(
            batch['edge_matrices'] == tokenizer.pad_token_id,
            UNLABELLED_CATEGORICAL,
            batch['edge_matrices']
        )
        inputs, labels, nodes, texts = [], [], [], []
        for text, node_sequence in zip(
            tokenizer(batch['text_sequence'], padding=False)['input_ids'],
            tokenizer(batch['node_sequence'], padding=False)['input_ids']
        ):
            if has_encoder:
                labels.append(node_sequence[1:] + [tokenizer.eos_token_id])
            else:
                labels.append(
                    len(text) * [tokenizer.pad_token_id]
                    + node_sequence[1:] + [tokenizer.eos_token_id]
                )
            inputs.append(text + node_sequence)
            nodes.append(node_sequence)
            texts.append(text)
        inputs = tokenizer.pad(
            BatchEncoding({'input_ids': inputs}),
            return_attention_mask=True,
            return_tensors='pt'
        )
        batch['input_sequence'] = inputs['input_ids']
        batch['input_attn_mask'] = inputs['attention_mask']

        batch['target_sequence'] = tokenizer.pad(
            BatchEncoding({'input_ids': labels}),
            return_tensors='pt'
        )['input_ids']
        batch['target_sequence'] = torch.where(
            batch['target_sequence'] == tokenizer.pad_token_id,
            UNLABELLED_CATEGORICAL,
            batch['target_sequence']
        )
        node = tokenizer.pad(
            BatchEncoding({'input_ids': nodes}),
            return_attention_mask=True,
            return_tensors='pt'
        )
        batch['node_sequence'] = node['input_ids']
        batch['node_attn_mask'] = node['attention_mask']
        texts = tokenizer.pad(
            BatchEncoding({'input_ids': texts}),
            return_tensors='pt',
            return_attention_mask=True
        )
        batch['text_sequence'] = texts['input_ids']
        batch['text_attn_mask'] = texts['attention_mask']
        return dict(batch)

    @staticmethod
    def inputs2graph(
        inputs: Dict[str, torch.Tensor],
        tokenizer: PreTrainedTokenizerFast,
        file_paths: Optional[List[str]] = None
    ) -> List[TextGraph]:
        """ Processes a batch of model outputs, possibly generated, into a batch of graphs """
        graphs = []
        nodes_sequence = tokenizer.batch_decode(
            inputs['node_sequence'],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
            spaces_between_special_tokens=False
        )
        nodes_list = _desequence_nodes(nodes_sequence)
        for batch_idx, nodes in enumerate(nodes_list):
            graph = TextGraph(nodes=nodes, edges=[], edge_index=[])
            edge_matrix = [
                tokenizer.batch_decode(
                    edges,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=False,
                    spaces_between_special_tokens=False
                )
                for edges in inputs['edge_matrices'][batch_idx]
            ]
            graph.edges, graph.edge_index = _decode_edges(edge_matrix)
            graphs.append(graph)
        return graphs


class EdgesGen(nn.Module):
    def __init__(
        self,
        hidden_dim: int,
        vocab_size: int,
        bos_token_id: int
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.bos_token_id = bos_token_id
        self.edgeDecoder = GRUDecoder(hidden_dim, vocab_size)

    def forward(self, features: torch.Tensor, max_sequence_length: int) -> torch.Tensor:
        # features: num_nodes X batch_size X hidden_dim
        num_nodes, batch_size, hidden_dim = features.size()
        all_logits = torch.zeros(
            max_sequence_length,
            num_nodes * num_nodes * batch_size,
            self.vocab_size,
            device=features.device
        )
        input = torch.ones(
            num_nodes * num_nodes * batch_size,
            dtype=torch.long,
            device=features.device
        ) * self.bos_token_id
        # num_nodes X num_nodes X batch_size X hidden_dim
        feats = features.unsqueeze(0).expand(num_nodes, -1, -1, -1)
        # num_nodes*num_nodes*batch_size X hidden_dim
        hidden = (feats.permute(1, 0, 2, 3) - feats).reshape(-1, hidden_dim).contiguous()
        # set first token in output

        all_logits[0, :, input] = 1.0

        for t in range(1, max_sequence_length):
            output, hidden = self.edgeDecoder(input, hidden)
            all_logits[t] = output
            input = output.max(1)[1]
        # num_nodes X num_nodes X batch_size X seq_len X vocab_size
        all_logits = all_logits.reshape(
            max_sequence_length,
            num_nodes,
            num_nodes,
            batch_size,
            -1
        ).permute(1, 2, 3, 0, 4)
        return all_logits


class GRUDecoder(nn.Module):
    def __init__(self, hidden_size: int, vocab_size: int):
        super().__init__()
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, num_layers=1)
        self.out = nn.Linear(hidden_size, vocab_size)
        self.embedding = nn.Embedding(vocab_size, hidden_size)

    def forward(self, x: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # x: bsize
        # hidden: bsize X  hidden_dim or 1 X bsize X hidden_dim
        if len(hidden.size()) == 2:
            hidden = hidden.unsqueeze(0)  # to imitate num_layers=1
        emb_input = self.embedding(x)
        if len(x.size()) == 1:
            emb_input = emb_input.unsqueeze(1)  # bsize X 1 X emb_dim
        # else bsize X sent_len X emb_dim
        output = F.relu(emb_input)
        output, hidden = self.gru(output, hidden)
        output = self.out(output.squeeze())  # bsize X vocab_size OR bsize X sent_len X vocab_size
        return output, hidden


def _encode_edges(edge_index: List[List[int]], edges: List[str]) -> List[List[str]]:
    """ Returns a representation of edges in a graph which is a tensor of token ids """
    max_nodes = np.max(edge_index) + 1
    edge_matrix = [[NO_EDGE_TOKEN for _ in range(max_nodes)] for _ in range(max_nodes)]
    for edge_idx, node_idxs in enumerate(edge_index):
        edge_matrix[node_idxs[0]][node_idxs[1]] = EDGE_TOKEN + edges[edge_idx] + EDGE_TOKEN
    return edge_matrix


def _decode_edges(edge_matrix: List[List[str]]) -> Tuple[List[str], List[List[int]]]:
    """ Recovers edge features and indices from a edge matrix tensor of token ids """
    edge_features, edge_index = [], []
    for row_idx, edges in enumerate(edge_matrix):
        for col_idx, edge in enumerate(edges):
            if edge.startswith(EDGE_TOKEN):
                edge_index.append([row_idx, col_idx])
                edge_feature = [ef for ef in edge.split(EDGE_TOKEN) if ef != '']
                edge_features.append(edge_feature[0].strip() if len(edge_feature) > 0 else '')
    return edge_features, edge_index


def _add_separator_tokens_nodes(nodes: List[str]) -> str:
    """ Returns a sequence of nodes as a string with node tokens added to facilitate separating
        nodes from a bag of nodes sequence generated by a language model
    """
    return "".join(list(itertools.chain.from_iterable([[NODE_TOKEN, node, NODE_TOKEN] for node in nodes])))


def _desequence_nodes(nodes_list: List[str]) -> List[List[str]]:
    """ Returns a list of strings describing the nodes in a graph """
    desequenced_nodes_list = []
    for nodes in nodes_list:
        desequenced_nodes_list.append([])
        if NODE_TOKEN in nodes:
            node_sequence = nodes[nodes.find(NODE_TOKEN):]
            desequenced_nodes_list[-1] = [
                node.strip() for node in node_sequence.split(NODE_TOKEN) if node != ''
            ]
            desequenced_nodes_list[-1] = ambiguate_nodes(desequenced_nodes_list[-1])
    return desequenced_nodes_list
