import functools
import io
import json
import os
import random
from datetime import datetime
from queue import Queue
from typing import Optional

import jsonlines
import mlxu
from absl import app, flags
from datasets import load_dataset
from google.cloud import storage
from retriv import SearchEngine
import logging
from typing import List, Dict
import sys
import tqdm
from .utils import MultiLogger
import string 
from EasyLM.models.llama.llama_model import LLaMATokenizer
from joblib import Parallel, delayed

LOGGER = logging.Logger("Parquet dataset", level=logging.INFO)
LOGGER_HANDLER = logging.StreamHandler(sys.stderr)
LOGGER_HANDLER.setFormatter(logging.Formatter("[%(asctime)s] TSPPGatherer [%(levelname)s] : %(message)s"))
LOGGER.addHandler(LOGGER_HANDLER)

LOGGER = MultiLogger(basic_loggers=[print], advanced_loggers=[LOGGER])


if __name__ == "__main__":
    flags.DEFINE_string("source_dataset_mode", "hf", "hf|file")
    flags.DEFINE_string("source_dataset", "bigcode/starcoderdata", "")
    flags.DEFINE_string("source_dataset_subset", None, "")
    flags.DEFINE_string("source_dataset_split", None, "")
    flags.DEFINE_string("destination_dataset_path", None, "")
    flags.DEFINE_string("content_field", "content", "")

    flags.DEFINE_integer("min_len_in_chars", 128, "")
    flags.DEFINE_string("aggregate_mode", "bm25", "bm25|repo|random")
    flags.DEFINE_integer("topk", 3, "")
    flags.DEFINE_integer("topk_sample", None, "")
    flags.DEFINE_boolean("reverse_tree", True, "")
    flags.DEFINE_integer("prefilter_max_doc_chars", 30_000, "")
    flags.DEFINE_integer("max_doc_chars", 120_000, "")
    flags.DEFINE_integer("max_tokens", 100_000_000, "")
    flags.DEFINE_integer("chars_per_token", 4, "")
    flags.DEFINE_boolean("precompute", True, "")
    flags.DEFINE_float("stop_fraction", 0.8, "")
    flags.DEFINE_boolean("random_shuffle_tree", False, "")
    flags.DEFINE_integer("take_char_prefix", None, "")
    flags.DEFINE_string("tokenizer_path", None, "")
    FLAGS = flags.FLAGS

os.environ["HF_DATASETS_CACHE"] = "/dev/shm"
random.seed(2137)

# Instantiate a Google Cloud Storage client
storage_client = storage.Client()

# Specify the name of your GCS bucket
bucket_name = "focused-llama"

# Get the bucket from the client
bucket = storage_client.get_bucket(bucket_name)


# def process_chunk(df):
#     LOGGER.info("Num rows before filtering: " + str(len(df)))
#     LOGGER.info("Starting processing")
#     df = df[df[FLAGS.content_field].str.len() >= FLAGS.min_len_in_chars]
#     LOGGER.info("Filtered by length")
#     repo_level_code = df.to_dict(orient="records")

#     return [
#         {
#             "repo_name": dct["max_stars_repo_name"],
#             "repo_path": dct["max_stars_repo_path"],
#             "text": dct[FLAGS.content_field],
#         }
#         for dct in repo_level_code
#     ]


def filter_out(list_of_records: List[Dict[str, str]], min_len_in_chars: int, max_len_in_chars: int, take_char_prefix: Optional[int]):
    result = []
    total_chars = 0
    LOGGER.info(f"Before filtering has {len(list_of_records)} records")
    for r in list_of_records:
        if len(r["text"]) >= min_len_in_chars and len(r["text"]) <= max_len_in_chars:
            result.append(r)
            total_chars += len(r["text"])
            if take_char_prefix is not None and total_chars >= take_char_prefix:
                break
    LOGGER.info(
        f"After filtering has {len(result)} records and {total_chars} chars." f"That is around {total_chars/1e9} GB"
    )
    return result


