import logging

import torch


logger = logging.getLogger(__name__)

class ParacrawlDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, src_lang, trg_lang, path=None, split=None):
        self.file_path = file_path
        self.src_lang = src_lang
        self.trg_lang = trg_lang
        self.path = path  # dummy variable for Hugging Face 🤗 compitability reasons
        self.split = split  # dummy variable for Hugging Face 🤗 compitability reasons

        logger.info(f'Loading Paracrawl dataset from {file_path}')
        with open(self.file_path) as f:
            self.src_segments, self.trg_segments = self.__parse_file(f)
            # file = ''.join([next(f) for _ in range(lines_number)])

    def __parse_file(self, file):
        split_sentences_ = [line.split('\t') for line in file]
        split_sentences = [sentence for sentence in split_sentences_ if len(sentence) == 2]
        src_sentences = [sentence[0] for sentence in split_sentences]
        trg_sentences = [sentence[1] for sentence in split_sentences]
        assert len(src_sentences) == len(trg_sentences)
        return src_sentences, trg_sentences

    def __getitem__(self, i):
        src_segment = self.src_segments[i]
        trg_segment = self.trg_segments[i]
        return {'translation': {self.src_lang: src_segment, self.trg_lang: trg_segment}}

    def __len__(self):
        return len(self.src_segments)
