"""The actual tf.data.Dataset."""
import csv
import dataclasses
import os

import tensorflow as tf

from em.util import env_util

###############################################################################

if env_util.on_fruit():
    DS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ead1/datasets'
else:
    DS_DIR = '~/Desktop/projects_data/extract_merge1/antiderivative/datasets'

###############################################################################


def load(
    task: str,
    split: str,
    tokenizer,
    sequence_length: int,
    *extra_args,
    **extra_kwargs,
) -> tf.data.Dataset:
    # For now, the full task_ri should always be ead/infix with the split providing
    # information on what file to read.
    assert task == 'infix'

    filepath = os.path.join(DS_DIR, f'{split}.csv')
    ds = load_raw_from_file(filepath, skip_unlabeled=True)
    ds = convert_dataset_to_features(ds, tokenizer, sequence_length=sequence_length)
    return ds


def n_classes_for_task(task: str):
    return 2


def de_facto_validation_split(task: str):
    return None


def examples_per_epoch(task: str):
    return None


###############################################################################


def _char_label_to_int(label: str):
    if label == '1':
        return 1
    elif label == '0':
        return 0
    elif label == '-':
        return -1


def _skip_unlabeled_filter_fn(x, y):
    return y >= 0


def load_raw_from_file(filepath: str, *, skip_unlabeled=True):
    with open(os.path.expanduser(filepath), newline='') as f:
        reader = csv.reader(f)
        rows = list(reader)

    expressions, labels = zip(*rows)
    expressions = tf.constant(expressions, dtype=tf.string)
    labels = tf.constant(
        [_char_label_to_int(label) for label in labels],
        dtype=tf.int32)

    ds = tf.data.Dataset.from_tensor_slices((expressions, labels))
    if skip_unlabeled:
        ds = ds.filter(_skip_unlabeled_filter_fn)

    return ds


def convert_to_prefix(raw_ds):
    pass


def convert_dataset_to_features(ds, tokenizer, sequence_length: int, *, ignore_too_long=True):
    pad_token = tokenizer.pad_token_id
    pad_token_type_id = tokenizer.pad_token_type_id

    def mutative_pad_list(x, pad_token):
        padding_length = sequence_length - len(x)
        x.extend(padding_length * [pad_token])

    def py_map_fn(equation):
        equation = tf.compat.as_str(equation.numpy())
        inputs = tokenizer.encode_plus(
            equation,
            add_special_tokens=True,
            max_length=sequence_length,
            return_token_type_ids=True,
            truncation=True,
        )
        mutative_pad_list(inputs["input_ids"], pad_token)
        mutative_pad_list(inputs["token_type_ids"], pad_token_type_id)

        return inputs['input_ids'], inputs['token_type_ids']

    def map_fn(x, y):
        input_ids, token_type_ids = tf.py_function(
            func=py_map_fn,
            inp=[x],
            Tout=[tf.int32, tf.int32],
        )
        example = {
            'input_ids': tf.reshape(input_ids, [sequence_length]),
            'token_type_ids': tf.reshape(token_type_ids, [sequence_length]),
        }
        return example, y

    def is_not_cutoff(x, y):
        # NOTE: This also filters out examples that are exactly the sequence length.
        return x['input_ids'][-1] == pad_token

    ds = ds.map(map_fn)

    if ignore_too_long:
        ds = ds.filter(is_not_cutoff)

    return ds


###################################################################################


@dataclasses.dataclass
class EadDsStats:
    elementary_antiderivative_count: int
    nonelementary_antiderivative_count: int
    unlabeled_count: int

    @property
    def total_count(self) -> int:
        return sum([
            self.elementary_antiderivative_count,
            self.nonelementary_antiderivative_count,
            self.unlabeled_count,
        ])


def get_stats(raw_ds: tf.data.Dataset) -> EadDsStats:
    ead_count = 0
    non_ead_count = 0
    unlabeled_count = 0
    for x, y in raw_ds.as_numpy_iterator():
        x = tf.compat.as_str(x)
        if y == -1:
            unlabeled_count += 1
        elif y == 0:
            non_ead_count += 1
        elif y == 1:
            ead_count += 1
        else:
            raise ValueError(f'Unrecognized label: {y}')

    return EadDsStats(
        elementary_antiderivative_count=ead_count,
        nonelementary_antiderivative_count=non_ead_count,
        unlabeled_count=unlabeled_count,
    )