def save_jsonl_to_gcs(data, DS_SUBSET):
    # Create an in-memory text stream
    text_stream = io.StringIO()

    # Use jsonlines to write data to the text stream
    with jsonlines.Writer(text_stream) as writer:
        for item in data:
            writer.write(item)

    # Append the current timestamp to the filename
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S%f")
    filename = f"datasets/semi-synthetic/{DS_SUBSET}-{timestamp}.jsonl"

    # Create a blob in the GCS bucket
    blob = bucket.blob(filename)

    # Write the text stream to the blob
    blob.upload_from_string(text_stream.getvalue())

    # Close the text stream
    text_stream.close()
    return filename


def save_jsonl_efficient(data, path: str):
    with mlxu.open_file(path, "w") as f:
        for item in data:
            f.write(json.dumps(item) + "\n")


def remove_random_element(lst):
    # Select a random index
    idx = random.randrange(len(lst))

    # Swap the element at the random index with the last element
    lst[idx], lst[-1] = lst[-1], lst[idx]

    # Remove the last element (which is now the randomly selected one)
    return lst.pop()


class IndexedData:
    def __init__(self, data) -> None:
        self.data = [
            {
                "id": i,
                "text": doc["text"],
                "repo_name": doc.get("repo_name", None),
                "repo_path": doc.get("repo_path", None),
            }
            for i, doc in enumerate(data)
        ]

        self.marked = set()
        self.precomputed = None

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i: int):
        return self.data[i]

    def get_marked_fraction(self):
        return len(self.marked) / len(self.data)

    def mark_data(self, data: dict):
        assert data["id"] not in self.marked
        self.marked.add(data["id"])

    def is_data_marked(self, data: dict):
        return data["id"] in self.marked

    def unmark_data(self, data: dict):
        self.marked.remove(data["id"])

    def iter_unmarked(self):
        for d in self.data:
            if d["id"] in self.marked:
                continue
            yield d

    def precompute_nn(self, retriever, cutoff):
        self.precomputed = retriever.bsearch(node_list=self.data, cutoff=cutoff)

    def get_precomputed_nn(self, node):
        assert self.precomputed is not None
        node_id = node["id"]
        assert self.data[node_id]["id"] == node_id
        result = self.precomputed[node_id]
        return result


class RetrieverBase:
    def __init__(self, indexed_data: IndexedData):
        self.indexed_data = indexed_data

    def search(self, node: dict, cutoff: int):
        raise NotImplementedError

    def bsearch(self, node_list: List[dict], cutoff: int) -> List[List[dict]]:
        result = []
        for node in node_list:
            result.append(self.search(node=node, cutoff=cutoff))

        return result
    


def tokenizer_wrapper(record: str, tokenizer):
   
   text = record["text"]
   
   symbols = [str(i) for i in range(10)] + [s for s in string.ascii_lowercase]
   def to_32_num_base(number):
      base = len(symbols)
      assert number >= 0
      if number == 0:
         return "0"
      else:
         digits = []
         while number > 0:
            digits.append(symbols[number % base])
            number = number // base
         return "".join(digits[::-1])
        

   assert isinstance(text, str)
   tokens = tokenizer.encode(text)
   assert isinstance(tokens, list)
   encoded_tokens = " ".join(list(map(lambda t: to_32_num_base(int(t)), tokens)))
   new_record = {}
   new_record["id"] = record["id"]
   new_record["text"] = encoded_tokens
   return new_record



