import re
import random
from datasets import load_dataset
import numpy as np
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

def prune_thinking_data(input_text, keep_percentage=0.5, end_start_ratio=0.5):
    """
    Prune thinking data from a text by:
    1. Extracting content between <think> and </think> tags
    2. Splitting by double newlines
    3. Randomly removing a percentage while preserving beginning and end
    
    Args:
        input_text: The input text containing <think>...</think> blocks
        keep_percentage: Percentage of chunks to keep (default: 0.5)
        preserve_ends: Whether to preserve beginning and ending chunks (default: True)
        
    Returns:
        The input text with pruned thinking sections
    """
    # Extract content between <think> and </think> tags
    think_pattern = re.compile(r'<think>(.*?)</think>', re.DOTALL)
    match = think_pattern.search(input_text)
    if not match:
        alt_pattern = re.compile(r'<think>(.*?)\*\*Final Answer\*\*', re.DOTALL)
        match = alt_pattern.search(input_text)
    if match:
        thinking_content = match.group(1)
        chunks = re.split(r'\n\n', thinking_content)
        # Filter out empty chunks
        chunks = [chunk for chunk in chunks if chunk.strip()]
    else:
        return input_text
    non_think_part = input_text[match.end():]
    non_think_start_part = input_text[:match.start()]
    chunks_number = len(chunks)
    if chunks_number >=60:
    # Calculate number of chunks to keep
        num_to_keep = int(chunks_number * keep_percentage)


        num_end_to_keep = int(num_to_keep * end_start_ratio)

        num_start_to_keep = num_to_keep - num_end_to_keep

        num_start_to_keep = max(num_start_to_keep, 1)
        num_end_to_keep = max(num_end_to_keep, 1)
        start_chunks = chunks[:num_start_to_keep]
        # print(start_chunks)
        
        end_chunks = chunks[-num_end_to_keep:]
        new_thoughts = "\n\n".join(start_chunks + end_chunks)
        new_thoughts = f"<think>{new_thoughts}</think>"

        output_text = new_thoughts + non_think_part
    else:
        output_text = input_text
    return output_text

    
def prune_thinking_data_new(input_text, num_to_keep, end_start_ratio=0.5):
    """
    Prune thinking data from a text by:
    1. Extracting content between <think> and </think> tags
    2. Splitting by double newlines
    3. Randomly removing a percentage while preserving beginning and end
    
    Args:
        input_text: The input text containing <think>...</think> blocks
        keep_percentage: Percentage of chunks to keep (default: 0.5)
        preserve_ends: Whether to preserve beginning and ending chunks (default: True)
        
    Returns:
        The input text with pruned thinking sections
    """
    # Extract content between <think> and </think> tags
    think_pattern = re.compile(r'<think>(.*?)</think>', re.DOTALL)
    match = think_pattern.search(input_text)
    if not match:
        alt_pattern = re.compile(r'<think>(.*?)\*\*Final Answer\*\*', re.DOTALL)
        match = alt_pattern.search(input_text)
    if match:
        thinking_content = match.group(1)
        chunks = re.split(r'\n\n', thinking_content)
        # Filter out empty chunks
        chunks = [chunk for chunk in chunks if chunk.strip()]
    else:
        return input_text
    non_think_part = input_text[match.end():]
    non_think_start_part = input_text[:match.start()]
    chunks_number = len(chunks)
    if chunks_number >=60:
    # Calculate number of chunks to keep
        if num_to_keep >= chunks_number:
            num_to_keep = chunks_number
            return input_text
        num_end_to_keep = int(num_to_keep * end_start_ratio)

        num_start_to_keep = num_to_keep - num_end_to_keep

        num_start_to_keep = max(num_start_to_keep, 1)
        num_end_to_keep = max(num_end_to_keep, 1)
        start_chunks = chunks[:num_start_to_keep]
        # print(start_chunks)
        
        end_chunks = chunks[-num_end_to_keep:]
        new_thoughts = "\n\n".join(start_chunks + end_chunks)
        new_thoughts = f"<think>{new_thoughts}</think>"

        output_text = new_thoughts + non_think_part
    else:
        output_text = input_text
    return output_text







