"""
Module with custom training objectives for training models to perform inference on programme data
"""
import enum
from functools import partial
from typing import Callable, List

import numpy as np
import torch
from torch import nn

from text2graph.data.base_dataset import TextGraph

UNLABELLED_CATEGORICAL = -100


def cross_entropy_w_missing_values(
    logits: torch.Tensor,
    label_idxs: torch.Tensor
) -> torch.Tensor:
    """ Returns weighted cross entropy loss after unlabelled values have been removed,
        it is assumed unlablled points have a sample value of -100
    """
    cross_entropy = nn.CrossEntropyLoss(reduction='none', ignore_index=UNLABELLED_CATEGORICAL)(
        input=logits.transpose(1, 2),
        target=label_idxs
    )
    return cross_entropy.sum(-1)


def measure_functional_differences(
    graphs_ground_truth: List[TextGraph],
    graphs_generated: List[TextGraph],
    eval_function: Callable,
    metric_name: str
) -> torch.Tensor:
    functional_differences = eval_function(
        graphs_ground_truth=graphs_ground_truth,
        graphs_generated=graphs_generated
    )
    if metric_name in functional_differences:
        return torch.from_numpy(np.array(functional_differences[metric_name])).float()
    else:
        return torch.tensor(float('nan'))


def compare_edge_indexes_undirected(
    *,
    edge_index0: List[List[int]],
    nodes0: List[str],
    edge_index1: List[List[int]],
    nodes1: List[str]
) -> bool:
    """ Returns a boolean indicating whether two edge indexes / adjacency matrices for an undirected
        graph based on a necessary condition for them to be the same
    """
    edges0 = sorted([
        sorted([nodes0[node_idxs[0]], nodes0[node_idxs[1]]]) for node_idxs in edge_index0
    ])
    edges1 = sorted([
        sorted([nodes1[node_idxs[0]], nodes1[node_idxs[1]]]) for node_idxs in edge_index1
    ])
    return edges0 == edges1


def compare_graphs(graph_0: TextGraph, graph_1: TextGraph) -> bool:
    """ Returns a boolean indicating whether two graphs are the same based on a set of necessary
        conditions for the graphs to be the same
    """
    return (
        sorted(graph_0.nodes) == sorted(graph_1.nodes)
        and sorted(graph_0.edges) == sorted(graph_1.edges)
        and compare_edge_indexes_undirected(
            edge_index0=graph_0.edge_index,
            nodes0=graph_0.nodes,
            edge_index1=graph_1.edge_index,
            nodes1=graph_1.nodes
        )
    )


def diversity(
    graphs_ground_truth: List[TextGraph],
    graphs_generated_0: List[TextGraph],
    graphs_generated_1: List[TextGraph]
) -> torch.Tensor:
    """ Returns a tensor indicating the diversity in a set of generated graphs based on comparing
        them using a set of necessary conditions to see if they are the same
    """
    diversity_indicators = []
    for g_t, gen_0, gen_1 in zip(graphs_ground_truth, graphs_generated_0, graphs_generated_1):
        if compare_graphs(g_t, gen_0) or compare_graphs(g_t, gen_1) or compare_graphs(gen_0, gen_1):
            diversity_indicators.append(0)
        else:
            diversity_indicators.append(1)
    return torch.from_numpy(np.array(diversity_indicators)).float()



def grapher_loss(
    logits_nodes: torch.Tensor,
    logits_edges: torch.Tensor,
    target_nodes: torch.Tensor,
    target_edges: torch.Tensor
) -> torch.Tensor:
    """ Returns loss for grapher model
        most of the code for this model comes from https://github.com/IBM/Grapher
        and the paper https://arxiv.org/abs/2211.10511
    """
    # --------- Node Loss ---------
    # shift forward 1 step to create labels
    loss_nodes = nn.CrossEntropyLoss(reduction='none', ignore_index=UNLABELLED_CATEGORICAL)(
        input=logits_nodes.transpose(1,2),
        target=target_nodes
    ).mean()
    logits_edges = logits_edges.permute(2, 4, 0, 1, 3)
    target_edges_exp = UNLABELLED_CATEGORICAL * torch.ones_like(logits_edges[:, 0]).to(torch.long)
    batch_size, num_nodes, _, seq_length = target_edges.shape
    seq_length = min(logits_edges.shape[-1], seq_length)
    target_edges_exp[:batch_size, :num_nodes, :num_nodes, :seq_length] = (
        target_edges[:batch_size, :num_nodes, :num_nodes, :seq_length]
    )
    loss_edges = nn.CrossEntropyLoss(reduction='none', ignore_index=UNLABELLED_CATEGORICAL)(
        logits_edges,
        target_edges_exp
    ).mean()
    loss = loss_nodes + loss_edges
    return loss


#pylint: disable=invalid-name
class MetricFactory(enum.Enum):
    """ A enum of metric functions, partial is used because otherwise assigning the function by
        itself would make the function a method of the MetricFactory instance as opposed to a value
        of the Enum
    """
    CrossEntropyWMissingValues = partial(cross_entropy_w_missing_values)
    FunctionalDifference = partial(measure_functional_differences)
    GrapherLoss = partial(grapher_loss)
    Diversity = partial(diversity)