class BM25Retriever(RetrieverBase):
    def __init__(self, indexed_data: IndexedData, tokenizer_path: Optional[str]):
        super().__init__(indexed_data)

        if tokenizer_path is not None:
            tokenizer = functools.partial(tokenizer_wrapper, tokenizer=LLaMATokenizer(
            vocab_file=tokenizer_path,
            add_bos_token=False,
            add_eos_token=False,
            ))
            self.tokenizer = tokenizer
            delayed_tokenizer = delayed(self.tokenizer)
            pretokenized_indexed_data = Parallel(n_jobs=32)(delayed_tokenizer(d) for d in indexed_data)
        else:
            self.tokenizer = None
            pretokenized_indexed_data = [{"id": d["id"], "text": d["text"]} for d in indexed_data]

        
        self.se = SearchEngine(model="bm25").index(pretokenized_indexed_data)

        

    def search(self, node: dict, cutoff: int):
        if self.tokenizer is not None:
            raw_result = self.se.search(self.tokenizer(node["text"]), cutoff=cutoff)
            return self.indexed_data[raw_result["id"]]
        else:
            return self.se.search(node["text"], cutoff=cutoff)

    def bsearch(self, node_list: List[dict], cutoff: int):
        LOGGER.info("preparing")
        queries = [{"id": i, "text": node["text"]} for i, node in zip(range(len(node_list)), node_list)]
        if self.tokenizer is not None:
            delayed_tokenizer = delayed(self.tokenizer)
            queries = Parallel(n_jobs=32)(delayed_tokenizer(q) for q in queries)
        result_dict = self.se.bsearch(queries=queries, cutoff=cutoff, batch_size=8192)
        LOGGER.info("got data")
        all_result = []
        for qid in range(len(queries)):
            result_ids_kv = list(result_dict[qid].items())
            # result_ids_kv = sorted(result_ids_kv, key=lambda x: -x[1])
            qdocs = []

            for i in range(len(result_ids_kv)):
                rdoc_id, rdoc_score = result_ids_kv[i]
                if i > 0:
                    assert rdoc_score <= result_ids_kv[i - 1][1]

                rdoc = self.indexed_data[rdoc_id]
                assert rdoc["id"] == rdoc_id
                qdocs.append(rdoc)

            all_result.append(qdocs)

        return all_result


def get_repo_path_with_id(data: dict):
    return data["repo_path"] + f"#{data['id']}"


class RepoRetrieverBase(RetrieverBase):
    def __init__(self, indexed_data: IndexedData):
        super().__init__(indexed_data)
        self.path_to_data = {"__/$/catalog": True}  # this should be a hash map
        for d in indexed_data:
            RepoRetrieverBase.append_to_catalog(d, self.path_to_data)

        LOGGER.debug(
            f"RepoRetriever: {len(self.path_to_data)} repos, {len(indexed_data)} files; {len(indexed_data)/len(self.path_to_data)}"
        )

    def append_to_catalog(data: dict, catalog_root: dict):
        repo = data["repo_name"]
        path = get_repo_path_with_id(data).split("/")
        full_path = [repo] + path
        assert len(full_path) > 1

        catalog = catalog_root
        for dir in full_path[:-1]:
            if dir not in catalog:
                catalog[dir] = {"__/$/catalog": True}
            catalog = catalog[dir]

        dir = full_path[-1]
        # print(f"RepoRetriever adding: {repo} {'/'.join(path)}")
        assert dir not in catalog
        catalog[dir] = data

    def retrieve_repo_dfs(data: dict, catalog_root: dict, indexed_data: Optional[IndexedData] = None):
        def dfs(node: dict, result: list, current_path: str, path_to_range: dict):
            if "__/$/catalog" not in node:
                assert "id" in node and "text" in node
                beg = len(result)
                if indexed_data is None or (not indexed_data.is_data_marked(node)):
                    result.append(node)

                path_to_range[current_path] = (beg, len(result))
            else:
                path_begin = len(result)
                for k, v in node.items():
                    if k != "__/$/catalog":
                        next_path = current_path + "/" + k if current_path != "" else k
                        dfs(
                            v,
                            result,
                            current_path=next_path,
                            path_to_range=path_to_range,
                        )

                path_end = len(result)
                path_to_range[current_path] = (path_begin, path_end)

        result = []
        path_to_range = {}
        repo = data["repo_name"]
        dfs(
            node=catalog_root[repo],
            result=result,
            current_path="",
            path_to_range=path_to_range,
        )
        return result, path_to_range

    def search(self, node: dict, cutoff: int):
        repo_files, path_to_range = RepoRetrieverBase.retrieve_repo_dfs(node, self.path_to_data, self.indexed_data)
        random.shuffle(repo_files)
        result = repo_files[:cutoff]
        LOGGER.debug(f"RepoRetrieverBase retrieved {len(result)} files from {node['repo_name']}")
        return result


