import json
import tqdm
import pathlib
import random


# class definition

class Dataset:
    class Prompt:
        def __init__(self) -> None:
            self.instructions = ""
            self.input_prefix = "Text: "
            self.input_suffix = "\n"
            self.output_prefix = "Answer: "
            self.output_suffix = "\n"
    
    class RequestState:
        class IOText:
            def __init__(self) -> None:
                self.text = ""
                
        class Instance:
            def __init__(self) -> None:
                self.input = Dataset.RequestState.IOText()
                self.references = [
                    {
                        "output": Dataset.RequestState.IOText()
                    }
                ]
                self.split = ""
                self.id = ""
                
        def __init__(self) -> None:
            self.instance = self.Instance()
        
    def __init__(self) -> None:
        self.prompt = self.Prompt()
        self.request_states = [] 
        
class ClassEncoder(json.JSONEncoder):
    def default(self, obj):
        return obj.__dict__  
        
# datasets

# (dataset_name, input_directory, label_position)
target_datasets = [
    ("conll-2003", "../../../data/CoNLL2003/", 3),
    # ("few-nerd-inter", "../../../data/FewNERD/inter/", 1),
    # ("few-nerd-intra", "../../../data/FewNERD/intra/", 1),
    ("few-nerd-supervised", "../../../data/FewNERD/supervised/", 1),
]

dest_prefix = "../../../unified_data/NER/"

split_types = ["train", "dev", "test"]
        
# data preprocess

def data_process(dataset_name: str, source: str, label_pos: int, split_type: str, random_sample: bool=True):
    dataset = Dataset()
    
    CONLL_LABLE_DICT = {
        "PER": "PER",
        "LOC": "LOC",
        "ORG": "ORG",
        "MISC": "MISC",
    }
    
    dataset.prompt.instructions = "Please recognize entities for the given text and classify them into a suitable type. The collection of types is as follows: "
    dataset.prompt.instructions += "["
    if dataset_name.startswith("conll-2003"):
        dataset.prompt.instructions += ", ".join(CONLL_LABLE_DICT.values())
    elif dataset_name.startswith("few-nerd"):
        dataset.prompt.instructions += ", ".join(
            # line.rstrip() for line in open("../data/raw/Few-NERD/data/labels.txt", "r", encoding="utf-8").readlines()
            []
        )
    else:
        raise ValueError(dataset_name)
    dataset.prompt.instructions += "]."
    dataset.prompt.instructions = "Named Entity Recognition: "
    
    input_text, ref_output_text = "", []
    request_id = 1
    last_label = "O"
    ongoing_entity = []

    source_path = source + split_type + ".txt"
    print(source_path)
    all_lines = open(source_path, encoding="utf-8").readlines()

    for line in tqdm.tqdm(all_lines, total=len(all_lines)):
        if line and not line.isspace():
            pair = line.split()
            assert(len(pair) > label_pos)
            word, label = pair[0], pair[label_pos]
            if dataset_name.startswith("conll-2003") and label != "O":
                label = label.removeprefix("B-").removeprefix("I-") # No BIO labels
                label = CONLL_LABLE_DICT[label]
            input_text += word + " "
            if last_label != "O":
                if label == last_label:
                    ongoing_entity.append(word)
                else:
                    # ref_output_text += " ".join(ongoing_entity) + ": " + last_label + "; "
                    ref_output_text.append(f"({' '.join(ongoing_entity)}; is; {last_label})")
                    ongoing_entity.clear()
                    if label != "O":
                        ongoing_entity.append(word)
            else:   # last_label == "O"
                if label != "O":
                    ongoing_entity.append(word)
            last_label = label
        else:
            if last_label != "O":
                # ref_output_text += " ".join(ongoing_entity) + ": " + label + "; "
                ref_output_text.append(f"({' '.join(ongoing_entity)}; is; {last_label})")
            last_label = "O"
            ongoing_entity.clear()
            request_state = Dataset.RequestState()
            request_state.instance.id = "id" + str(request_id)
            request_state.instance.input.text = input_text
            request_state.instance.split = split_type
            request_state.instance.references[0]["output"].text = " | ".join(ref_output_text)
            dataset.request_states.append(request_state)
            request_id += 1
            input_text, ref_output_text = "", []
    if random_sample:
        dataset.request_states = random.choices(dataset.request_states, k=1000)
    dest_dir = dest_prefix + dataset_name
    dest_path = dest_dir + "/" + split_type + ".json"
    pathlib.Path(dest_dir).mkdir(exist_ok = True, parents = True)
    print(len(dataset.request_states))
    json.dump(obj=dataset, fp=open(dest_path, "w", encoding="utf-8"), cls=ClassEncoder, indent=4)

if __name__ == "__main__":
    for target in target_datasets:
        for split_type in split_types:
            data_process(
                dataset_name=target[0],
                source=target[1],
                label_pos=target[2],
                split_type=split_type,
                random_sample=False
            )