import json
import os 
from tqdm import tqdm
from datasets import load_dataset

def read_jsonl_data(filename):
    with open(filename) as f:
        data = f.readlines()
    data = [json.loads(val) for val in data]
    return data

READ_METHODS = {
    "jsonl": read_jsonl_data,
}


class InstructionDataset:
    BASE_REPO = "hippocrates"


    def build_classification_instruction(self, data, prompt, index):
        # return {"id": f"{self.dataset}{index}", "query": prompt,
        #             "answer": data["answer"], "choices": self.choices,
        #             "gold": self.choices.index(data["answer"])}

        return {"task_id": f"{self.dataset}{index}", "prompt": prompt,
                    "gold": data["answer"],}


    def construct_instructions(self, data, split):
        instructions = []
        construct_dict = {
            "classification": self.build_classification_instruction}
        construct_method = construct_dict[self.task_type] 

        for index, datum in enumerate(tqdm(data)):     
            fetched_data = self.fetch_data(datum)
            filled_prompt = self.prompt.format(**fetched_data) 
            instruction = construct_method(fetched_data, filled_prompt, len(instructions))
            instructions.append(instruction)

        with open(f"{split}_instructions.jsonl", "w") as f:
            for val in tqdm(instructions):
                f.write(json.dumps(val)+"\n")

        instructions = load_dataset("json", data_files=f"{split}_instructions.jsonl", split="train")
        return instructions


class MedNLI(InstructionDataset):
    dataset = "MedNLI" 
    task_type = "classification"
    choices = ["entailment", "contradiction", "neutral"]
    prompt = """
TASK: Please classify the relationship between the given premise and hypothesis into one of the following labels: entailment, contradiction, or neutral. 
Return **valid JSON** that matches this schema:
{{"reason": <Explanation>   // one sentence, "label": entailment | contradiction | neutral}}
###
INPUT: {text}
OUTPUT:
"""
    def fetch_data(self, datum):
        return {
            "text": "[PRE] "+datum["sentence1"]+" [HYP] "+datum["sentence2"],
            "answer": datum["gold_label"],
        }



DATASETS = {
    "MedNLI": MedNLI,
}


def main():             
    mednli = MedNLI()
    os.makedirs("//users///proxy_tuning/eval/mednli/processed_data", exist_ok=True)
    for split, filename in {
        "train": "//datasets/mednli/1.0.0/mli_train_v1.jsonl",
        "dev":   "//datasets/mednli/1.0.0/mli_dev_v1.jsonl",
        "test":  "//datasets/mednli/1.0.0/mli_test_v1.jsonl",
    }.items():
        out_path = f"//users///proxy_tuning/datasets/mednli/{split}_reason_first.jsonl"
        dataset  = mednli.construct_instructions(
            READ_METHODS['jsonl'](filename), split
        )
        dataset.to_json(out_path)
        
        print(f"✓ wrote {out_path}")
        print(f"\n{split.upper()} – first row\n", dataset[0])


if __name__ == '__main__':
    main()