""" Module for loading and performing inference using a serialized graph generator"""

from collections import defaultdict
from functools import partial
import random
import re
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import numpy.typing as npt
import torch
from transformers import PreTrainedTokenizerFast, BatchEncoding

from text2graph.data.base_dataset import TextGraph
from text2graph.models.base_model import BaseModel, GeneratedText
from text2graph.models.edge_feature_processing import reindex_edge_index
from text2graph.training.losses_and_metrics import UNLABELLED_CATEGORICAL

SUCC_NODE_TOKEN = "<sn>"
EDGE_TOKEN = "<e>"
PRED_NODE_TOKEN = "<pe>"
BRANCH_TOKEN = "<bn>"
S_TOKENS = [EDGE_TOKEN, PRED_NODE_TOKEN, BRANCH_TOKEN, SUCC_NODE_TOKEN]


class SerializedGraphGenerator(BaseModel):
    """ A language model for performing inference on programmes """
    def __init__(self, metadata: Dict[str, Any]):
        assert 'serialization_type' in metadata
        super().__init__(metadata)
        self._add_new_tokens(S_TOKENS)
        graph_generator = (
            self.language_model.transformer if hasattr(self.language_model, 'transformer')
            else self.language_model.decoder
        )
        if self.metadata['message_passing_type'] != 'none':
            graph_token_ids = {
                'edge': self.tokenizer.convert_tokens_to_ids(EDGE_TOKEN),
                'eos': self.tokenizer.eos_token_id,
                'pred_node': self.tokenizer.convert_tokens_to_ids(PRED_NODE_TOKEN),
                'succ_node': self.tokenizer.convert_tokens_to_ids(SUCC_NODE_TOKEN),
                'pad': self.tokenizer.pad_token_id
            }
            graph_generator.init_graph_information_passing(
                gnn_type='sage',
                element_type=self.metadata['message_passing_type'],
                graph_token_ids=graph_token_ids
            )

    def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """ Performs a forward pass through the model and returns the outputs as a dictionary of
            tensors
        """
        if self.has_encoder:
            return {
                'graph_sequence_logits':  self.language_model(
                    input_ids=inputs['text_sequence'],
                    attention_mask=inputs['text_attn_mask'],
                    decoder_input_ids=inputs['graph_sequence'],
                    decoder_attention_mask=inputs['graph_attn_mask']
                ).logits
            }
        else:
            return {
                'graph_sequence_logits':  self.language_model(
                    input_ids=inputs['input_sequence'],
                    attention_mask=inputs['input_attn_mask']
                ).logits
            }

    def get_collate_fn(self):
        """ Returns a batching function for loading data to train and evaluate the model """
        return partial(
            self.text_graph2inputs_batch,
            tokenizer=self.tokenizer,
            randomize=self.metadata['randomize_sequence'],
            serialization_type=self.metadata['serialization_type'],
            has_encoder=self.has_encoder,
            truncation_length=self.metadata.get('truncation_length', -1),
            quantize=self.metadata.get('quantize', False)
        )

    def generate_graph(
        self,
        text_sequence: torch.Tensor,
        text_attn_mask: torch.Tensor,
        do_sample: bool,
        max_new_tokens: int,
        num_beams: int = 1,
        **kwargs
    ) -> Tuple[List[TextGraph], GeneratedText]:
        """ Generates a graph for each data point in a batch given the points text input """
        if max_new_tokens == -1:
            max_new_tokens = kwargs['graph_sequence'].shape[1]
        generated_graphs = self._generate_language(
            text_sequence=text_sequence,
            text_attn_mask=text_attn_mask,
            do_sample=do_sample,
            max_new_tokens=max_new_tokens,
            num_beams=num_beams,
            start_token_id=self.tokenizer.convert_tokens_to_ids(PRED_NODE_TOKEN)
        )
        graph_sequences = self.tokenizer.batch_decode(
            generated_graphs.ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
            spaces_between_special_tokens=False
        )
        text_sequences = self.tokenizer.batch_decode(
            text_sequence,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
            spaces_between_special_tokens=False
        )
        graphs = []
        for text_sequence, graph_sequence in zip(text_sequences, graph_sequences):
            graph = _desequence_graph(graph_sequence)
            graph.text = text_sequence
            graphs.append(graph)
        return graphs, generated_graphs

    @staticmethod
    def text_graph2inputs(
        text_graph_pair: TextGraph,
        randomize: bool = False,
        serialization_type: str = 'depth'
    ) -> Dict[str, Union[str, npt.NDArray[np.int_]]]:
        """ Processes a data point's text and graph into a dictionary of tensors with the required
            inputs for the model and returns the dictionary
        """
        sequenced_graph = {}
        edge_idxs_ordered = _order_edges(text_graph_pair.edge_index, randomize, serialization_type)
        edge_index_ordered = np.array(
            [text_graph_pair.edge_index[edge_idx] for edge_idx in edge_idxs_ordered]
        )
        node_idxs_ordered = np.array(_order_nodes(edge_index_ordered))
        nodes = [text_graph_pair.nodes[node_idx] for node_idx in node_idxs_ordered]
        edges = [text_graph_pair.edges[edge_idx] for edge_idx in edge_idxs_ordered]
        edge_index = reindex_edge_index(node_idxs_ordered, edge_index_ordered)
        nodes_disambiguated = disambiguate_nodes(nodes)
        sequenced_graph['graph_sequence'] = _sequence_graph(
            nodes=nodes_disambiguated,
            edges=edges,
            edge_index=edge_index,
            randomize=randomize
        )
        sequenced_graph['text_sequence'] = text_graph_pair.text
        sequenced_graph['graph_sequence'] = "".join(sequenced_graph['graph_sequence'])
        sequenced_graph['file_path'] = text_graph_pair.file_path
        return sequenced_graph

    @staticmethod
    def text_graph2inputs_batch(
        text_graph_pairs: List[TextGraph],
        tokenizer: PreTrainedTokenizerFast,
        randomize: bool = False,
        serialization_type: str = 'depth',
        truncation_length: int = -1,
        has_encoder: bool = False,
        quantize: bool = False
    ) -> Dict[str, torch.Tensor]:
        """ Processes a batch of text-graph pairs into a dictionary of tensors with the required
            inputs for a model and returns the dictionary
        """
        sequenced_graphs = [
            SerializedGraphGenerator.text_graph2inputs(
                text_graph_pair,
                randomize=randomize,
                serialization_type=serialization_type
            )
            for text_graph_pair in text_graph_pairs
        ]
        batch = defaultdict(list)
        for graph in sequenced_graphs:
            batch['text_sequence'].append(graph['text_sequence'])
            batch['graph_sequence'].append(graph['graph_sequence'])
            batch['file_path'].append(graph['file_path'])
        inputs, labels, graphs, texts = [], [], [], []
        for text, graph_sequence in zip(
            tokenizer(batch['text_sequence'], padding=False)['input_ids'],
            tokenizer(batch['graph_sequence'], padding=False)['input_ids']
        ):
            if has_encoder:
                labels.append(graph_sequence[1:] + [tokenizer.eos_token_id])
            else:
                labels.append(
                    len(text) * [tokenizer.pad_token_id]
                    + graph_sequence[1:] + [tokenizer.eos_token_id]
                )
            inputs.append(text + graph_sequence)
            graphs.append(graph_sequence)
            texts.append(text)
        inputs = tokenizer.pad(
            BatchEncoding({'input_ids': inputs}),
            return_attention_mask=True,
            return_tensors='pt',
        )
        graph = tokenizer.pad(
            BatchEncoding({'input_ids': graphs}),
            return_attention_mask=True,
            return_tensors='pt',
        )
        texts = tokenizer.pad(
            BatchEncoding({'input_ids': texts}),
            return_tensors='pt',
            return_attention_mask=True,
        )
        batch['input_sequence'] = inputs['input_ids']
        batch['input_attn_mask'] = inputs['attention_mask']
        batch['graph_sequence'] = graph['input_ids']
        batch['graph_attn_mask'] = graph['attention_mask']
        batch['text_sequence'] = texts['input_ids']
        batch['text_attn_mask'] = texts['attention_mask']
        batch['target_sequence'] = tokenizer.pad(
            BatchEncoding({'input_ids': labels}),
            return_tensors='pt'
        )['input_ids']
        batch['target_sequence'] = torch.where(
            batch['target_sequence'] == tokenizer.pad_token_id,
            UNLABELLED_CATEGORICAL,
            batch['target_sequence']
        )
        if truncation_length > 0:
            for key, value in batch.items():
                batch[key] = value[:, :truncation_length]
                if quantize:
                    batch[key] = batch[key].type(torch.float16)
        return dict(batch)

    @staticmethod
    def inputs2graph(
        inputs: Dict[str, torch.Tensor],
        tokenizer: PreTrainedTokenizerFast,
        file_paths: Optional[List[str]] = None
    ) -> List[TextGraph]:
        """ Processes a batch of model outputs, possibly generated, into a batch of graphs """
        graph_sequences = tokenizer.batch_decode(
            inputs['graph_sequence'],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
            spaces_between_special_tokens=False
        )
        graphs = []
        for idx, graph_sequence in enumerate(graph_sequences):
            graphs.append(_desequence_graph(graph_sequence))
            if file_paths is not None:
                graphs[-1].file_path = file_paths[idx]
        return graphs


