import os
import json

class Datasets:
    def __init__(self, dataset_path, dataset_name, count = None ) -> None:
        self.question_list = []
        self.answer_list = []
        self.load(dataset_path, dataset_name, count)

    def load(self, dataset_path, dataset_name, count = None):
        if dataset_name == "salda":
            # file_name = os.path.basename(dataset_path)
            with open(dataset_path, "r") as f:
                for i, line in enumerate(f):
                    if count != None and i >= count:
                        break
                    # print(i)
                    # print(line.strip())
                    
                    json_object = json.loads(line.strip())
                    self.question_list.append(json_object['question'])
                    self.answer_list.append(json_object['answer'])
        else:
            raise Exception("Wrong name in dataset name")
        
class Salad_Datasets:  
    def __init__(self, count_per_type = 10 , 
                 dataset_path = "./datasets/salad_data.json",
                 dataset_name = "salda") -> None:
        self.dataset_dict = {}
        self.question_list = []
        self.load(dataset_path, dataset_name, count_per_type)

    def load(self, dataset_path, dataset_name, count_per_type = None):
        with open(dataset_path, "r", encoding="utf-8") as f:
            self.dataset_dict = json.load(f)
        import random

        random.seed(42)
        random.shuffle(self.dataset_dict)
        question_list_1 = []
        question_list_2 = []
        question_list_3 = []
        question_list_4 = []
        question_list_5 = []
        question_list_6 = []
        for i in self.dataset_dict:
            # if i['label'] == 1:
            #     continue
            if i["1-category"][:2]=="O1":
                question_list_1.append([i["question"], i["1-category"]])
            elif i["1-category"][:2]=="O2":
                question_list_2.append([i["question"], i["1-category"]])
            elif i["1-category"][:2]=="O3":
                question_list_3.append([i["question"], i["1-category"]])
            elif i["1-category"][:2]=="O4":
                question_list_4.append([i["question"], i["1-category"]])
            elif i["1-category"][:2]=="O5":
                question_list_5.append([i["question"], i["1-category"]])
            elif i["1-category"][:2]=="O6":
                question_list_6.append([i["question"], i["1-category"]])

        self.question_list = (question_list_1[:count_per_type] 
                              + question_list_2[:count_per_type] 
                              + question_list_3[:count_per_type] 
                              + question_list_4[:count_per_type] 
                              + question_list_5[:count_per_type] 
                              + question_list_6[:count_per_type])



    def create_subdataset(self, count_per_type, target_file):
        total_cate = 6
        cate_count_list = total_cate * [0]
        result_list = []
        for i in self.dataset_dict:
            match = True
            if i["1-category"][:2]=="O1":
                pass
            elif i["1-category"][:2]=="O2":
                pass
            elif i["1-category"][:2]=="O3":
                pass
            elif i["1-category"][:2]=="O4":
                pass
            elif i["1-category"][:2]=="O5":
                pass
            elif i["1-category"][:2]=="O6":
                pass
            else:
                match = False

            if match:
                cate = int(i["1-category"][1]) - 1
                if cate_count_list[cate] < count_per_type:
                    result_list.append(i)
                    cate_count_list[cate] += 1
            if cate_count_list == total_cate * [count_per_type]:
                break
        with open(target_file, "w") as f:
            json.dump(result_list,f, indent= 4)
        


if __name__ == "__main__":
    question_num = 3
    dataset = Salad_Datasets(question_num, dataset_path="./datasets/salad_data.json")
    question_list = dataset.question_list
    print(len(question_list))
    print(question_list)