class RepoRetrieverCatalog(RepoRetrieverBase):
    def __init__(self, indexed_data: IndexedData):
        super().__init__(indexed_data)

    def search(self, node: dict, cutoff: int):
        def cat_list_to_path(cat_list: list):
            return "/".join(cat_list)

        repo_files, path_to_range = RepoRetrieverBase.retrieve_repo_dfs(node, self.path_to_data, self.indexed_data)

        node_path = get_repo_path_with_id(node).split("/")
        beg, end = (
            path_to_range[cat_list_to_path(node_path)][0],
            path_to_range[cat_list_to_path(node_path)][0],
        )
        result = []
        for path_end in range(len(node_path), -1, -1):
            super_path = cat_list_to_path(node_path[:path_end])
            new_beg, new_end = path_to_range[super_path]
            new_files = repo_files[new_beg:beg] + repo_files[end:new_end]
            random.shuffle(new_files)
            beg, end = new_beg, new_end
            result = result + new_files

        assert len(result) == len(repo_files)
        result = result[:cutoff]
        LOGGER.debug(
            f"RepoRetrieverCatalog retrieved {len(result)} files from {node['repo_name']} with {len(repo_files)} unmarked files"
        )
        return result


class RandomRetriever(RetrieverBase):
    def __init__(self, indexed_data: IndexedData):
        super().__init__(indexed_data)

    def search(self, node: dict, cutoff: int):
        result = []
        sample = random.sample(range(len(self.indexed_data)), cutoff)
        for id in sample:
            result.append(self.indexed_data[id])

        return result


class EmptyRetriever(RetrieverBase):
    def __init__(self, indexed_data: IndexedData):
        super().__init__(indexed_data)

    def search(self, node: dict, cutoff: int):
        return []


def get_retriever(mode, indexed_data: IndexedData, tokenizer_path: Optional[str]):
    if mode == "bm25":
        return BM25Retriever(indexed_data=indexed_data, tokenizer_path=tokenizer_path)
    elif mode == "repo":
        return RepoRetrieverBase(indexed_data=indexed_data)
    elif mode == "repocatalog":
        return RepoRetrieverCatalog(indexed_data=indexed_data)
    elif mode == "random":
        return RandomRetriever(indexed_data=indexed_data)
    elif mode == "plain":
        return EmptyRetriever(indexed_data=indexed_data)
    else:
        raise ValueError(f"Mode {mode} not supported")