def disambiguate_nodes(nodes: List[str]) -> List[str]:
    """ Renames nodes in order to differentiate nodes that have non-unique feature strings """
    nodes_disambiguated = []
    feature_counts = defaultdict(lambda: 0)
    for node in nodes:
        feature_counts[node] += 1
    type_counter = defaultdict(lambda: 0)
    for node_feature in nodes:
        if feature_counts[node_feature] > 1:
            count = type_counter[node_feature]
            nodes_disambiguated.append(node_feature + BRANCH_TOKEN + str(count))
            type_counter[node_feature] += 1
        else:
            nodes_disambiguated.append(node_feature)
    return nodes_disambiguated


def ambiguate_nodes(nodes: List[str]) -> List[str]:
    """ Returns original node names for nodes that were disambiguated """
    return [
        node.split(BRANCH_TOKEN)[0].strip() if BRANCH_TOKEN in node else node for node in nodes
    ]

def _sequence_graph(
    *,
    nodes: List[str],
    edges: List[str],
    edge_index: npt.NDArray[np.int_],
    randomize: bool
) -> List[str]:
    """ Returns a string representation of a graph which describes the graph based on the inputted
        ordering of edges
    """
    nodes2edges = defaultdict(list)
    for edge_idx, edge in enumerate(edge_index):
        for node_idx in np.unique(edge):
            nodes2edges[int(node_idx)].append(edge_idx)
    explored = defaultdict(lambda: False)
    graph_sequence = []
    for edge_idx, node_idxs in enumerate(edge_index):
        candidate_parents = [node_idx for node_idx in node_idxs if explored[node_idx]]
        if len(candidate_parents) == 0:
            candidate_parents = node_idxs
        parent_idx = random.choice(candidate_parents) if randomize else candidate_parents[0]
        child_idx = _get_neighbor_idx(edge_index[edge_idx].tolist(), parent_idx)
        graph_sequence.extend(
            [PRED_NODE_TOKEN, parent_idx, EDGE_TOKEN, edges[edge_idx], SUCC_NODE_TOKEN, child_idx]
        )
        explored[parent_idx] = True
        explored[child_idx] = True
    return [
        element if isinstance(element, str) else nodes[element]
        for element in graph_sequence
    ]


