
import ast
from exceptiongroup import catch
from src import utils
import re
from collections import defaultdict

def gen_data_list(dir_path,batch_size=1):
    file_path=f"{dir_path}/{batch_size}.jsonl"
    ok,data=utils.read_jsonl(file_path)
    if not ok:
        print(f"{file_path} read fail")
        return 
    ret=[]
    error_line=0
    for item in data:
        split_answer_list=utils.split_string_by_separator(item['split_output'])
        if len(split_answer_list)!=batch_size:
            error_line+=1
            split_answer_list=['test' for _ in range(batch_size)]
        
        for index,id in enumerate(item['id_list']):
            ret_data={
                'origin_id':id,
                'origin_correct_answer':item['origin_correct_answer_list'][index],
                'origin_question':item['origin_question_list'][index],
                'split_answer':split_answer_list[index],
                'batch_size':batch_size
            }
            ret.append(ret_data)
    print(f"file_path:{file_path} total_line:{len(data)} error_line:{error_line}")
    return ret

def get_raw_data_list(dir_list,data_size_list):
    ret=[]
    for dir_path in dir_list:
        for batch_size in data_size_list:
            ret.extend(gen_data_list(dir_path,batch_size))
    return ret
def build_data_list(raw_data_list,data_size_list):
    data_index_map={
        f"{it['origin_id']}_{it['batch_size']}": index for index,it in enumerate(raw_data_list)
    }
    id_list=list(set(
        list(item['origin_id'] for item in raw_data_list)  
    ))
    print(f"total origin_id num:{len(id_list)}")
    results=[]
    for i in id_list:
        tp={}
        for x in data_size_list:
            if f"{i}_{x}" in data_index_map.keys():
                tp[f'split_answer_{x}']=raw_data_list[data_index_map[f"{i}_{x}"]]['split_answer']
        if len(tp)<len(data_size_list):
            continue
        data={
            'origin_id':i,
            'origin_question':raw_data_list[data_index_map[f"{i}_1"]]['origin_question'],
            'origin_correct_answer':raw_data_list[data_index_map[f"{i}_1"]]['origin_correct_answer'],
            **tp
        }
        results.append(data)
    print(len(results))
    return results

def filter_data(data,data_size_list):
    for x in data_size_list:
        # 使用正则表达式匹配 \boxed{内容}
        match = re.search(r'\\boxed\{([^}]+)\}', data[f'split_answer_{x}'])
        if match:
            # 提取匹配组中的内容
            content_inside_boxed = match.group(1)
           
            if content_inside_boxed!=data['origin_correct_answer']:
                #  print(f"{data['origin_id']}_{x} error")
                # print(f"{data['origin_id']}_{x} error")
                return False
        else:
            # print(f"{data['origin_id']}_{x} No match found.")
            return False
    return True

def filter_data_list(data_list,data_size_list):
    ret=[]
    for data in data_list:
        if filter_data(data,data_size_list):
            ret.append(data)
    return ret

def get_sft_dataset(dir_list,batch_size_list):
    raw_data_list=get_raw_data_list(dir_list,batch_size_list)
    print(f"raw_data_list len:{len(raw_data_list)}")
    data_list=build_data_list(raw_data_list,batch_size_list)
    print(f"data_list len:{len(data_list)}")
    data_list=filter_data_list(data_list,batch_size_list)
    print(f"filter_data_list len:{len(data_list)}")
    data=utils.gen_data_set(data_list)
    print(data)
    return data

def insight_len(dataset,source=None):
    if source is not None:
        dataset=[item for item in dataset if item['source']==source]
        print("source:{source}}")
    for i in [1,2,3,5,10]:
        if f"split_answer_{i}" in dataset[0]:
            print(f"{i}:len:{sum(len(item[f'split_answer_{i}']) for item in dataset)/len(dataset)}")

    # len1=sum(len(item['split_answer_1']) for item in dataset)/len(dataset)
    # len2=sum(len(item['split_answer_2']) for item in dataset)/len(dataset)
    # len3=sum(len(item['split_answer_3']) for item in dataset)/len(dataset)
    # len5=0
    # if 'split_answer_5' in dataset[0]:
    #     len5=sum(len(item['split_answer_5']) for item in dataset)/len(dataset)
    # print(f"{source} len1:{len1} len2:{len2} len3:{len3} len5:{len5}")