import mlxu
from datasets import load_dataset
from typing import Optional, List
from absl import flags, app
import json
import numpy as np
import io
import pandas as pd
import requests
import tempfile
import random
import pyarrow as pa
import pyarrow.parquet as pq
import logging
import sys
import tqdm
from huggingface_hub import HfFileSystem
import functools
LOGGER = logging.Logger("Parquet dataset", level=logging.INFO)
LOGGER_HANDLER = logging.StreamHandler(sys.stdout)
LOGGER_HANDLER.setFormatter(logging.Formatter("[%(asctime)s] Parkiet [%(levelname)s] : %(message)s"))
LOGGER.addHandler(LOGGER_HANDLER)

FLAGS = flags.FLAGS
if __name__ == "__main__":
    flags.DEFINE_string("destination_dataset_path", None, "")
    flags.DEFINE_string("dataset", None, "")
    flags.DEFINE_string("source_field", None, "")
    flags.DEFINE_multi_string("fields_to_retain", [], "")
    flags.DEFINE_string("split", None, "")
    flags.DEFINE_string("API_TOKEN", None, "")
    flags.DEFINE_integer("num_tokens_to_load", 50_000_000, "")
    flags.DEFINE_float("chars_per_token", 4, "")
    flags.DEFINE_integer("drop_prefix", 0, "")

def get_ds_files_urls(hf_dataset: str):
    fs = HfFileSystem()

    files = fs.ls(hf_dataset, detail=False)

    return files

# Unused
def filter_out(response: dict, split: str):

    result = []
    for r in response["parquet_files"]:
        if r["split"] == split:
            result.append(r["url"])
    return result


class BuffferedFile:
    def __init__(self, path, buffer_size=2000):
        self.file = mlxu.open_file(path, "w")
        self.buffer = []
        self.buffer_size = buffer_size

    def flush(self):
        if len(self.buffer) > 0:
            text = "".join(self.buffer)
            self.buffer = []
            self.file.write(text)
            self.file.flush()
        else:
            self.file.flush()

    def write(self, text: str):
        self.buffer.append(text)
        if len(self.buffer) >= self.buffer_size:
            self.flush()
    
    def close(self):
        self.flush()
        self.file.close()


from datetime import datetime
import os

class DatasetWriter:
    def __init__(self, destination, char_limit):
        self.destination = destination
        self.char_limit = char_limit
        timestamp = datetime.now().strftime("%Y%m%d-%H%M%S%f")
        self.file = BuffferedFile(destination +  f".{timestamp}.jsonl")
        self.written_chars = 0
        self.progress_bar = tqdm.tqdm(total=self.char_limit)
        self.log_freq = 10000
        self.call_id = 0
    def add(self, data):
        dump = json.dumps(data) + "\n"
        
        self.file.write(dump)
        update = len(data["text"])
        self.written_chars += update
        self.progress_bar.update(update)
        if self.call_id % self.log_freq == 0:
            tqdm.tqdm.write(f"Saved {dump[:32]}...{dump[-32:]}  {round(self.written_chars/self.char_limit*100.0, 2)}% {self.written_chars}/{self.char_limit}")
        self.call_id += 1

    def is_full(self):
        return self.written_chars >= self.char_limit

    def flush(self):
        self.file.flush()

    def finish(self):
        self.progress_bar.close()
        self.file.close()

    

    

#test = ["https://huggingface.co/datasets/tiiuae/falcon-refinedweb/resolve/refs%2Fconvert%2Fparquet/tiiuae--falcon-refinedweb/parquet-train-05533-of-05534.parquet", "https://huggingface.co/datasets/tiiuae/falcon-refinedweb/resolve/refs%2Fconvert%2Fparquet/tiiuae--falcon-refinedweb/parquet-train-05526-of-05534.parquet"]
        


def save_dicts(urls, source_field, writer: DatasetWriter, fields_to_retain: List[str]):
    LOGGER.info(f"In addition to text will retain fields {fields_to_retain}")
    urls = [*urls]
    random.shuffle(urls)
    done = False
    fs = HfFileSystem()
    for url in urls:
        LOGGER.info(F"Processing {url}")

        with fs.open(url) as f:
            raw_data = f.read()
            file = io.BytesIO(raw_data)
        table = pq.read_table(file)

        table = table.to_pandas().to_dict()

        data_tuples = [table[source_field].values()]
        for ftr in fields_to_retain:
            data_tuples.append(table[ftr].values())

        
        data_tuples = zip(*data_tuples)
        
        
        for rdata in data_tuples:
            to_save = {}
            to_save["text"] = rdata[0]
            for ftr_k, ftr_v in zip(fields_to_retain, rdata[1:]):
                to_save[ftr_k] = ftr_v
            writer.add(to_save)
            if writer.is_full():
                done = True
                break

        if done == True:
            break
        writer.flush()




            
            
def prepare_dataset(dataset: str, split: str, source_field: str, save_path: str, num_tokens_to_load: int, chars_per_token: int, drop_prefix: int = 0):
    LOGGER.info(F"Gathering info about files from {dataset}")
    remote_files = get_ds_files_urls(hf_dataset=dataset)
    LOGGER.info(F"Got {len(remote_files)} files from {dataset}, truncating to {drop_prefix}")
    #remote_urls = filter_out(remote_files, split)
    #LOGGER.info(F"Got {len(remote_urls)} files from {dataset}, after truncating to {split}")
    remote_files=remote_files[drop_prefix:]
    LOGGER.info(F"Those files are {remote_files}")
    #assert False
    
    

    LOGGER.info(F"Creating writer {save_path} for {num_tokens_to_load*chars_per_token} chars that is {num_tokens_to_load} tokens assuming chars_per_token={chars_per_token}")
    print(F"Creating writer {save_path} for {num_tokens_to_load*chars_per_token} chars that is {num_tokens_to_load} tokens assuming chars_per_token={chars_per_token}")
    writer = DatasetWriter(save_path, num_tokens_to_load*chars_per_token)
    save_dicts(remote_files, source_field, writer, fields_to_retain=FLAGS.fields_to_retain)
    writer.finish()
    
    


    






def main(_):
    prepare_dataset(dataset=FLAGS.dataset, split=FLAGS.split, source_field=FLAGS.source_field, save_path=FLAGS.destination_dataset_path, num_tokens_to_load=FLAGS.num_tokens_to_load, chars_per_token=FLAGS.chars_per_token, drop_prefix=FLAGS.drop_prefix)

