"""Data proeprocess."""
import os
import re
import json
import random
import numpy as np
import pandas as pd
from collections import defaultdict

from tqdm import tqdm
from langdetect import detect
from datasets import load_dataset, Dataset, load_from_disk
from sklearn.cluster import KMeans
from sentence_transformers import SentenceTransformer

from data_utils import save_dataset_to_disk
from api_utils import _set_openai_env, _call_gpt_api

class DataPreprocesser():
    """Preprocess training datasets."""
    def __init__(self, save_dir):
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)
        # seed everything
        random.seed(0)

    def preprocess_wikiHow(self, data_path, n_clusters=19, embedding_path=None, n_sampled=2000):
        """Preprocess wikiHow dataset as LIMA."""
        data = pd.read_csv(data_path)
        data_record = data.to_records()
        data_list = [{"id": int(item[0]), "headline": item[1], "title": item[2], "text": item[3]} for item in data_record]

        embeddings = None
        if not embedding_path:
            model = SentenceTransformer('all-MiniLM-L6-v2')
            total_titles = [item['title'] for item in data_list]
            embeddings =  model.encode(total_titles)
            np.save("embeddings.npy", embeddings)
        else:
            embeddings = np.load(embedding_path)
        
        kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(embeddings)
        cluster_labels = kmeans.labels_

        assert len(cluster_labels) == len(data_list)

        for i in range(len(data_list)):
            data_list[i]['cluster'] = int(cluster_labels[i])

        #sample from wikihow
        dataset_groups = [[] for _ in range(n_clusters)]
        for item in data_list:
            dataset_groups[item['cluster']].append(item)
        
        samples = []

        # calculate proportions
        proportions = [len(group) / len(data_list) for group in dataset_groups]

        g_id = 0
        for group in dataset_groups:
            samples.extend(random.sample(group, int(proportions[g_id]*n_sampled)))
            g_id += 1
        
        random.shuffle(samples)

        # santinize
        cleaned_samples = []
        for item in samples:
            if not item['text'] or not item['title'] or not item['headline'] or not isinstance(item['text'], str) or not isinstance(item['title'], str) or not isinstance(item['headline'], str):
                continue
            item['text'] = item['text'].replace("This article", "The following answer").replace("this article", "the following answer")
            # add other operations here
            cleaned_samples.append(item)
        
        # save
        json.dump(cleaned_samples, open(os.path.join(self.save_dir, "wikiHow.json"), "w"))

        return cleaned_samples

    def format_wikihow(self, data_path):
        with open(data_path) as f:
            samples = json.load(f)

        formatted_inputs = []
        formatted_outputs = []

        for item in samples:
            # headline = item['headline']
            item_input = item['title']
            item_output = item['text']
            formatted_inputs.append(item_input)
            formatted_outputs.append(item_output)
        
        formatted_dataset = Dataset.from_dict({"input": formatted_inputs, "output": formatted_outputs})

        save_dataset_to_disk(formatted_dataset, "./data/processed_wikihow/")

    def preprocess_stack_exchange(self):
        se_dataset = load_dataset("HuggingFaceH4/stack-exchange-preferences", streaming=False)['train']
        per_category_data = defaultdict(list)
        total_count = 0
        for item in tqdm(reversed(se_dataset)):
            if len(item["answers"]) == 0:
                continue
            top_answer = max(item["answers"], key=lambda a: a["pm_score"])

            if top_answer["pm_score"] <= 5:
                continue

            top_answer_text = top_answer["text"]
            # delete all HTML tags
            top_answer_text = re.sub(r'<[^>]+>', '', top_answer_text)

            if len(top_answer_text) < 200 or len(top_answer_text) > 4096:
                continue

            question_text = re.sub(r'<[^>]+>', '', item["question"])

            # extract category from metadata
            category = item["metadata"][1].split(".")[0]
            # there are about 170 categories, for each categories, we keep at most 20 data

            # print(item["question"], top_answer)
            add_item = {
                "url": item["metadata"][0],
                "qid": item["qid"],
                "question": question_text,
                "answer": top_answer_text,
                "category": category,
            }
            if len(per_category_data[category]) == 0:
                print(total_count, category)
            # print(item["metadata"], add_item)

            if len(per_category_data[category]) < 20:
                total_count = total_count + 1
                per_category_data[category].append(add_item)
            else:
                # random index from 0 to 20
                random_index = random.randint(0, 19)
                per_category_data[category][random_index] = add_item
                # break
        all_data = []
        for k, v in per_category_data.items():
            all_data.extend(v)
        print(len(all_data))
        with open(os.path.join(self.save_dir, "stack_exchange.json") , "w") as f:
            json.dump(all_data, f)


    def format_stack_exchange(self, data_path):
        with open(data_path) as f:
            samples = json.load(f)

        formatted_inputs = []
        formatted_outputs = []

        for item in samples:
            item_input = item['question']
            item_output = item['answer']
            formatted_inputs.append(item_input)
            formatted_outputs.append(item_output)

        formatted_dataset = Dataset.from_dict({"input": formatted_inputs, "output": formatted_outputs})

        save_dataset_to_disk(formatted_dataset, "./data/processed_stack_exchange/")
    
    def preprocess_open_assistant(self):
        dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
        
        data_list = []
        for item in tqdm(reversed(dataset)):
            if len(item['text'].split("### Human: ")) < 2:
                continue
            _, text = item['text'].split("### Human: ")[:2]
            if len(text.split("### Assistant: ")) != 2:
                continue
            instruct, answer = text.split("### Assistant: ")

            try:
                if detect(instruct) != "en":
                    continue
            except Exception as e:
                # print(e)
                continue
            # print("instruct: ", instruct)
            # print("answer:", answer)

            add_item = {
                "input": instruct,
                "output": answer
            }

            data_list.append(add_item)
        print(len(data_list))
        with open(os.path.join(self.save_dir, "open_assistant.json") , "w") as f:
            json.dump(data_list, f)

    def format_open_assistant(self, data_path):
        with open(data_path) as f:
            samples = json.load(f)

        formatted_inputs = []
        formatted_outputs = []

        for item in samples:
            item_input = item['input']
            item_output = item['output']
            formatted_inputs.append(item_input)
            formatted_outputs.append(item_output)

        formatted_dataset = Dataset.from_dict({"input": formatted_inputs, "output": formatted_outputs})

        save_dataset_to_disk(formatted_dataset, "./data/processed_open_assistant/")

    def preprocess_evaluation_datasets(self):
        self_instruct_data = []
        with open("user_oriented_instructions.jsonl", 'r', encoding='utf-8') as f:
            for line in f:
                example = json.loads(line)
                processed_example = {
                    "input": example["instruction"] + "\n" + example["instances"][0]["input"],
                    "output": example["instances"][0]["output"]
                }
                self_instruct_data.append(processed_example)
        fast_chat_data = []
        with open("question.jsonl", "r", encoding="utf-8") as f:
            for line in f:
                example = json.loads(line)
                processed_example = {
                    "input": example["text"],
                    "output": ""
                }
                fast_chat_data.append(processed_example)
        
        # merge two datasets
        self_instruct_data.extend(fast_chat_data)

        print(len(self_instruct_data))

        _set_openai_env()

        g_data_items = []
        # call gpt
        for (idx, instance) in tqdm(enumerate(self_instruct_data)):

            instance['gpt_output'] = _call_gpt_api(
                prompt=instance['input'],
                model="gpt-3.5-turbo",
                n=5
            )
            # print(instance['gpt_output'])
            g_data_items.append(instance)

            if idx % 20 == 0:
                print(f"Saving {idx} generated examples to file.")
                
                with open(os.path.join(self.save_dir, "evaluation.json") , "w") as f:
                    json.dump(g_data_items, f)

        with open(os.path.join(self.save_dir, "evaluation.json") , "w") as f:
            json.dump(g_data_items, f)
        
    def format_evaluation(self, data_path):
        with open(data_path) as f:
            samples = json.load(f)

        formatted_inputs = []
        formatted_outputs = []

        for item in samples:
            item_input = item['input']
            for item_output in item['gpt_output']:
                formatted_inputs.append(item_input)
                formatted_outputs.append(item_output)

        formatted_dataset = Dataset.from_dict({"input": formatted_inputs, "output": formatted_outputs})

        save_dataset_to_disk(formatted_dataset, "./data/processed_evaluation/")

    def process_sharegpt(self):
        # ds = load_dataset("submodules/sharegpt_alpaca_oa_vicuna_format")
        # for old, new in [["prompt", "input"], ["label", "output"]]:
        #     ds = ds.rename_column(old, new)
        # ds_train = ds['train']
        # print(ds_train)
        # items = []
        # for i, item in tqdm(enumerate(ds_train), total=len(ds_train)):
        #     try:
        #         lang = detect(item['output'])
        #         if lang == "en":
        #             items.append(i)
        #     except Exception as e:
        #         print(e)
        #         continue
        # ds_train = ds_train.select(items)
        # save_dataset_to_disk(ds_train, "./data/sharegpt_train_en/")
        ds_train = load_from_disk("./data/sharegpt_train_en/")
        print(ds_train)
        ds_train = ds_train.filter(lambda x: x['input'].count("USER:") == 1)
        print("after filter USER multiturn:", ds_train)
        ds_train = ds_train.filter(lambda x: x['input'].count("ASSISTANT:") == 1)
        print("after filter ASSISTANT multiturn:", ds_train)
        ds_train = ds_train.filter(lambda x: x['input'].startswith("USER:"))
        print("after filter start user:", ds_train)
        ds_train = ds_train.filter(lambda x: x['input'].endswith("ASSISTANT:"))
        print("after filter end ASSISTANT:", ds_train)
        ds_train = ds_train.filter(lambda x: len(x['input'] + x['output']) <= 4096)
        print("after filter length:", ds_train)

        def clean_sharegpt(example):
            return {
                "input": example['input'].replace("USER: ", "").replace("ASSISTANT:", ""),
                "output": example["output"],
            }

        print(ds_train[0])
        ds_train = ds_train.map(clean_sharegpt)
        print(ds_train[0])

        save_dataset_to_disk(ds_train, "./data/sharegpt_train_en_clean/")
        # ds_train = load_from_disk("./data/sharegpt_train_en_clean/")

        # ds_train = ds_train.shuffle(seed=42)

        # for i, isz in enumerate([1, 2000, 4000, 8000, 16000, 32000]):
        #     part_ds = ds_train.select(range(isz))
        #     save_dataset_to_disk(part_ds, f"./data/processed_sharegpt_{i}_{isz}/")


    def process_dolly(self):
        # ds = load_dataset("databricks/databricks-dolly-15k")
        ds = load_from_disk("data/dolly")
        print(ds)
        from data_utils import extract_alpaca_dataset
        for old, new in [["context", "input"], ["response", "output"]]:
            ds = ds.rename_column(old, new)
        print(ds)
        print(ds[0])
        ds = ds.map(extract_alpaca_dataset)
        print(ds)
        print(ds[0])
        # almost every item is english 15011 -> 15006
        # ds = ds.filter(lambda x: detect(x['input'] + x['output']) == "en")
        ds = ds.filter(lambda x: len(x['input'] + x['output']) <= 4096)
        print("after filter length:", ds)
        save_dataset_to_disk(ds, "./data/processed_dolly/")


if __name__ == "__main__":
    pass