def aggregate_docs(
    docs,
    topk,
    max_doc_chars,
    mode,
    max_total_chars,
    min_len_in_chars,
    precompute,
    stop_fraction,
    random_shuffle_tree,
    reverse_tree,
    topk_sample,
    tokenizer_path,
):
    LOGGER.info(f"reverse_tree {reverse_tree} topk_sample {topk_sample}")
    indexed_data = IndexedData(data=docs)
    LOGGER.info(f"Total number of documents: {len(indexed_data)}")

    retriever = get_retriever(mode=mode, indexed_data=indexed_data, tokenizer_path=tokenizer_path)

    def get_cutoff(marked_fraction: float):
        return int(100 / (1 - min(marked_fraction, 0.8)))

    if precompute:
        indexed_data.precompute_nn(retriever=retriever, cutoff=get_cutoff(1.0 - stop_fraction))

    resulting_dataset = []
    documents_created = 0
    characters_added = 0

    def length_within_bounds(node_text: str, total_doc_length: int):
        return len(node_text) + total_doc_length <= max_doc_chars

    progress_bar = tqdm.tqdm(total=max_total_chars, desc="Gathering docs")

    for root_node in indexed_data.iter_unmarked():
        if characters_added >= max_total_chars:
            break
        if indexed_data.get_marked_fraction() > stop_fraction:
            LOGGER.info("Visited 80% of documents, stopping")
            break

        assert root_node["id"] not in indexed_data.marked

        q = Queue()
        indexed_data.mark_data(root_node)
        q.put(root_node)

        resulting_tree = []
        total_doc_length = 0
        to_unmark = []

        while not q.empty():
            node = q.get()
            node_text = node["text"]
            if length_within_bounds(node_text, total_doc_length):
                resulting_tree.append(node_text)
                total_doc_length += len(node_text)

                cutoff = get_cutoff(indexed_data.get_marked_fraction())

                if precompute:
                    next_nodes = indexed_data.get_precomputed_nn(node)
                else:
                    next_nodes = retriever.search(node, cutoff=cutoff)

                descendants = []
                for nn in next_nodes:
                    if len(descendants) >= topk:
                        break
                    if not indexed_data.is_data_marked(nn):
                        descendants.append(nn)

                random.shuffle(descendants)

                if topk_sample is not None:
                    descendants = descendants[:topk_sample]

                if len(descendants) < topk:
                    LOGGER.debug(
                        f"Not enough descendants! got {len(descendants)} but wanted {topk}, cutoff is {cutoff}"
                    )

                for d in descendants:
                    assert not indexed_data.is_data_marked(d)
                    d_text = d["text"]
                    if length_within_bounds(d_text, total_doc_length):
                        indexed_data.mark_data(d)
                        q.put(d)
            else:
                to_unmark.append(node)

        for tu in to_unmark:
            indexed_data.unmark_data(tu)

        def merge_docs(docs):
            return "\n".join(docs)

        if random_shuffle_tree:
            random.shuffle(resulting_tree)

        doc_string = merge_docs(resulting_tree)
        if len(doc_string) >= min_len_in_chars:
            assert len(doc_string) <= max_doc_chars + 2 * len(resulting_tree)
            if not reverse_tree:
                resulting_dataset.append(doc_string)
            else:
                resulting_dataset.append(merge_docs(resulting_tree[::-1]))
            progress_bar.update(len(doc_string))
            characters_added += len(doc_string)
            LOGGER.debug(f"Number of documents created: {documents_created}")
            LOGGER.debug(f"Number of characters added: {characters_added}")
            LOGGER.debug(f"Number of documents visited: {len(indexed_data.marked)}")
            documents_created += 1

    LOGGER.info(f"Number of documents created: {documents_created}")
    LOGGER.info(f"Number of characters added: {characters_added}")
    LOGGER.info(f"Number of documents visited: {len(indexed_data.marked)}")

    return resulting_dataset


