from countsplit import get_cot,ave_length,check_block_fields
from agent_split import split_agent
import os
import sys
from tqdm import tqdm
import argparse
import json
import torch
import gc
from transformers import AutoTokenizer
from tools.PMPD.pmpd.modules.clarification import dewey_text_type

project_root = os.environ.get("PROJECT_ROOT")
if project_root and project_root not in sys.path:
    sys.path.append(project_root)
from env import Reward_Skywork,Reward_Prm,Gsm8kDataset,MathDataset

data_root = os.environ.get("DATA_ROOT")

def split_reward_argparser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, default="simple_split/math_qwen7b_4_4,3.log")
    parser.add_argument("--dataset", type=str, default="math")
    parser.add_argument("--reward", type=str, default="prm") # prm, skywork
    parser.add_argument("--agent_path", type=str, default=f"{data_root}/Qwen2.5-7B-Instruct/")
    parser.add_argument("--cuda_id", type=int, default=1)
    parser.add_argument("--n", type=int, default=1)
    parser.add_argument("--save_path", type=str, default="dewey_split/math_qwen7b_4_4,3.png")
    parser.add_argument("--split_method", type=str, default="original")  # "agent" or "original"
    return parser.parse_args()

dataset = {
    "gsm8k": Gsm8kDataset,
    "math": MathDataset
}

dataset_path = {
    "gsm8k": f"{data_root}/gsm8k",
    "math": f"{data_root}/efficient-reasoning/competition_math"
}

reward_path = {
    "skywork": Reward_Skywork,
    "prm": Reward_Prm
}

reward_model_path = {
    "skywork": f"{data_root}/Skywork-Reward-Llama-3.1-8B-v0.2",
    "prm": f"{data_root}/Qwen2.5-Math-PRM-7B"
}