def _get_neighbor_idx(node_idxs: List[int], node_idx: int) -> int:
    """ Returns the node idx of the neighboring node in an edge in a graph """
    assert len(node_idxs)== 2, 'edge contains more than two nodes'
    edge_node_idx = node_idxs.index(node_idx)
    edge_neighbor_idx = 0 if edge_node_idx == 1 else 1
    return node_idxs[edge_neighbor_idx]


def _order_edges(
    edge_index: List[List[int]],
    randomize: bool,
    serialization_type: str
) -> List[int]:
    """ Returns an ordering of the edges in the graph such that every edge later in the list
        has at least one node in one of the edges that appear earlier in the list. The ordering
        can be generated in a random or deterministic manner
    """
    assert serialization_type in ['arbitrary', 'depth', 'breadth', 'tree'], "Unsupported serialization type"
    if serialization_type == 'arbitrary':
        edge_idxs = list(range(len(edge_index)))
        if randomize:
            random.shuffle(edge_idxs)
        return edge_idxs
    nodes2edges, edges2edges = defaultdict(list), defaultdict(list)
    for edge_idx, edge in enumerate(edge_index):
        for node_idx in np.unique(edge):
            nodes2edges[int(node_idx)].append(edge_idx)
    for node_idx, edge_idxs in nodes2edges.items():
        for idx, edge_idx in enumerate(edge_idxs):
            if idx == 0:
                edges2edges[edge_idx].extend(edge_idxs[1:])
            elif idx == len(edge_idxs) - 1:
                edges2edges[edge_idx].extend(edge_idxs[:-1])
            else:
                edges2edges[edge_idx].extend(edge_idxs[:idx] + edge_idxs[(idx + 1):])
    start_edge_idx = random.choice(list(edges2edges.keys())) if randomize else 0
    queue = Chain([start_edge_idx], randomize, serialization_type)
    edge_idxs_ordered = []
    for _ in range(len(edge_index)):
        current_edge_idx = queue.pop()
        edge_idxs_ordered.append(current_edge_idx)
        queue.add(edges2edges[current_edge_idx])
    assert len(queue) == 0
    return edge_idxs_ordered


