import os

def extract_think_blocks(file_path):    
    results = []
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()

    i = 0
    n = len(text)
    while i < n:
        start = text.find('<think>', i)
        if start == -1:
            break
        start += len('<think>')
        next_think = text.find('<think>', start)
        next_end = text.find('</think>', start)
        if next_end == -1:
            break
        if next_think != -1 and next_think < next_end:
            i = next_think
            continue
        content = text[start:next_end]
        results.append(content.strip())
        i = next_end + len('</think>')
    return results

def split_blocks(blocks):  
    split_results = []
    for block in blocks:
        import re
        parts = re.split(r'((?i)(?:wait|Alternatively))', block)
        merged_parts = []
        for i in range(0, len(parts)-1, 2):
            if i+1 < len(parts):
                if i+2 < len(parts):
                    merged_parts.append(parts[i])
                    merged_parts.append(parts[i+1] + parts[i+2])
                else:
                    merged_parts.append(parts[i])
                    merged_parts.append(parts[i+1])
            else:
                merged_parts.append(parts[i])
        parts = [re.sub(r'^[\s,]+', '', part.strip()) for part in merged_parts if part.strip()]
        split_results.append(parts)
    return split_results

def get_cot(file_path):
    blocks = extract_think_blocks(file_path)
    split_blocks_result = split_blocks(blocks)
    # print(split_blocks_result[0])
    return split_blocks_result

def ave_length(split_blocks_result):
    sum = 0
    max_len = 0
    min_len = float('inf')
    for parts in split_blocks_result:
        sum += len(parts)
        max_len = max(max_len, len(parts))
        min_len = min(min_len, len(parts))
    return  {
        "max_len": max_len,
        "min_len": min_len,
        "ave_len": sum / len(split_blocks_result)
    }

def check_block_fields(split_block):
    """
    Check if the split block contains specific fields
    Return: cal, veri, calve, seek
    cal: whether it contains calculating/calculater/compute related fields
    veri: whether it contains check/verify/confirm/"correct?" related fields
    calve: product of cal and veri
    seek: 1 if cal and veri are both 0, otherwise 0
    """
    cal = 0
    veri = 0
    
    # Merge all parts of split_block into a string for checking
    text = ' '.join(split_block).lower()
    
    # Check for calculating/calculater/compute related fields
    calc_keywords = ['calculating', 'calculater', 'compute']
    for keyword in calc_keywords:
        if keyword in text:
            cal = 1
            break
    
    # Check for verify/confirm related fields
    verify_keywords = ['check', 'verify', 'confirm']
    for keyword in verify_keywords:
        if keyword in text:
            veri = 1
            break
    
    # Check for "correct?" pattern
    if 'correct' in text:
        idx = text.find('correct')
        if idx + 7 < len(text) and text[idx + 7] == '?':
            veri = 1
    
    # Calculate calve and seek
    calve = cal * veri
    seek = 1 if (cal == 0 and veri == 0) else 0
    calon = cal - calve
    verion = veri - calve
    
    return calon, verion, calve, seek

if __name__ == '__main__':
    file_path = './raw_answer/part_split_cot/gsm8k_7b_6_5,4.log'
    # blocks = extract_think_blocks(file_path)
    # # print(blocks[0])
    # split_blocks_result = split_blocks(blocks)
    # print(len(split_blocks_result))
    # print(ave_length(split_blocks_result))
    cot = get_cot(file_path)
    print(cot[0])
    print(len(cot))
    print(ave_length(cot))
    
    # Test check_block_fields function
    test_block = ["Let's calculate the sum", "verifythis is correct ?", "compute the result"]
    cal, veri, calve, seek, calon, verion = check_block_fields(cot[0])
    print(f"Test results: cal={calon}, veri={verion}, calve={calve}, seek={seek}")