import mlxu
from datetime import datetime
import tqdm
import json
from typing import List


class LineFile:
    def __init__(self, line_list: List[str]):
        self.line_list = line_list
        self.pointer = 0

    def readline(self):
        if self.pointer == len(self.line_list):
            return None
        else:
            result = self.line_list[self.pointer]
            self.pointer += 1
            return result

    def close(self):
        pass


class BuffferedFile:
    def __init__(self, path, buffer_size=2000):
        self.file = mlxu.open_file(path, "w")
        self.buffer = []
        self.buffer_size = buffer_size

    def flush(self):
        if len(self.buffer) > 0:
            text = "".join(self.buffer)
            self.buffer = []
            self.file.write(text)
            self.file.flush()
        else:
            self.file.flush()

    def write(self, text: str):
        self.buffer.append(text)
        if len(self.buffer) >= self.buffer_size:
            self.flush()

    def close(self):
        self.flush()
        self.file.close()


class DatasetWriter:
    def __init__(self, destination, expected_chars, text_field: str = "text"):
        self.destination = destination
        self.expected_chars = expected_chars
        self.text_field = text_field
        self.time_stamp = datetime.now().strftime("%Y%m%d-%H%M%S%f")

    def __enter__(self):
        self.stamped_destination = self.destination + f".{self.time_stamp}.jsonl"
        print(f"Will write result to {self.stamped_destination}")
        self.file = BuffferedFile(self.stamped_destination)

        self.written_chars = 0
        self.progress_bar = tqdm.tqdm(total=self.expected_chars)
        self.log_freq = 10000
        self.call_id = 0

        return self

    def add(self, data):
        update = len(data[self.text_field])
        dump = json.dumps(data) + "\n"

        self.file.write(dump)

        self.written_chars += update
        self.progress_bar.update(update)
        if self.call_id % self.log_freq == 0:
            tqdm.tqdm.write(
                f"Saved {dump[:32]}...{dump[-32:]}  {round(self.written_chars/self.expected_chars*100.0, 2)}% {self.written_chars}/{self.expected_chars}"
            )
        self.call_id += 1

    def is_full(self):
        return self.written_chars >= self.expected_chars

    def flush(self):
        self.file.flush()

    def __exit__(self, *args):
        print(f"Data written to {self.stamped_destination}")
        self.progress_bar.close()
        self.file.close()