def prune_ratio_dist(keep_percentage=0.5,end_start_ratio=0.5):
    dataset = load_dataset("open-r1/OpenR1-Math-220k", split="train")
    chunk_counts = []

    for sample in tqdm(dataset):
        # Get the solution field which contains the thinking process
        solution = sample.get('messages', '')
        
        # Count chunks in this sample
        for message in solution:
            
            if message['role'] == 'user':
                continue
            else:
                solution = message["content"]
        chunks = count_thought_chunks(solution)
        if chunks >= 60:
            chunk_counts.append(chunks)
            # samples_with_thinking += 1
        else:
            chunk_counts.append(0)
    def softmax(x, temp = 1.0):
        x = np.array(x)
        return x / np.sum(x)

    probs = softmax(chunk_counts)
    total_thoughts = sum(chunk_counts)
    new_thoughts = []
    for prob, chunk in zip(probs.tolist(), chunk_counts):
        new_thoughts.append(prob * chunk)
    
    new_thoughts = np.array(new_thoughts)
    new_thoughts = new_thoughts / np.sum(new_thoughts)
    number_of_thoughts = new_thoughts*total_thoughts
    new_thoughts = new_thoughts.tolist()
    print(new_thoughts)
    print(np.sum(new_thoughts))
    plt.figure(figsize=(10, 6))
    plt.hist(number_of_thoughts, bins=50, edgecolor='black')
    # plt.title('Distribution of Chunk Counts')
    plt.xlabel('Number of Chunks')
    plt.ylabel('Frequency')
    dataset = dataset.add_column("prob", new_thoughts)
    def map_fn(examples):
        for i, example in enumerate(examples['messages']):
            for j,text in enumerate(example):
                if text['role'] == 'user':
                    continue
                else:
                    solution = text["content"]
                pruned_solution = prune_thinking_data_new(solution, num_to_keep=int(keep_percentage*examples["prob"][i]*total_thoughts), end_start_ratio=end_start_ratio)
                # print(len(pruned_solution))
                # print(len(solution))
                # exit()
                examples['messages'][i][j]['content'] = pruned_solution
        return examples
    dataset = dataset.map(map_fn, batched=True)

 

    return prob


def prune_ratio_correct(keep_percentage=0.5,end_start_ratio=0.5):
    dataset = load_dataset("open-r1/OpenR1-Math-220k", split="train")
    chunk_counts = []
    correct_ratio = []
    for sample in tqdm(dataset):
        # Get the solution field which contains the thinking process
        solution = sample.get('messages', '')
        
        # Count chunks in this sample
        for message in solution:
            
            if message['role'] == 'user':
                continue
            else:
                solution = message["content"]
        chunks = count_thought_chunks(solution)
        num_thinking_trajectories = len(sample.get("correctness_math_verify", []))
        correctness_count = sample.get("correctness_count",0)
        if num_thinking_trajectories == 0:
            ratio = 0
        else:
            ratio = correctness_count / num_thinking_trajectories
        if chunks >= 60:
            chunk_counts.append(chunks)
            # samples_with_thinking += 1
        else:
            chunk_counts.append(0)
        correct_ratio.append(ratio)
    total_thoughts = sum(chunk_counts)
    max_thouthts = max(chunk_counts)
    min_thoughts = min(chunk_counts)
    new_thoughts = []
    normalized_ratio = []
    normalized_chuns = []
    for prob, chunk in zip(correct_ratio, chunk_counts):
        normalized_chunk = (chunk - min_thoughts) / (max_thouthts - min_thoughts)
        new_thoughts.append((1-prob) * normalized_chunk)
        normalized_chuns.append(normalized_chunk)
        normalized_ratio.append(1-prob)
    new_thoughts = np.array(new_thoughts)
    new_thoughts = new_thoughts / np.sum(new_thoughts)
    from scipy.stats import spearmanr
    spearman_corr, _ = spearmanr(normalized_ratio, normalized_chuns)
    print(f"Spearman correlation: {spearman_corr}")
    exit()
    print(new_thoughts)
    print(np.sum(new_thoughts))
    dataset = dataset.add_column("prune_ratio", new_thoughts)
    def map_fn(examples):
        for i, example in enumerate(examples['messages']):
            for j,text in enumerate(example):
                if text['role'] == 'user':
                    continue
                else:
                    solution = text["content"]
                pruned_solution = prune_thinking_data_new(solution, num_to_keep=int(keep_percentage*examples["prune_ratio"][i]*total_thoughts), end_start_ratio=end_start_ratio)
                # print(len(pruned_solution))
                # print(len(solution))
                # exit()
                examples['messages'][i][j]['content'] = pruned_solution
        return examples
    dataset = dataset.map(map_fn, batched=True)

    return prob

