import json, tqdm, pathlib, random

# class definition

class Dataset:
    class Prompt:
        def __init__(self) -> None:
            self.instructions = ""
            self.input_prefix = "Text: "
            self.input_suffix = "\n"
            self.output_prefix = "Answer: "
            self.output_suffix = "\n"
    
    class RequestState:
        class IOText:
            def __init__(self) -> None:
                self.text = ""
                
        class Instance:
            def __init__(self) -> None:
                self.input = Dataset.RequestState.IOText()
                self.references = [
                    {
                        "output": Dataset.RequestState.IOText()
                    }
                ]
                self.split = ""
                self.id = ""
                
        def __init__(self) -> None:
            self.instance = self.Instance()
        
    def __init__(self) -> None:
        self.prompt = self.Prompt()
        self.request_states = [] 
        
# dataset io settings

dataset_name = "ace2005-ner"
dataset_dir = "../../../data/ace2005-ner/"
dest_prefix = "../../../unified_data/NER/"

split_types = ["train", "dev", "test"]

class ClassEncoder(json.JSONEncoder):
    def default(self, obj):
        return obj.__dict__ 
        
        
if __name__ == "__main__":
    for split_type in split_types:
        
        dataset = Dataset()

        if split_type in ["train", "test", "dev"]:
            source_path = dataset_dir + split_type + ".json"
        print(source_path)
        all_lines = open(source_path, encoding="utf-8").readlines()
        
        label_set = set()
        for line_id, line in enumerate(all_lines):
            instance = json.loads(line)
            assert len(instance["sentences"]) == len(instance["ner"])
            start = 0
            for sentence, entities in zip(instance["sentences"], instance["ner"]):
                # import pdb; pdb.set_trace()
                input_text = " ".join(sentence)
                ref_output = []
                entities = sorted(entities, key=lambda item: item[0])
                for entity in entities:
                    type_ent = entity[-1]
                    name_ent = " ".join(sentence[entity[0]-start:entity[1]-start+1])
                    ref_output.append(f"({name_ent}; is; {type_ent})")
                request_state = Dataset.RequestState()
                request_state.instance.id = "id" + str(line_id)
                request_state.instance.input.text = input_text
                request_state.instance.split = split_type
                request_state.instance.references[0]["output"].text = " | ".join(ref_output)
                dataset.request_states.append(request_state)
                start += len(sentence)
                
        dataset.prompt.instructions = "Please recognize entities for the given text and classify them into a suitable type. The collection of types is as follows: "
        dataset.prompt.instructions += "["
        dataset.prompt.instructions += ", ".join(label_set)
        dataset.prompt.instructions += "]."
        dataset.prompt.instructions = "Named Entity Recognition: "
 
        # dataset.request_states = random.choices(dataset.request_states, k=1000)
        dest_dir = dest_prefix + dataset_name
        dest_path = dest_dir + "/" + split_type + ".json"
        pathlib.Path(dest_dir).mkdir(exist_ok = True, parents = True) 
        
        print(len(dataset.request_states))
        json.dump(obj=dataset, fp=open(dest_path, "w", encoding="utf-8"), cls=ClassEncoder, indent=4)
