""" Base class for all data sets """

import argparse
import json
import os
import subprocess
from typing import Any, Dict, List, Tuple
import unicodedata as ud
from xml.etree.ElementTree import Element, SubElement
from xml.etree import ElementTree
from xml.dom import minidom
from typing import Any, Dict, List, Union, Tuple

import networkx as nx
import nltk
try:
    nltk.download('punkt')
except Exception:
    pass

from text2graph.data.base_dataset import BaseDataset, TextGraph
from text2graph.webnlg.WebNLGTexttoTriples import Evaluation_script_json

CURRENT_PATH = os.getcwd()


class WebNLG2020Dataset(BaseDataset):
    """ Data set class for WebNLG+ knowledge graph <> text data set """
    @staticmethod
    def metric_names() -> List[str]:
        """ Returns a list of the names of evaluation metrics used to generate graphs """
        return [
            'parsability',
            'F1_exact',
            'precision_exact',
            'recall_exact',
            'F1_partial',
            'precision_partial',
            'recall_partial',
            'F1_strict',
            'precision_strict',
            'recall_strict'
        ]

    @staticmethod
    def download_and_process(parent_path: str) -> None:
        """ Processes the raw data of a data set into (text, graph) pairs where
            each graph is represented by three lists: nodes -> List[str], edges -> List[str],
            edge_index -> List[List[int]]. The processed data is saved as json files, one for each
            of the three splits of the data (train, val, test)
        """
        dataset_name = 'webnlg2020'
        dataset_path = os.path.join(parent_path, dataset_name)
        if not os.path.isdir(dataset_path):
            os.mkdir(dataset_path)
        subprocess.call(
            f"../data/download_scripts/download_webnlg2020.sh {dataset_path}",
            shell=True
        )
        for split_name in ['train', 'dev', 'test']:
            split_dataset = _read_webnlg_dataset(dataset_path=dataset_path, split_name=split_name)
            split_data = []
            triples_list, text_list = _extract_text_triples_pairs(split_dataset)
            nodes_list, edges_list, edge_index_list = _parse_triples(triples_list)
            for text, nodes, edges, edge_index in zip(
                text_list,
                nodes_list,
                edges_list,
                edge_index_list
            ):
                split_data.append(
                    {'nodes': nodes, 'edges': edges, 'edge_index': edge_index, 'text': text}
                )
            for idx, graph in enumerate(split_data):
                processed_file_name = (
                    f"{dataset_name}_{split_name}_{idx}.json"
                    if split_name != 'dev' else f"{dataset_name}_val_{idx}.json"
                )
                processed_file_path = os.path.join(dataset_path, processed_file_name)
                with open(processed_file_path, 'w', encoding='utf-8') as split_file:
                    json.dump(graph, split_file)

    @staticmethod
    def eval_rep2graph(data_point: List[List[str]]) -> Dict[str, Any]:
        """ NOTE: each data point in the data set will have a representation from which it is
            possible to calculate metrics which measure model performance.

            Maps the point's evaluatable representation to its graph representation and returns the
            graph representation. Each graph is represented by three lists: nodes -> List[str],
            edges -> List[str], edge_index -> List[List[int]] contained within the output
            dictionary
        """
        nodes, edges, edge_index = set(), [], []
        for triple in data_point:
            nodes.add(triple[0])
            nodes.add(triple[2])
            edges.append(triple[1])
        nodes = list(nodes)
        node2idx = {node: idx for idx, node in enumerate(nodes)}
        for triple in data_point:
            edge_index.append([node2idx[triple[0]], node2idx[triple[2]]])
        return {'nodes': nodes, 'edges': edges, 'edge_index': edge_index}

    @staticmethod
    def graph2eval_rep(graph: TextGraph) -> List[List[str]]:
        """ NOTE: each data point in the data set will have a representation from which it is
            possible to calculate metrics which measure model performance.

            Maps the point's graph to its evaluatable representation and returns the evaluatable
            representation.
        """
        return [
            [graph.nodes[node_idxs[0]], edge_feature, graph.nodes[node_idxs[1]]]
            for node_idxs, edge_feature in zip(graph.edge_index, graph.edges)
        ]

    @staticmethod
    def calculate_metrics(
        *,
        ground_truth_point: List[List[str]],
        generated_point: List[List[str]]
    ) -> Dict[str, Union[int, float]]:
        """ Returns a dictionary of metrics comparing a generated point to its ground truth to
            measure the performance of the generative model
        """
        refs = [' | '.join(t) for t in ground_truth_point]
        hyps = [' | '.join(t) for t in generated_point]
        min_length = min(len(refs), len(hyps))
        refs = refs[:min_length]
        hyps = hyps[:min_length]

        ref_fname, hyp_fname = save_webnlg_rdf([hyps], [refs], CURRENT_PATH)

        scores_fname = os.path.join(CURRENT_PATH, 'scores.json')

        Evaluation_script_json.main(ref_fname, hyp_fname, scores_fname)

        scores = json.load(open(scores_fname))
        scores = {
            'F1_exact': scores['Total_scores']['Exact']['F1'],
            'precision_exact': scores['Total_scores']['Exact']['Precision'],
            'recall_exact': scores['Total_scores']['Exact']['Recall'],
            'F1_partial': scores['Total_scores']['Partial']['F1'],
            'precision_partial': scores['Total_scores']['Partial']['Precision'],
            'recall_partial': scores['Total_scores']['Partial']['Recall'],
            'F1_strict': scores['Total_scores']['Strict']['F1'],
            'precision_strict': scores['Total_scores']['Strict']['Precision'],
            'recall_strict': scores['Total_scores']['Strict']['Recall'],
        }
        return scores


