import os
import json
import pandas as pd
import glob
import re



'''
装在example:
General thought 
from _datasets import Datasets
dataset = Datasets(
    dataset_path = "../dataset/GeneralThought-323K",
    dataset_name = "general_thought",
    count = 100,  # 装载100条数据
    start_index = 0,  # 从第0条开始装载
    )

    
openr1-math
from _datasets import Datasets
dataset = Datasets(
    dataset_path = "../dataset/OpenR1-Math-220k/data",
    dataset_name = "openr1-math",
    # count = 100,  # 装载100条数据
    start_index = 0,  # 从第0条开始装载
    )
'''


'''
dataset_name对应:
    SALDA-BENCH: salda
    advBench: adv
'''
class Datasets:  # 只包含example中的一百多条数据，不是完整数据集
    def __init__(self, dataset_path, dataset_name, count = None, start_index = 0) -> None:
        self.question_list = []
        self.answer_list = []
        self.load(dataset_path, dataset_name, count, start_index)

    '''
    count:{
        type: int
        description: 用于表示装载多少内容
        example: 100
    }
    最终要的装载过程是，装载question_list: List[str]与answer_list: List[str]
    '''
    def load(self, dataset_path, dataset_name, count = None, start_index= 0, problem_type = None):
        if dataset_name == "salda":
            counter = 0
            # file_name = os.path.basename(dataset_path)
            with open(dataset_path, "r") as f:
                for i, line in enumerate(f):
                    if i < start_index:
                        continue
                    if count != None and counter >= count:
                        break
                    json_object = json.loads(line.strip())
                    self.question_list.append(json_object['question'])
                    self.answer_list.append(json_object['answer'])
                    counter = counter + 1
        elif dataset_name == "jbb":
            df = pd.read_csv(dataset_path)
            goal_column = df["Goal"]

            # 切片: 从 start_index 开始，最多 count 个
            if count is not None:
                result = goal_column.iloc[start_index:start_index+count]
            else:
                result = goal_column.iloc[start_index:]
            self.question_list = result.tolist()
        elif dataset_name == "adv":
            counter = 0
            with open(dataset_path, "r") as file:
                import csv
                reader = csv.reader(file)
                next(reader)
                # 遍历每一行
                for i, row in enumerate(reader):
                    if i < start_index:
                        continue
                    if count != None and counter >= count:
                        break
                    # 提取每一行的第一个元素，并将其添加到列表
                    self.question_list.append(row[0])
                    self.answer_list.append(row[1])
                    counter = counter + 1
                # for i, line in enumerate(f):
                #     if count != None and i >= count:
                #         break
                #     json_object = json.loads(line.strip())
                #     self.question_list.append(json_object['question'])
                #     self.answer_list.append(json_object['answer'])
        elif dataset_name == "general_thought":
            
            parquet_files = glob.glob(os.path.join(dataset_path , '*.parquet'))
            df = pd.concat([pd.read_parquet(file, engine='pyarrow') for file in parquet_files], ignore_index=True)
            df = df[df["model_reasoning"].notna() & (df["model_reasoning"].str.strip() != "")]
            df = df[df["model_answer"].notna() & (df["model_answer"].str.strip() != "")]
            if count != None:
                df = df[start_index:count+start_index]
            self.question_list = df.question.tolist()
            self.reasoning_list = df.model_reasoning.tolist()
            self.answer_list = df.model_answer.tolist()
            assert len(self.question_list) == len(self.reasoning_list)
            assert len(self.question_list) == len(self.answer_list)
        elif dataset_name == "openr1-math":
            parquet_files = glob.glob(os.path.join(dataset_path , '*.parquet'))
            df = pd.concat([pd.read_parquet(file, engine='pyarrow') for file in parquet_files], ignore_index=True)
            # 注意到deepseek R1的回答是在"generations"中
            # 首先排除一些空的内容
            df = df[df["problem"].notna() & (df["problem"].str.strip() != "")]
            df = df[df["answer"].notna() & (df["answer"].str.strip() != "")]
            df = df[df["problem_type"].notna() & (df["problem_type"].str.strip() != "")]
            df = df[df["is_reasoning_complete"].notna() & (df["is_reasoning_complete"].str.strip() != "")]
            df = df[df["generations"].notna() & (df["generations"].str.strip() != "")]
            df = df[df["correctness_math_verify"].notna() & (df["correctness_math_verify"].str.strip() != "")]
            # is_reasoning_complete, generations, correctness_math_verify中的元素均为list，元素list长度均一致
            # 过滤完成
            # 判断种类
            if problem_type != None:
                df = df[df["problem_type"] == problem_type]
            # 保证is_reasoning_complete与correctness_math_verify至少有一个位置为True，这是用于后续微调
            df = df[df.apply(
                lambda row: any([
                    row['is_reasoning_complete'][i] and row['correctness_math_verify'][i] 
                    for i in range(len(row['is_reasoning_complete']))
                ]), 
                axis=1
            )]
            #
            if count != None:
                df = df[start_index:count+start_index]
            self.direct_answer_list = df.answer.tolist()
            self.question_list = []
            self.reasoning_list = []
            self.answer_list = []
            # self.problem_type_list = df.problem_type.tolist()
            question_list = df.problem.tolist()
            generation_list = df.generations.tolist()
            is_reasoning_complete_list = df.is_reasoning_complete.tolist()
            correctness_math_verify_list = df.correctness_math_verify.tolist()
            # pattern = r'<think>(.*?)</think>'
            for i, generation in enumerate(generation_list):
                for j in range(len(generation)):
                    if is_reasoning_complete_list[i][j] and correctness_math_verify_list[i][j]:
                        self.question_list.append(question_list[i])
                        generation_text = generation[j]
                        reasoning_text = re.search(r'<think>(.*?)</think>', generation_text, re.DOTALL).group(1).strip()
                        answer_text =  re.search(r'</think>(.*)', generation_text, re.DOTALL).group(1).strip()
                        self.reasoning_list.append(reasoning_text)
                        self.answer_list.append(answer_text)
                        break

            # 完成之后确保长度一致
            assert len(self.question_list) == len(self.reasoning_list)
            assert len(self.question_list) == len(self.answer_list)
        elif dataset_name == "lima":
            with open(dataset_path, "r") as f:
                for i, line in enumerate(f):
                    if i < 0 or i < start_index:
                        continue
                    if count != None:
                        if i >= count + start_index:
                            break
                    json_object = json.loads(line.strip())
                    question = json_object['conversations'][0]
                    answer = json_object['conversations'][1]
                    # print(len(json_object['conversations']))
                    # print(question)
                    # print(answer)
                    self.question_list.append(question)
                    self.answer_list.append(answer)
        else:
            raise Exception("Wrong name in dataset name")
        