def count_thought_chunks(text):
    """
    Count the number of thought chunks in a text.
    A chunk is defined as text separated by double newlines inside <think> tags.
    
    Args:
        text: The input text containing <think>...</think> blocks
        
    Returns:
        Number of chunks, or 0 if no thinking sections found
    """
    think_pattern = re.compile(r'<think>(.*?)</think>', re.DOTALL)
    match = think_pattern.search(text)
    if not match:
        alt_pattern = re.compile(r'<think>(.*?)\*\*Final Answer\*\*', re.DOTALL)
        match = alt_pattern.search(text)
    if match:
        thinking_content = match.group(1)
        chunks = re.split(r'\n\n', thinking_content)
        # Filter out empty chunks
        chunks = [chunk for chunk in chunks if chunk.strip()]
        return len(chunks)
    return 0



def prune_openr1_math_dataset(keep_ratio=0.5, end_start_ratio=0.5):

    dataset = load_dataset("open-r1/OpenR1-Math-220k", split="train")
    total_num_wait = 0
    def map_fn(examples):
        for i, example in enumerate(examples['messages']):
            for j,text in enumerate(example):
                if text['role'] == 'user':
                    continue
                else:
                    solution = text["content"]
                pruned_solution = prune_thinking_data(solution, keep_percentage=keep_ratio, end_start_ratio=end_start_ratio)
                # print(len(pruned_solution))
                # print(len(solution))
                # exit()
                pattern = re.compile(r'\bwait\b', re.IGNORECASE)
                num_wait = len(pattern.findall(pruned_solution))
                nonlocal total_num_wait
                total_num_wait += num_wait
                examples['messages'][i][j]['content'] = pruned_solution
        return examples
    dataset = dataset.map(map_fn, batched=True)
    # total_num_wait = total_num_wait / len(dataset)
    # print(f"Total number of wait: {total_num_wait}")
    # exit()

    return dataset

def prune_general_thoughts(keep_ratio=0.5, end_start_ratio=0.5):
    def count(text):
        print(text)
        chunks = re.split(r'\n\n', text)
        if not chunks:
            return 1
        chunks = [chunk for chunk in chunks if chunk.strip()]
        return len(chunks)
    def prune_thinking(text):
        if text is None:
            return None
        chunks = re.split(r'\n\n', text)
        if not chunks:
            return None
        chunks = [chunk for chunk in chunks if chunk.strip()]
        chunks_number = len(chunks)
        if chunks_number >= 15:
            # Calculate number of chunks to keep
            num_to_keep = int(chunks_number * keep_ratio)
            num_end_to_keep = int(num_to_keep * end_start_ratio)
            num_start_to_keep = num_to_keep - num_end_to_keep
            num_start_to_keep = max(num_start_to_keep, 1)
            num_end_to_keep = max(num_end_to_keep, 1)
            start_chunks = chunks[:num_start_to_keep]
            end_chunks = chunks[-num_end_to_keep:]
            new_thoughts = "\n\n".join(start_chunks + end_chunks)
            return new_thoughts
        else:
            return text
    dataset = load_dataset("GeneralReasoning/GeneralThought-195K", split="train")
    def map_fn(examples):
        for i, example in enumerate(examples['model_reasoning']):
            pruned_thoughts = prune_thinking(example)
            # number_chunks = count(example)
            # if number_chunks >= 15:
            #     print(len(pruned_thoughts))
            #     print(example)
                
            #     print(pruned_thoughts)
            #     print(example)
            #     exit()
            if pruned_thoughts is not None:
                examples['model_reasoning'][i] = pruned_thoughts
            else:
                continue
        return examples
    dataset = dataset.map(map_fn, batched=True)