class Chain():
    """ A linked list for storing a list of integers which does not allow elements to be added
        back again once they have been removed
    """
    def __init__(self, initial_elements: List[int], randomize: bool, serialization_type: str):
        assert serialization_type in ['tree', 'breadth', 'depth'], "Unsupported serialization type"
        if serialization_type == 'tree':
            assert randomize, 'tree orderings must be randomly generated'
        self.serialization_type = serialization_type
        self.randomize = randomize
        self._next_links, self._prev_links = {}, {}
        self._queued, self._explored = defaultdict(lambda: False), defaultdict(lambda: False)
        if len(initial_elements) > 1:
            for element0, element1 in zip(initial_elements[:-1], initial_elements[1:]):
                self._next_links[element0] = element1
                self._prev_links[element1] = element0
        self._next_links[initial_elements[-1]] = None
        self._prev_links[initial_elements[0]] = None
        self._end = initial_elements[-1]
        for key in self._next_links:
            self._queued[key] = True

    def __len__(self):
        return len(self._next_links)

    def _remove_element(self, element: int):
        prev_element = self._prev_links.pop(element)
        next_element = self._next_links.pop(element)
        if next_element is not None:
            self._prev_links[next_element] = prev_element
        else:
            self._end = prev_element
        if prev_element is not None:
            self._next_links[prev_element] = next_element

    def pop(self) -> int:
        """ Removes an element from the linked list and returns it """
        assert len(self._next_links) == len(self._prev_links)
        candidates = list(self._next_links.keys())
        if self.serialization_type == 'tree':
            current_element = random.choice(candidates)
        elif self.serialization_type in ['depth']:
            current_element = candidates[-1]
        else:
            current_element = candidates[0]
        self._explored[current_element] = True
        self._queued[current_element] = False
        self._remove_element(current_element)
        return current_element

    def add(self, new_elements: List[int]):
        """ Adds elements to the end of a linked list if they have not already been explored /
            removed from the linked list
        """
        if self.randomize:
            random.shuffle(new_elements)
        for element in new_elements:
            if self._explored[element]:
                continue
            if self._queued[element] and self.serialization_type != 'breadth':
                self._remove_element(element)
            elif self._queued[element]:
                continue
            if len(self._next_links) == 0:
                self._prev_links[element] = None
            else:
                self._next_links[self._end] = element
                self._prev_links[element] = self._end
            self._next_links[element] = None
            self._end = element
            self._queued[element] = True


def _order_nodes(edge_index: npt.NDArray[np.int_]) -> List[int]:
    """ Returns a list of nodes ordered based on the order in which they appear in a list of
        ordered edges in a graph
    """
    explored = defaultdict(lambda: False)
    node_idxs_ordered = []
    for edge in edge_index:
        for node_idx in np.unique(edge):
            if explored[node_idx]:
                continue
            node_idxs_ordered.append(node_idx)
            explored[node_idx] = True
    return node_idxs_ordered


def _desequence_graph(graph_sequence: str) -> TextGraph:
    """ Desequences a string that describes a graph into a list of node features, edge_features
        and adjacency matrix indices (edge_index) and returns all three lists
    """
    nodes, edges, edge_index = [], [], []
    delimiters = [PRED_NODE_TOKEN, SUCC_NODE_TOKEN, EDGE_TOKEN]
    delimiter_string = "(" + "|".join(delimiters) + ")"
    split_sequence = re.split(delimiter_string, graph_sequence)
    if PRED_NODE_TOKEN not in split_sequence:
        return TextGraph(nodes=[], edges=[], edge_index=[])
    else:
        split_sequence = split_sequence[split_sequence.index(PRED_NODE_TOKEN):]
    split_sequence = [element for element in split_sequence if element]
    split_sequence_processed = []
    for element0, element1 in zip(split_sequence[:-1], split_sequence[1:]):
        if element0 in [PRED_NODE_TOKEN, SUCC_NODE_TOKEN] and element1.strip() not in nodes:
            nodes.append(element1.strip())
            split_sequence_processed.append(len(nodes) - 1)
        elif element0 in [PRED_NODE_TOKEN, SUCC_NODE_TOKEN]:
            split_sequence_processed.append(nodes.index(element1.strip()))
        elif element0 == EDGE_TOKEN and element1 not in delimiters:
            edges.append(element1.strip())
        elif element1 not in [PRED_NODE_TOKEN, SUCC_NODE_TOKEN]:
            split_sequence_processed.append(element1)
    nodes = ambiguate_nodes(nodes)
    for idx in range(len(split_sequence_processed[:-2])):
        element0 = split_sequence_processed[idx]
        element1 = split_sequence_processed[idx + 1]
        element2 = split_sequence_processed[idx + 2]
        if isinstance(element0, int) and element1 == EDGE_TOKEN and isinstance(element2, int):
            edge_index.append([element0, element2])
    return TextGraph(nodes=nodes, edges=edges, edge_index=edge_index)