class Salad_Datasets:  # 包含完整的Salad-Data数据集的子集base_set，对6个有害类别分别加载cout_per_type条
    def __init__(self, count_per_type = 10 , 
                 dataset_path = "./datasets/salad_data.json",
                 dataset_name = "salda",
                 count = 100) -> None:
        self.dataset_dict = {}
        self.question_list = []
        self.load(dataset_path, dataset_name, count_per_type)

    def load(self, dataset_path, dataset_name, count_per_type = None):
        with open(dataset_path, "r", encoding="utf-8") as f:
            self.dataset_dict = json.load(f)
        import random

        random.seed(42)
        random.shuffle(self.dataset_dict)
        # 6大类有害问题，每类加载count_per_type条
        question_list_1 = []
        question_list_2 = []
        question_list_3 = []
        question_list_4 = []
        question_list_5 = []
        question_list_6 = []
        for i in self.dataset_dict:
            # if i['label'] == 1:
            #     continue
            if i["1-category"][:2]=="O1":
                question_list_1.append(i["question"])
            elif i["1-category"][:2]=="O2":
                question_list_2.append(i["question"])
            elif i["1-category"][:2]=="O3":
                question_list_3.append(i["question"])
            elif i["1-category"][:2]=="O4":
                question_list_4.append(i["question"])
            elif i["1-category"][:2]=="O5":
                question_list_5.append(i["question"])
            elif i["1-category"][:2]=="O6":
                question_list_6.append(i["question"])

        self.question_list = (question_list_1[:count_per_type] 
                              + question_list_2[:count_per_type] 
                              + question_list_3[:count_per_type] 
                              + question_list_4[:count_per_type] 
                              + question_list_5[:count_per_type] 
                              + question_list_6[:count_per_type])



    def create_subdataset(self, count_per_type, target_file):
        total_cate = 6
        cate_count_list = total_cate * [0]
        result_list = []
        for i in self.dataset_dict:
            match = True
            if i["1-category"][:2]=="O1":
                pass
            elif i["1-category"][:2]=="O2":
                pass
            elif i["1-category"][:2]=="O3":
                pass
            elif i["1-category"][:2]=="O4":
                pass
            elif i["1-category"][:2]=="O5":
                pass
            elif i["1-category"][:2]=="O6":
                pass
            else:
                match = False

            if match:
                cate = int(i["1-category"][1]) - 1
                if cate_count_list[cate] < count_per_type:
                    result_list.append(i)
                    cate_count_list[cate] += 1
            if cate_count_list == total_cate * [count_per_type]:
                break
        with open(target_file, "w") as f:
            json.dump(result_list,f, indent= 4)
        


if __name__ == "__main__":
    question_num = 3
    dataset = Salad_Datasets(question_num, dataset_path="./datasets/salad_data.json")
    question_list = dataset.question_list
    print(len(question_list))
    print(question_list)