""" Base class for all data sets """

from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
import json
import os
from typing import Any, Dict, List, Union

from torch.utils.data import Dataset

JSON_SUFFIX = ".json"


@dataclass
class TextGraph:
    """ A data class for representing a text graph where you can optionally include the text
        describing the graph
    """
    nodes: List[str]
    edges: List[str]
    edge_index: List[List[int]]
    text: str = ""
    file_path: str = ""


class BaseDataset(Dataset, ABC):
    """ Base class for all data sets """
    def __init__(
        self,
        *,
        split_name: str,
        parent_path: str,
        dataset_name: str,
        file_prefix: str,
        **kwargs
    ):
        super().__init__()
        self.dataset_path = os.path.join(parent_path, dataset_name)
        self.name = dataset_name
        self.split_name = split_name
        self.file_prefix = file_prefix


    def __len__(self) -> int:
        return len([
            file_name for file_name in os.listdir(self.dataset_path)
            if (
                f"{self.file_prefix}_{self.split_name}_" in file_name
                and file_name.endswith(JSON_SUFFIX)
            )
        ])

    def __getitem__(self, idx : int) -> TextGraph:
        """ Returns a data point in the data set as a TextGraph """
        file_path = os.path.join(
            self.dataset_path,
            f"{self.file_prefix}_{self.split_name}_{idx}{JSON_SUFFIX}"
        )
        with open(file_path, 'r', encoding='utf-8') as data_json:
            graph = json.load(data_json)
        return TextGraph(
            text=graph['text'],
            nodes=graph['nodes'],
            edges=graph['edges'],
            edge_index=graph['edge_index'],
            file_path=file_path
        )

    @classmethod
    @abstractmethod
    def metric_names(cls) -> List[str]:
        """ Returns a list of the names of evaluation metrics used to generate graphs """

    @staticmethod
    @abstractmethod
    def download_and_process(parent_path: str) -> None:
        """ Processes the raw data of a data set into (text, graph) pairs where
            each graph is represented by three lists: nodes -> List[str], edges -> List[str],
            edge_index -> List[List[int]]. The processed data is saved as json files, one for each
            of the three splits of the data (train, val, test)
        """

    @staticmethod
    @abstractmethod
    def eval_rep2graph(data_point: Any) -> Dict[str, Any]:
        """ NOTE: each data point in the data set will have a representation from which it is
            possible to calculate metrics which measure model performance.

            Maps the point's evaluatable representation to its graph representation and returns the
            graph representation. Each text graph pair is represented by a TextGraph instance

            This function is used for processing raw data and consequently returns a dictionary
            which may contain additional fields required for corollary data analysis tasks
        """

    @staticmethod
    @abstractmethod
    def graph2eval_rep(graph: TextGraph) -> Any:
        """ NOTE: each data point in the data set will have a representation from which it is
            possible to calculate metrics which measure model performance.

            Maps the point's text graph to its evaluatable representation and returns the evaluatable
            representation.
        """

    @abstractmethod
    def calculate_metrics(
        self,
        *,
        ground_truth_point: Any,
        generated_point: Any
    ) -> Dict[str, Union[int, float]]:
        """ Returns a dictionary of metrics comparing a generated point to its ground truth to
            measure the performance of the generative model
        """

    def calculate_metrics_from_graph(
        self,
        *,
        graph_ground_truth: TextGraph,
        graph_generated: TextGraph
    ) -> Dict[str, Union[int, float]]:
        """ Returns a dictionary of metrics comparing a generated point to its ground truth to
            measure the performance of the generative model where the ground truth and generated
            point are represented as a graph
        """
        metrics = {}
        ground_truth_eval_rep = self.graph2eval_rep(graph_ground_truth)
        try:
            assert len(graph_generated.nodes) != 0
            generated_eval_rep = self.graph2eval_rep(graph_generated)
            metrics.update(self.calculate_metrics(
                generated_point=generated_eval_rep,
                ground_truth_point=ground_truth_eval_rep
            ))
            metrics['parsability'] = 1
        except Exception:
            metrics['parsability'] = 0
        return metrics

    def calculate_metrics_batch(
        self,
        *,
        graphs_ground_truth: List[TextGraph],
        graphs_generated: List[TextGraph]
    ) -> Dict[str, Union[List[int], List[float]]]:
        """ Returns a dictionary of metrics comparing a batch of generated points to their ground
            truth to measure the performance of the generative model where the ground truth and
            generated point are represented as a graph
        """
        metrics = defaultdict(list)
        for graph_gt, graph_gen in zip(graphs_ground_truth, graphs_generated):
            point_metrics = self.calculate_metrics_from_graph(
                graph_ground_truth=graph_gt,
                graph_generated=graph_gen
            )
            for key, value in point_metrics.items():
                metrics[key].append(value)
        return dict(metrics)