def analyze_general_thoughts_dataset(sample_size=None):
    dataset = load_dataset("GeneralReasoning/GeneralThought-195K", split="train")
    def count_lens(text):
        return len(text.split())
    def count(text):
        chunks = re.split(r'\n\n', text)
        if not chunks:
            return 1
        chunks = [chunk for chunk in chunks if chunk.strip()]
        return len(chunks)
    chunk_counts = []
    lens_counts = []
    samples_with_thinking = 0
    samples_total = len(dataset)
    for sample in tqdm(dataset):
        reasoning = sample['model_reasoning']
        # exit()
        if reasoning == None:
            continue
        chunks = count(reasoning)
        lens = count_lens(reasoning)
        lens_counts.append(lens)
        if chunks > 0:
            chunk_counts.append(chunks)
            samples_with_thinking += 1
            # if chunks == 1:
                # print(reasoning)
                # exit()
                # exit()
        else:
            # print(samples_with_thinking)
            print(reasoning)
            # exit()
    if chunk_counts:
        stats = {
            'mean_chunks': np.mean(chunk_counts),
            'median_chunks': np.median(chunk_counts),
            'min_chunks': min(chunk_counts),
            'max_chunks': max(chunk_counts),
            'std_dev': np.std(chunk_counts),
            'samples_with_thinking': samples_with_thinking,
            'samples_total': samples_total,
            'percentage_with_thinking': (samples_with_thinking / samples_total) * 100,
            'average_length': np.mean(lens_counts),
        }
    else:   
        stats = {
            'mean_chunks': 0,
            'median_chunks': 0,
            'min_chunks': 0,
            'max_chunks': 0,
            'std_dev': 0,
            'samples_with_thinking': 0,
            'samples_total': samples_total,
            'percentage_with_thinking': 0
        }
    plt.figure(figsize=(10, 6))
    plt.hist(chunk_counts, bins=50, edgecolor='black')
    # plt.title('Distribution of Chunk Counts')
    plt.xlabel('Number of Chunks')
    plt.ylabel('Frequency')
    return stats


# Example usage
if __name__ == "__main__":
    import sys
    import argparse
    
    parser = argparse.ArgumentParser(description='Process thinking data and analyze datasets.')
    parser.add_argument('--analyze-dataset', action='store_true', help='Analyze the OpenR1-Math-220k dataset')
    parser.add_argument('--sample-size', type=int, default=None, help='Number of samples to analyze from the dataset')
    parser.add_argument('--prune-dataset', action='store_true', help='Prune the OpenR1-Math-220k dataset')
    parser.add_argument('--keep-ratio', type=float, default=0.5, help='Percentage of chunks to keep (default: 0.5)')
    parser.add_argument('--end-start-ratio', type=float, default=0.5, help='Percentage of chunks to keep from the end (default: 0.5)')
    args = parser.parse_args()
    
    # if args.analyze_dataset:
    #     stats = analyze_openr1_math_dataset(sample_size=args.sample_size)
    #     print("\nDataset Analysis Results:")
    #     print(f"Mean number of thought chunks: {stats['mean_chunks']:.2f}")
    #     print(f"Median number of thought chunks: {stats['median_chunks']:.2f}")
    #     print(f"Range: {stats['min_chunks']} to {stats['max_chunks']} chunks")
    #     print(f"Standard deviation: {stats['std_dev']:.2f}")
    #     print(f"Samples with thinking: {stats['samples_with_thinking']} out of {stats['samples_total']} ({stats['percentage_with_thinking']:.2f}%)")
    
    # if args.prune_dataset:
    #     prune_openr1_math_dataset(keep_ratio=args.keep_ratio, end_start_ratio=args.end_start_ratio)
    #     print("Dataset pruned and pushed to HuggingFace.")
    # if args.analyze_dataset:
    #     stats = analyze_general_thoughts_dataset(sample_size=args.sample_size)
    #     print("\nDataset Analysis Results:")
    #     print(f"Mean number of thought chunks: {stats['mean_chunks']:.2f}")
    #     print(f"Median number of thought chunks: {stats['median_chunks']:.2f}")
    #     print(f"Range: {stats['min_chunks']} to {stats['max_chunks']} chunks")
    #     print(f"Standard deviation: {stats['std_dev']:.2f}")
    #     print(f"Samples with thinking: {stats['samples_with_thinking']} out of {stats['samples_total']} ({stats['percentage_with_thinking']:.2f}%)")

    # if args.prune_dataset:
    #     prune_general_thoughts(keep_ratio=args.keep_ratio, end_start_ratio=args.end_start_ratio)
    #     print("Dataset pruned and pushed to HuggingFace.")
    # prune_ratio_acc(args.keep_ratio)
    # stats= prune_ratio_acc()
    # print("\nDataset Analysis Results:")
    # print(f"Mean number of thought chunks: {stats['mean_chunks']:.2f}")
    # print(f"Median number of thought chunks: {stats['median_chunks']:.2f}")
    # print(f"Range: {stats['min_chunks']} to {stats['max_chunks']} chunks")
    # print(f"Standard deviation: {stats['std_dev']:.2f}")
    # print(f"Samples with thinking: {stats['samples_with_thinking']} out of {stats['samples_total']} ({stats['percentage_with_thinking']:.2f}%)")    
    # print(f"Average length: {stats['average_length']:.2f}")
    prune_general_thoughts(keep_ratio=0.01, end_start_ratio=0.5)