import json
from pathlib import Path
from collections import Counter
import numpy as np
from scipy.stats import ttest_1samp
from scipy.stats import t
import pandas as pd
from collections import Counter


class EarlyStopCoT():
    def __init__(self, path):
        self.data = self.load_data(path)

        # 生成merge_step_answer, merge_step_answer_count, merge_step_answer_tf
        self.process_data()

    def load_data(self, path):
        """Load data from a JSONL file."""
        data = []
        with open(path, 'r') as file:
            for line in file:
                data.append(json.loads(line))
        return data

    def process_data(self):
        """Process data to implement early stopping in chain-of-thought."""
        for data in self.data:
            ground_truth = data['answer']
            best_of_k = len(data['step_answer'])
            itemss = []
            countss = []
            tfss = []
            finalss = []
            for i in range(best_of_k):
                step_answer = data['step_answer'][i]
                items = []
                counts = []
                tfs = []
                finals = []
                if step_answer:
                    current = step_answer[0]
                    count = 1
                    for item in step_answer[1:]:
                        if item == current:
                            count += 1
                        else:
                            if current == ground_truth:
                                tfs.append(True)
                            else:
                                tfs.append(False)

                            if current == step_answer[-1]:
                                finals.append(True)
                            else:
                                finals.append(False)

                            items.append(current)
                            counts.append(count)
                            current = item
                            count = 1
                    # 添加最后一组
                    if current == ground_truth:
                        tfs.append(True)
                    else:
                        tfs.append(False)
                    if current == step_answer[-1]:
                        finals.append(True)
                    else:
                        finals.append(False)
                    items.append(current)
                    counts.append(count)

                itemss.append(items)
                countss.append(counts)
                tfss.append(tfs)
                finalss.append(finals)
            data['merged_step_answer'] = itemss
            data['merged_step_answer_count'] = countss
            data['merged_step_answer_tf'] = tfss
            data['merged_step_answer_final'] = finalss

    def early_stop_slope(self, min_slope=5, threshold=0.05):
        """Apply early stopping based on the slope of the answer counts."""
        for data_id in range(len(self.data)):
            es_answers = []
            es_tokens = []
            es_step_tokens = []
            for sample_id in range(len(self.data[data_id]['merged_step_answer'])):
                es_answer, es_token, es_tokens_with_step_answer, es_less = self.early_stop_slope_single(data_id, sample_id, min_slope, threshold)
                es_answers.append(es_answer)
                es_tokens.append(es_token)
                es_step_tokens.append(es_tokens_with_step_answer)
            self.data[data_id]['early_stop_answer'] = es_answers
            self.data[data_id]['early_stop_step_tokens'] = es_step_tokens
            self.data[data_id]['early_stop_tokens'] = es_tokens

    def early_stop_slope_single(self, data_id, sample_id, min_slope=5, threshold=0.05):
        """Apply early stopping based on the threshold."""
        merge_counts = self.data[data_id]['merged_step_answer_count'][sample_id]
        merge_answers = self.data[data_id]['merged_step_answer'][sample_id]
        merger_tfs = self.data[data_id]['merged_step_answer_tf'][sample_id]
        step_answer = self.data[data_id]['step_answer'][sample_id]
        step_tokens = self.data[data_id]['step_tokens'][sample_id]
        step_answer_tokens = self.data[data_id]['step_answer_tokens'][sample_id]
        answer = self.data[data_id]['generated_answer'][sample_id]

        es_answer = None
        es_tokens = None
        es_tokens_with_step_answer = None
        es_less = None

        difference_counts = np.diff(merge_counts, prepend=0)  # 计算相邻元素的差值
        for i in range(2, len(difference_counts)):
            current = difference_counts[i]
            previous = difference_counts[1:i]
            if current >= min_slope:
                target = self.get_mu(previous, threshold)
                if current >= target:
                    # 如果当前的计数大于等于目标计数，则认为可以提前停止
                    es_answer = merge_answers[i]
                    current_step = max(target, min_slope)
                    es_tokens = sum(step_tokens[:sum(merge_counts[:i]) + current_step])
                    es_tokens_with_step_answer = sum(step_answer_tokens[:sum(merge_counts[:i]) + current_step])
                    es_less = True
                    break

        if es_answer is None:
            es_answer = answer
            es_tokens = sum(step_tokens)
            es_tokens_with_step_answer = sum(step_answer_tokens)
            es_less = False
        return es_answer, es_tokens, es_tokens_with_step_answer, es_less

    def get_mu(self, x, threshold):
        n = len(x)
        x_bar = np.mean(x)
        s = np.std(x, ddof=1)
        if s == 0:
            s = 1
        df = n - 1

        t_crit = t.ppf(threshold, df)
        mu = x_bar - t_crit * (s / np.sqrt(n))
        # mu向上取整
        mu = np.ceil(mu).astype(int)
        return mu

    def statistics(self, sample_id=None, print_info=None):
        """Calculate statistics of the loaded data."""
        count = 0
        cot_true_count = 0
        es_true_count = 0
        cot_tokens_sum = 0
        es_tokens_sum = 0
        es_step_tokens_sum = 0
        for data in self.data:
            if sample_id is None:
                sample_id = list(range(len(data['step_answer'])))
            for s_id in sample_id:
                ground_truth = data['answer']
                cot_answer = data['generated_answer'][s_id]
                early_stop_answer = data['early_stop_answer'][s_id]
                cot_tokens = sum(data['step_tokens'][s_id])
                es_tokens = data['early_stop_tokens'][s_id]
                es_step_tokens = data['early_stop_step_tokens'][s_id]

                count += 1
                if cot_answer == ground_truth:
                    cot_true_count += 1
                if early_stop_answer == ground_truth:
                    es_true_count += 1
                cot_tokens_sum += cot_tokens
                es_tokens_sum += es_tokens
                es_step_tokens_sum += es_step_tokens

        if print_info is not None:
            print('| Total samples | CoT correct answers | Early Stop correct answers | CoT average tokens | Early Stop average tokens |')
            print(f'| {count} | {cot_true_count} ({cot_true_count / count:.2%}) | {es_true_count} ({es_true_count / count:.2%}) | {cot_tokens_sum / count:.2f} | {es_tokens_sum / count:.2f} |')

        return {'total_samples': count,
                'cot_correct_answers': cot_true_count,
                'es_correct_answers': es_true_count,
                'cot_average_tokens': cot_tokens_sum / count,
                'es_average_tokens': es_tokens_sum / count,
                'es_step_average_tokens': es_step_tokens_sum / count}

    def statistics_sc(self, sample_id=None, print_info=None):
        """Calculate statistics of the loaded data."""
        count = 0
        cot_true_count = 0
        es_true_count = 0
        cot_tokens_sum = 0
        es_tokens_sum = 0
        es_step_tokens_sum = 0
        for data in self.data:
            if sample_id is None:
                sample_id = list(range(len(data['step_answer'])))

            ground_truth = data['answer']
            cot_sc_answer = Counter([data['generated_answer'][s_id] for s_id in sample_id]).most_common(1)[0][0]
            early_stop_answer = Counter([data['early_stop_answer'][s_id] for s_id in sample_id]).most_common(1)[0][0]
            cot_tokens = sum([sum(data['step_tokens'][s_id]) for s_id in sample_id])
            es_tokens = sum([data['early_stop_tokens'][s_id] for s_id in sample_id])
            es_step_tokens = sum([data['early_stop_step_tokens'][s_id] for s_id in sample_id])

            count += 1
            if cot_sc_answer == ground_truth:
                cot_true_count += 1
            if early_stop_answer == ground_truth:
                es_true_count += 1
            cot_tokens_sum += cot_tokens
            es_tokens_sum += es_tokens
            es_step_tokens_sum += es_step_tokens

        if print_info is not None:
            print('| Total samples | CoT correct answers | Early Stop correct answers | CoT average tokens | Early Stop average tokens |')
            print(f'| {count} | {cot_true_count} ({cot_true_count / count:.2%}) | {es_true_count} ({es_true_count / count:.2%}) | {cot_tokens_sum / count:.2f} | {es_tokens_sum / count:.2f} |')

        return {'total_samples': count,
                'cot_correct_answers': cot_true_count,
                'es_correct_answers': es_true_count,
                'cot_average_tokens': cot_tokens_sum / count,
                'es_average_tokens': es_tokens_sum / count,
                'es_step_average_tokens': es_step_tokens_sum / count}


