#! /usr/bin/env python3
# coding=utf-8


import logging
import os
import pickle

import pandas as pd
from torch.utils.data import Dataset
from tqdm import tqdm

from configs.rl_config import RLConfig

logger = logging.getLogger(__name__)


class AllsidesDataset(Dataset):
    def __init__(self, tokenizer, df_path):
        """
        :param tokenizer: LM tokenizer
        :param df_path: path to the csv file
        """
        self.config = RLConfig()

        _features_path = os.path.join(
            self.config.output_dir, self.config.cache_dir, "features"
        )

        if os.path.exists(_features_path) and not self.config.overwrite_cache:
            self.features = self.load_features(_features_path)
        else:
            logger.info("Creating features from dataset file at %s", _features_path)
            _df = pd.read_csv(df_path, header=0)
            self.features = self.encode_input(_df, tokenizer)
            self.save_features(_features_path)

    @staticmethod
    def load_features(feature_path):
        logger.info("Loading features from cached file %s", feature_path)
        with open(feature_path, "rb") as handle:
            return pickle.load(handle)

    def save_features(self, feature_path):
        logger.info("Saving features into cached file %s", feature_path)
        with open(feature_path, "wb") as handle:
            pickle.dump(self.features, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def encode_input(self, df, tokenizer):
        features = []
        for idx, row in tqdm(df.iterrows(), total=len(df)):
            input_ids = tokenizer.encode(row['sent1'],
                                         max_length=self.config.max_source_length,
                                         padding='max_length',
                                         truncation=True)
            labels = [tokenizer.encode(sent,
                                       max_length=self.config.max_target_length,
                                       padding='max_length',
                                       truncation=True) for sent in row['sent2']]
            features.append({'input_ids': input_ids,
                             'input_tokens': row['sent1'],
                             'labels_ids': labels,
                             'labels_tokens': row['sent2'],
                             'spice_scores': row['spice_score']})
        return features

    def __len__(self):
        return len(self.features)

    def __getitem__(self, item):
        return self.features[item]