def split_reward(data_path, dataset, agent_path, cuda_id, save_path, split_method):
    prompt = dataset.get_prompt()
    cot = get_cot("./raw_answer/" + data_path)
    print(ave_length(cot))
    reward_model = reward_path[args.reward](reward_model_path[args.reward], f"cuda:{cuda_id}")
    split_results = split_agent(cot, agent_path, cuda_id, split_method)
    reward_list = [[] for _ in range(len(split_results))]
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(agent_path, trust_remote_code=True)
    
    # Initialize token statistics variables
    total_problem_formulation_tokens = []
    total_computation_tokens = []
    total_verification_tokens = []
    total_conclusion_tokens = []

    # Initialize maximum reward statistics for each type of reward
    max_rewards = {
        "problem_formulation": 0,
        "computation": 0,
        "verification": 0,
        "conclusion": 0
    }
    
    # Initialize block count statistics for each type of text block
    block_counts = {
        "problem_formulation": 0,
        "computation": 0,
        "verification": 0,
        "conclusion": 0
    }

    # New: used to count the reward of all block_word
    blockword_reward_dict = {}
    blockword_count_dict = {}

    # Create output directory if it doesn't exist, newdewey is the directory name
    output_dir = os.path.join("./newdewey", os.path.splitext(save_path)[0])
    os.makedirs(output_dir, exist_ok=True)

    for i in tqdm(range(len(split_results)), desc="reward calculating"):
        # Create separate text files for each question
        text_output_path = os.path.join(output_dir, f"split_text_question_{i+1}.txt")
        with open(text_output_path, 'w', encoding='utf-8') as text_file:
            # Write question title
            text_file.write(f"text question {i+1}\n")
            text_file.write(f"Prompt: {prompt[i]}\n\n")
            
            # Create separate statistics lists for each question
            problem_formulation_blocks = []
            computation_blocks = []
            verification_blocks = []
            conclusion_blocks = []
            
            # Initialize token statistics for the current question
            problem_formulation_tokens = []
            computation_tokens = []
            verification_tokens = []
            conclusion_tokens = []
            
            # Get model answer
            model_answer = split_results[i][-1] if split_results[i] else ""
            
            for j in range(len(split_results[i])):
                reward = reward_model(system_prompt="", user=prompt[i], answer=split_results[i][j])
                reward_list[i].append(reward)
                
                # Write split text
                text_file.write(f"====== split {j+1} =====\n")
                text_file.write(f"Block content:\n{split_results[i][j]}\n")
                text_file.write(f"Reward value: {reward:.4f}\n\n")
                
                # Classify and count tokens for each block
                block_type, block_word = dewey_text_type(split_results[i][j])
                block_tokens = len(tokenizer.encode(split_results[i][j]))
                
                if block_type == "problem_formulation":
                    problem_formulation_blocks.append(reward)
                    problem_formulation_tokens.append(block_tokens)
                    max_rewards["problem_formulation"] = max(max_rewards["problem_formulation"], reward)
                    block_counts["problem_formulation"] += 1
                elif block_type == "computation":
                    computation_blocks.append(reward)
                    computation_tokens.append(block_tokens)
                    max_rewards["computation"] = max(max_rewards["computation"], reward)
                    block_counts["computation"] += 1
                elif block_type == "verification":
                    verification_blocks.append(reward)
                    verification_tokens.append(block_tokens)
                    max_rewards["verification"] = max(max_rewards["verification"], reward)
                    block_counts["verification"] += 1
                elif block_type == "conclusion":
                    conclusion_blocks.append(reward)
                    conclusion_tokens.append(block_tokens)
                    max_rewards["conclusion"] = max(max_rewards["conclusion"], reward)
                    block_counts["conclusion"] += 1

                if block_word not in blockword_reward_dict:
                    blockword_reward_dict[block_word] = []
                    blockword_count_dict[block_word] = 0
                blockword_reward_dict[block_word].append(reward)
                blockword_count_dict[block_word] += 1

            # Calculate statistics for each type of block in the current question
            def calculate_stats(blocks, tokens):
                if not blocks:
                    return 0, 0, 0
                return sum(blocks) / len(blocks), len(blocks), sum(tokens) / len(tokens) if tokens else 0

            problem_formulation_avg, problem_formulation_count, problem_formulation_token_avg = calculate_stats(problem_formulation_blocks, problem_formulation_tokens)
            computation_avg, computation_count, computation_token_avg = calculate_stats(computation_blocks, computation_tokens)
            verification_avg, verification_count, verification_token_avg = calculate_stats(verification_blocks, verification_tokens)
            conclusion_avg, conclusion_count, conclusion_token_avg = calculate_stats(conclusion_blocks, conclusion_tokens)

            # Add token statistics
            if problem_formulation_tokens:
                total_problem_formulation_tokens.extend(problem_formulation_tokens)
            if computation_tokens:
                total_computation_tokens.extend(computation_tokens)
            if verification_tokens:
                total_verification_tokens.extend(verification_tokens)
            if conclusion_tokens:
                total_conclusion_tokens.extend(conclusion_tokens)

            # Create statistics for the current question
            classification_stats = {
                "problem_formulation": {
                    "average": problem_formulation_avg,
                    "count": problem_formulation_count,
                    "avg_tokens": problem_formulation_token_avg,
                    "max_reward": max(problem_formulation_blocks) if problem_formulation_blocks else 0
                },
                "computation": {
                    "average": computation_avg,
                    "count": computation_count,
                    "avg_tokens": computation_token_avg,
                    "max_reward": max(computation_blocks) if computation_blocks else 0
                },
                "verification": {
                    "average": verification_avg,
                    "count": verification_count,
                    "avg_tokens": verification_token_avg,
                    "max_reward": max(verification_blocks) if verification_blocks else 0
                },
                "conclusion": {
                    "average": conclusion_avg,
                    "count": conclusion_count,
                    "avg_tokens": conclusion_token_avg,
                    "max_reward": max(conclusion_blocks) if conclusion_blocks else 0
                }
            }

            # Save statistics for the current question to JSON file
            question_file = os.path.join(output_dir, f"question_{i+1}.json")
            with open(question_file, 'w') as f:
                json.dump(classification_stats, f, indent=4)

            # Print statistics for the current question
            print(f"\nQuestion {i+1} Block classification statistics:")
            print(f"Problem formulation block: average reward = {problem_formulation_avg:.4f}, count = {problem_formulation_count}, average token count = {problem_formulation_token_avg:.2f}, max reward = {max(problem_formulation_blocks) if problem_formulation_blocks else 0:.4f}")
            print(f"Computation block: average reward = {computation_avg:.4f}, count = {computation_count}, average token count = {computation_token_avg:.2f}, max reward = {max(computation_blocks) if computation_blocks else 0:.4f}")
            print(f"Verification block: average reward = {verification_avg:.4f}, count = {verification_count}, average token count = {verification_token_avg:.2f}, max reward = {max(verification_blocks) if verification_blocks else 0:.4f}")
            print(f"Conclusion block: average reward = {conclusion_avg:.4f}, count = {conclusion_count}, average token count = {conclusion_token_avg:.2f}, max reward = {max(conclusion_blocks) if conclusion_blocks else 0:.4f}")

            # Write statistics to text file
            text_file.write("\nBlock classification statistics:\n")
            text_file.write(f"Problem formulation block: average reward = {problem_formulation_avg:.4f}, count = {problem_formulation_count}, average token count = {problem_formulation_token_avg:.2f}, max reward = {max(problem_formulation_blocks) if problem_formulation_blocks else 0:.4f}\n")
            text_file.write(f"Computation block: average reward = {computation_avg:.4f}, count = {computation_count}, average token count = {computation_token_avg:.2f}, max reward = {max(computation_blocks) if computation_blocks else 0:.4f}\n")
            text_file.write(f"Verification block: average reward = {verification_avg:.4f}, count = {verification_count}, average token count = {verification_token_avg:.2f}, max reward = {max(verification_blocks) if verification_blocks else 0:.4f}\n")
            text_file.write(f"Conclusion block: average reward = {conclusion_avg:.4f}, count = {conclusion_count}, average token count = {conclusion_token_avg:.2f}, max reward = {max(conclusion_blocks) if conclusion_blocks else 0:.4f}\n")

    def calculate_total_token_stats(tokens):
        return sum(tokens) / len(tokens) if tokens else 0

    total_problem_formulation_token_avg = calculate_total_token_stats(total_problem_formulation_tokens)
    total_computation_token_avg = calculate_total_token_stats(total_computation_tokens)
    total_verification_token_avg = calculate_total_token_stats(total_verification_tokens)
    total_conclusion_token_avg = calculate_total_token_stats(total_conclusion_tokens)
    
    # Initialize overall statistics variables
    total_problem_formulation_rewards = []
    total_computation_rewards = []
    total_verification_rewards = []
    total_conclusion_rewards = []
    total_questions = len(split_results)

    # Read all question JSON files and summarize data
    for i in range(total_questions):
        question_file = os.path.join(output_dir, f"question_{i+1}.json")
        with open(question_file, 'r') as f:
            stats = json.load(f)
            
            # Add rewards for each type of block
            if stats["problem_formulation"]["count"] > 0:
                total_problem_formulation_rewards.extend([stats["problem_formulation"]["average"]] * stats["problem_formulation"]["count"])
            if stats["computation"]["count"] > 0:
                total_computation_rewards.extend([stats["computation"]["average"]] * stats["computation"]["count"])
            if stats["verification"]["count"] > 0:
                total_verification_rewards.extend([stats["verification"]["average"]] * stats["verification"]["count"])
            if stats["conclusion"]["count"] > 0:
                total_conclusion_rewards.extend([stats["conclusion"]["average"]] * stats["conclusion"]["count"])

    # Calculate overall statistics results
    overall_stats = {
        "block_statistics": {
            "problem_formulation": {
                "total_average": sum(total_problem_formulation_rewards) / len(total_problem_formulation_rewards) if total_problem_formulation_rewards else 0,
                "total_count": len(total_problem_formulation_rewards),
                "total_avg_tokens": total_problem_formulation_token_avg,
                "variance": sum((x - (sum(total_problem_formulation_rewards) / len(total_problem_formulation_rewards)))**2 for x in total_problem_formulation_rewards) / len(total_problem_formulation_rewards) if total_problem_formulation_rewards else 0
            },
            "computation": {
                "total_average": sum(total_computation_rewards) / len(total_computation_rewards) if total_computation_rewards else 0,
                "total_count": len(total_computation_rewards),
                "total_avg_tokens": total_computation_token_avg,
                "variance": sum((x - (sum(total_computation_rewards) / len(total_computation_rewards)))**2 for x in total_computation_rewards) / len(total_computation_rewards) if total_computation_rewards else 0
            },
            "verification": {
                "total_average": sum(total_verification_rewards) / len(total_verification_rewards) if total_verification_rewards else 0,
                "total_count": len(total_verification_rewards),
                "total_avg_tokens": total_verification_token_avg,
                "variance": sum((x - (sum(total_verification_rewards) / len(total_verification_rewards)))**2 for x in total_verification_rewards) / len(total_verification_rewards) if total_verification_rewards else 0
            },
            "conclusion": {
                "total_average": sum(total_conclusion_rewards) / len(total_conclusion_rewards) if total_conclusion_rewards else 0,
                "total_count": len(total_conclusion_rewards),
                "total_avg_tokens": total_conclusion_token_avg,
                "variance": sum((x - (sum(total_conclusion_rewards) / len(total_conclusion_rewards)))**2 for x in total_conclusion_rewards) / len(total_conclusion_rewards) if total_conclusion_rewards else 0
            }
        },
        "total_questions": total_questions
    }
    
    # Save overall statistics results
    overall_stats_file = os.path.join(output_dir, "overall_block_statistics.json")
    with open(overall_stats_file, 'w') as f:
        json.dump(overall_stats, f, indent=4)
    
    # Print overall statistics results
    print("\nOverall Block statistics results:")
    print(f"Problem formulation block: total average reward = {overall_stats['block_statistics']['problem_formulation']['total_average']:.4f}, total count = {overall_stats['block_statistics']['problem_formulation']['total_count']}, total average token count = {overall_stats['block_statistics']['problem_formulation']['total_avg_tokens']:.2f}, variance = {overall_stats['block_statistics']['problem_formulation']['variance']:.4f}")
    print(f"Computation block: total average reward = {overall_stats['block_statistics']['computation']['total_average']:.4f}, total count = {overall_stats['block_statistics']['computation']['total_count']}, total average token count = {overall_stats['block_statistics']['computation']['total_avg_tokens']:.2f}, variance = {overall_stats['block_statistics']['computation']['variance']:.4f}")
    print(f"Verification block: total average reward = {overall_stats['block_statistics']['verification']['total_average']:.4f}, total count = {overall_stats['block_statistics']['verification']['total_count']}, total average token count = {overall_stats['block_statistics']['verification']['total_avg_tokens']:.2f}, variance = {overall_stats['block_statistics']['verification']['variance']:.4f}")
    print(f"Conclusion block: total average reward = {overall_stats['block_statistics']['conclusion']['total_average']:.4f}, total count = {overall_stats['block_statistics']['conclusion']['total_count']}, total average token count = {overall_stats['block_statistics']['conclusion']['total_avg_tokens']:.2f}, variance = {overall_stats['block_statistics']['conclusion']['variance']:.4f}")
    
    # Add overall statistics results to overall_stats
    overall_stats["block_statistics"].update({
        "max_rewards": max_rewards,
        "total_block_counts": block_counts
    })
    
    result_data = {
        "block_statistics": overall_stats["block_statistics"],
        "total_questions": overall_stats["total_questions"]
    }
    
    # Save overall statistics JSON file
    json_save_path = os.path.join(output_dir, "overall_stats.json")
    with open(json_save_path, 'w') as f:
        json.dump(result_data, f, indent=4)
    
    print("\nOverall statistics results:")
    
    # New: count the average reward and the number of occurrences of each block_word
    blockword_stats = {}
    for word, rewards in blockword_reward_dict.items():
        blockword_stats[word] = {
            "average_reward": sum(rewards) / len(rewards) if rewards else 0,
            "count": blockword_count_dict[word]
        }
    # Print block_word statistics
    print("\nBlock_word statistics:")
    for word, stats in blockword_stats.items():
        print(f"block_word: {word}, average reward: {stats['average_reward']:.4f}, count: {stats['count']}")
    # Save block_word statistics
    blockword_stats_file = os.path.join(output_dir, "blockword_stats.json")
    with open(blockword_stats_file, 'w') as f:
        json.dump(blockword_stats, f, indent=4)
    
    return result_data

if __name__ == "__main__":
    args = split_reward_argparser()
    dataset = dataset[args.dataset](dataset_path[args.dataset], "better")
    split_reward(args.data_path, dataset, args.agent_path, args.cuda_id, args.save_path, args.split_method)