if __name__ == '__main__':
    # parameters
    rootpath = '/data/project/Reasoning/results/'
    model_lis = ['DeepSeek-R1-Distill-Llama-8B', 'QwQ-32B', 'Qwen3-8B']
    dataset_lis = ['aime', 'gpqa', 'math', 'minerva', 'olympiadbench']
    min_slope_lis = [3, 5, 7, 10, 15, 20]
    threshold_lis = [0.01, 0.05, 0.1, 0.15, 0.2]

    results = []
    results_sc = []
    for model in model_lis:
        for dataset in dataset_lis:
            for min_slope in min_slope_lis:
                for threshold in threshold_lis:
                    path = Path(rootpath) / model / dataset / f'{dataset}_step_results.jsonl'
                    print(f"Processing {path} with min_slope={min_slope}, threshold={threshold}")
                    if not path.exists():
                        print(f"File {path} does not exist. Skipping.")
                        continue

                    early_stop_cot = EarlyStopCoT(path)
                    early_stop_cot.early_stop_slope(min_slope=min_slope, threshold=threshold)  # slope版本

                    print_info = f"{model} on {dataset}"
                    temp = early_stop_cot.statistics(sample_id=[0], print_info=print_info)
                    temp_sc = early_stop_cot.statistics_sc(print_info=None)

                    temp['model'] = model
                    temp['dataset'] = dataset
                    temp['min_slope'] = min_slope
                    temp['threshold'] = threshold
                    temp_sc['model'] = model
                    temp_sc['dataset'] = dataset
                    temp_sc['min_slope'] = min_slope
                    temp_sc['threshold'] = threshold

                    results.append(temp)
                    results_sc.append(temp_sc)
                    print('======================================================')
    results = pd.DataFrame(results)
    results_sc = pd.DataFrame(results_sc)
    results.to_csv('/data/project/Reasoning/results/early_stop_cot_results_slope.csv', index=False)
    results_sc.to_csv('/data/project/Reasoning/results/early_stop_cot_results_sc_slope.csv', index=False)

