import os

def extract_think_blocks(file_path):
    results = []
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()

    i = 0
    n = len(text)
    while i < n:
        start = text.find('<think>', i)
        if start == -1:
            break
        start += len('<think>')
        next_think = text.find('<think>', start)
        next_end = text.find('</think>', start)
        if next_end == -1:
            break
        if next_think != -1 and next_think < next_end:
            i = next_think
            continue
        content = text[start:next_end]
        results.append(content.strip())
        i = next_end + len('</think>')
    return results

def split_blocks(blocks):
    split_results = []
    for block in blocks:
        import re
        # Use capture groups to preserve the split words
        parts = re.split(r'((?i)(?:wait|Alternatively))', block)
        # Put the split words at the beginning of the next part
        merged_parts = []
        for i in range(0, len(parts)-1, 2):
            if i+1 < len(parts):
                if i+2 < len(parts):
                    merged_parts.append(parts[i])
                    merged_parts.append(parts[i+1] + parts[i+2])
                else:
                    merged_parts.append(parts[i])
                    merged_parts.append(parts[i+1])
            else:
                merged_parts.append(parts[i])
        # Clean up whitespace characters
        parts = [re.sub(r'^[\s,]+', '', part.strip()) for part in merged_parts if part.strip()]
        split_results.append(parts)
    return split_results

def ave_length(split_blocks_result):
    sum = 0
    max_len = 0
    min_len = float('inf')
    for parts in split_blocks_result:
        sum += len(parts)
        max_len = max(max_len, len(parts))
        min_len = min(min_len, len(parts))
    return  {
        "max_len": max_len,
        "min_len": min_len,
        "ave_len": sum / len(split_blocks_result)
    }

def get_cot(file_path):
    blocks = extract_think_blocks(file_path)
    split_blocks_result = split_blocks(blocks)
    # print(split_blocks_result[0])
    return split_blocks_result

if __name__ == '__main__':
    file_path = './gsm8k_7b_6_5,4.log'
    # blocks = extract_think_blocks(file_path)
    # # print(blocks[0])
    # split_blocks_result = split_blocks(blocks)
    # print(len(split_blocks_result))
    # print(ave_length(split_blocks_result))
    cot = get_cot(file_path)
    # print(cot[0])
    # print(len(cot))
    print(ave_length(cot))