def _read_webnlg_dataset(dataset_path: str, split_name: str) -> Dict[str, Any]:
    """ Reads a split of the web nlg data set from memory and returns it as a dictionary """
    from text2graph.webnlg.corpusreader.benchmark_reader import Benchmark, select_files
    b = Benchmark()
    split_path = os.path.join(dataset_path, 'webnlg-dataset', 'release_v3.0', 'en', split_name)
    if split_name == 'test':
        files = [(split_path, 'semantic-parsing-test-data-with-refs-en.xml')]
    else:
        files = select_files(split_path)
    b.fill_benchmark(files)
    b.b2json(dataset_path, f'{split_name}.json')
    return json.load(
        open(os.path.join(dataset_path, f'{split_name}.json'), encoding='utf-8')
    )


def _extract_text_triples_pairs(
    dataset: Dict[str, Any]
) -> Tuple[List[str], List[str]]:
    """ Extracts the knowledge graph text pairs from a data set stored in a dictionary """
    normalize = lambda text: ud.normalize('NFKD', text).encode('ascii', 'ignore').decode('ascii')
    triples_list = []
    text_list = []
    for ind, entry in enumerate(dataset['entries']):
        triples = entry[str(ind + 1)]['modifiedtripleset']
        proc_triples = []
        for triple in triples:
            obj, rel, sub = triple['object'], triple['property'], triple['subject']
            obj = normalize(obj.strip('\"').replace('_', ' '))
            sub = normalize(sub.strip('\"').replace('_', ' '))
            proc_triples.append(f'__subject__ {sub} __predicate__ {rel} __object__ {obj}')
        merged_triples = ' '.join(proc_triples)
        proc_lexs = [normalize(l['lex']) for l in entry[str(ind + 1)]['lexicalisations']]
        for lex in proc_lexs:
            text_list.append(
                normalize('summarize as a knowledge graph: ')
                + lex
            )
            triples_list.append(merged_triples)
    return triples_list, text_list


def _parse_triples(
    graph_data: List[str]
) -> Tuple[List[List[str]], List[List[str]], List[List[List[int]]]]:
    """ Processes and returns  a list of knowledge graphs represented as sequences of semantic / rdf
        triples as a list of nodes, edges and edge indexes.
    """
    all_nodes = []
    all_edges = []
    all_edges_ind = []
    for triples_str in graph_data:
        graph_nx = nx.DiGraph()
        triples_str += ' '
        for triple_str in triples_str.split('__subject__')[1:]:
            head = triple_str.split('__predicate__')[0][1:-1]
            relop = triple_str.split('__predicate__')[1].split('__object__')[0][1:-1]
            tail = triple_str.split('__predicate__')[1].split('__object__')[1][1:-1]
            graph_nx.add_edge(head, tail, edge=relop)
            graph_nx.nodes[head]['node'] = head
            graph_nx.nodes[tail]['node'] = tail
        nodes = list(graph_nx.nodes)
        edges = []
        edges_ind = []
        for u, v, d in graph_nx.edges(data=True):
            edges.append(d['edge'])
            edges_ind.append([nodes.index(u), nodes.index(v)])
        all_nodes.append(nodes)
        all_edges.append(edges)
        all_edges_ind.append(edges_ind)
    return all_nodes, all_edges, all_edges_ind


def create_xml(data, categories, ts_header, t_header):
    """ Returns an xml.etree.ElementTree.Element representation of a set of triples describing a
        knowledge graph
    """
    benchmark = Element('benchmark')
    entries = SubElement(benchmark, 'entries')

    assert len(categories) == len(data)

    for idx, triples in enumerate(data):
        entry = SubElement(entries, 'entry', {'category': categories[idx], 'eid': 'Id%s' % (idx + 1)})
        t_entry = SubElement(entry, ts_header)

        for triple in triples:
            element = SubElement(t_entry, t_header)
            element.text = triple

    return benchmark


def xml_prettify(elem):
    """Return a pretty-printed XML string for the Element.
       source : https://pymotw.com/2/xml/etree/ElementTree/create.html
    """
    rough_string = ElementTree.tostring(elem, 'utf-8')
    reparsed = minidom.parseString(rough_string)
    return reparsed.toprettyxml(indent="  ")


def save_webnlg_rdf(hyps: List[List[str]], refs: List[List[str]], out_dir: str) -> Tuple[str, str]:
    """ Saves the triples for a generated and ground truth knowledge graph as xml files to use
        downstream for calculating metrics
    """
    ref_xml = create_xml(refs, [' '] * len(refs), "modifiedtripleset", "mtriple")
    hyp_xml = create_xml(hyps, [' '] * len(hyps), "generatedtripleset", "gtriple")
    ref_fname = os.path.join(out_dir, "ref.xml")
    hyp_fname = os.path.join(out_dir, "hyp.xml")

    with open(ref_fname, 'w', encoding='utf-8') as f:
        f.write(xml_prettify(ref_xml))

    with open(hyp_fname, 'w', encoding='utf-8') as f:
        f.write(xml_prettify(hyp_xml))

    return ref_fname, hyp_fname


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='data set location')
    parser.add_argument('--dataset-path', type=str, required=True)
    args = parser.parse_args()
    WebNLG2020Dataset.download_and_process(args.dataset_path)