def get_records_from_dataset(
    ds_mode: str,
    ds_path: str,
    ds_subset: Optional[str],
    ds_split: Optional[str],
    topk: int,
    max_doc_chars: int,
    aggregate_mode: int,
    content_field: str,
    min_len_in_chars: int,
    max_total_chars,
    precompute: int,
    stop_fraction: float,
    prefilter_max_doc_chars: int,
    random_shuffle_tree: bool,
    reverse_tree: bool,
    topk_sample: int,
    take_char_prefix: Optional[int],
    tokenizer_path: Optional[str],
):
    #assert max_doc_chars >= prefilter_max_doc_chars
    LOGGER.info(f"Loading mode:{ds_mode} path:{ds_path} subset:{ds_subset} split:{ds_split}")
    if ds_mode == "hf":
        dataset = load_dataset(
            ds_path,
            data_dir=ds_subset,
            split=ds_split,
            # cache_dir="./hf_tmp_cache",
        )
        LOGGER.info(f"Dataset downloaded number of records is {len(dataset)}")
        dataset = dataset.select_columns([content_field, "max_stars_repo_path", "max_stars_repo_name"])
        LOGGER.info(f"Dataset columns selected")
        df = dataset.to_pandas()
        LOGGER.info(f"Dataset converted to pandas")
        result = df.to_dict(orient="records")
        result = [
            {
                "repo_name": dct["max_stars_repo_name"],
                "repo_path": dct["max_stars_repo_path"],
                "text": dct[content_field],
            }
            for dct in result
        ]
    elif ds_mode == "file":
        assert ds_subset is None
        assert ds_split is None

        result = []
        with mlxu.open_file(ds_path, "r") as f:
            while True:
                line = f.readline()
                if line and len(line) > 1:
                    d = json.loads(line)
                    new_d = {}
                    new_d["repo_name"] = d.get("max_stars_repo_name", None)
                    new_d["repo_path"] = d.get("max_stars_repo_path", None)
                    new_d["text"] = d[content_field]
                    result.append(new_d)
                else:
                    break
    else:
        raise ValueError(f"mode {ds_mode} not supported")

    LOGGER.info("Filtering")
    result = filter_out(result, min_len_in_chars, prefilter_max_doc_chars, take_char_prefix)

    random.shuffle(result)

    LOGGER.info("Aggregating")
    result = aggregate_docs(
        result,
        topk=topk,
        max_doc_chars=max_doc_chars,
        mode=aggregate_mode,
        max_total_chars=max_total_chars,
        min_len_in_chars=min_len_in_chars,
        precompute=precompute,
        stop_fraction=stop_fraction,
        random_shuffle_tree=random_shuffle_tree,
        reverse_tree=reverse_tree,
        topk_sample=topk_sample,
        tokenizer_path=tokenizer_path,
    )

    LOGGER.info("Finalizing")
    random.shuffle(result)

    result = [{"text": txt} for txt in result]

    return result


def main(_):
    assert FLAGS.destination_dataset_path is not None

    all_results = get_records_from_dataset(
        ds_mode=FLAGS.source_dataset_mode,
        ds_path=FLAGS.source_dataset,
        ds_subset=FLAGS.source_dataset_subset,
        ds_split=FLAGS.source_dataset_split,
        topk=FLAGS.topk,
        max_doc_chars=FLAGS.max_doc_chars,
        aggregate_mode=FLAGS.aggregate_mode,
        content_field=FLAGS.content_field,
        min_len_in_chars=FLAGS.min_len_in_chars,
        max_total_chars=FLAGS.max_tokens * FLAGS.chars_per_token,
        precompute=FLAGS.precompute,
        stop_fraction=FLAGS.stop_fraction,
        prefilter_max_doc_chars=FLAGS.prefilter_max_doc_chars,
        random_shuffle_tree=FLAGS.random_shuffle_tree,
        reverse_tree=FLAGS.reverse_tree,
        topk_sample=FLAGS.topk_sample,
        take_char_prefix=FLAGS.take_char_prefix,
        tokenizer_path=FLAGS.tokenizer_path,
    )

    LOGGER.info("Finished processing partitions. Now saving to GCS.")
    timestamp = datetime.now().strftime("%Y%m%d-%H%M")
    result_path = os.path.join(str(FLAGS.destination_dataset_path), str(FLAGS.source_dataset_subset))
    result_path = f"{result_path}--{timestamp}.{FLAGS.aggregate_mode}.jsonl"
    LOGGER.info("Starting savint to: " + result_path)
    save_jsonl_efficient(all_results, result_path)

    LOGGER.info("Finished saving to GCS. All results len: " + str(len(all_results)))
    LOGGER.info("Saved to: " + result_path)


