import os
import sys
import argparse
import json
import torch
import numpy as np
from collections import defaultdict
import gc
from tqdm import tqdm
alpha = 0.1
minus = 0.5
mean = 0.4

project_root = os.environ.get("PROJECT_ROOT")
if project_root and project_root not in sys.path:
    sys.path.append(project_root)
from env import er_model

sys.path.append(f"{project_root}/tools/PMPD/pmpd/modules")
from clarification import dewey_text_type

data_root = os.environ.get("DATA_ROOT")

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="qwen7b")
    parser.add_argument("--dataset", type=str, default="math500")
    parser.add_argument("--reward_model", type=str, default="prm")
    parser.add_argument("--scheduler", type=str, default="descent_dewey_score")
    parser.add_argument("--prompt_type", type=str, default="better")
    parser.add_argument("--device", type=str, default="cuda:7")
    parser.add_argument("--max_steps", type=int, default=32768)
    parser.add_argument("--prefill_bit", type=int, default=4)
    parser.add_argument("--past_key_values", type=bool, default=None)
    parser.add_argument("--naive_bit", type=str, default="4,3")
    parser.add_argument("--high_bit_steps", type=int, default=512)
    parser.add_argument("--part", type=str, default="cot")
    parser.add_argument("--do_sample", type=bool, default=True)
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--output", type=str, default="tt/math443.json", help="Path to save the results JSON file. If not specified, will use timestamp-based filename.")
    parser.add_argument("--num_samples", type=int, default=29, help="Number of samples to evaluate. If not specified, will evaluate all samples.")
    parser.add_argument("--prune_path", type=str, default=None, help="Path to the prune function. If not specified, will not prune.")
    parser.add_argument("--split", type=bool, default=True, help="Path to the prune function. If not specified, will not prune.")
    parser.add_argument("--sol_precision", type=int, default=4)
    parser.add_argument("--windows", type=int, default=1)
    parser.add_argument("--alpha_split", type=str, default=f"{alpha},{alpha},{alpha}", help="Alpha values for split probability calculation")
    parser.add_argument("--minus_score", type=str, default=f"{minus},{minus},{minus}", help="Minus score values for processing")
    parser.add_argument("--problem_split_mean_score", type=str, default=f"{mean},{mean}")
    parser.add_argument("--problem_split_max_score", type=str, default="0,0")
    parser.add_argument("--computation_split_mean_score", type=str, default=f"{mean},{mean}")
    parser.add_argument("--computation_split_max_score", type=str, default="0,0")
    parser.add_argument("--verification_split_mean_score", type=str, default=f"{mean},{mean}")
    parser.add_argument("--verification_split_max_score", type=str, default="0,0")
    parser.add_argument("--descent_prompt", type=str, default="nothink", help="Path to custom descent_prompt json file.")
    parser.add_argument("--xverify", type=bool, default=True)
    return parser.parse_args()

def calculate_part_rewards_and_bits(content, dataset, reward_model, model, cot_precision_dict):
    """
    Calculate the average reward and bit values for different parts
    """
    part_rewards = {
        "problem": [],
        "problem_formulation": [],
        "computation": [],
        "verification": [],
        "answer": []
    }
    
    part_bits = {
        "problem": [],
        "problem_formulation": [],
        "computation": [],
        "verification": [],
        "answer": []
    }
    
    # Record the number of tokens for each type
    part_token_counts = {
        "problem": 0,
        "problem_formulation": 0,
        "computation": 0,
        "verification": 0,
        "answer": 0
    }
    
    # Get the original problem text
    prompts = dataset.get_prompt()
    
    for i, (output_text, problem_text) in enumerate(zip(content, prompts)):
        # Separate the thinking chain and answer parts
        try:
            # Find the boundary between thinking chain and answer
            # First try to find the <think> and </think> markers
            think_start = output_text.find("<think>")
            think_end = output_text.find("</think>")
            
            if think_start != -1 and think_end != -1:
                thinking_chain = output_text[think_start + 7:think_end].strip()
                answer_part = output_text[think_end + 8:].strip()
            else:
                # If the <think> marker is not found, try other splitting methods
                think_start = output_text.find("<|im_start|>assistant<|im_end|>\n<|im_start|>assistant<|im_end|>\n")
                if think_start != -1:
                    think_start += len("<|im_start|>assistant<|im_end|>\n<|im_start|>assistant<|im_end|>\n")
                    think_end = output_text.find("<|im_start|>assistant<|im_end|>\n<|im_start|>assistant<|im_end|>\n", think_start)
                    
                    if think_end != -1:
                        thinking_chain = output_text[think_start:think_end].strip()
                        answer_part = output_text[think_end:].strip()
                    else:
                        thinking_chain = output_text[think_start:].strip()
                        answer_part = ""
                else:
                    thinking_chain = ""
                    answer_part = output_text
        except:
            thinking_chain = ""
            answer_part = output_text
        
        # Calculate the reward for the problem (original problem text)
        try:
            problem_reward = reward_model("", problem_text, "")
            part_rewards["problem"].append(problem_reward)
            # Estimate the number of tokens in the problem part (assuming an average of 0.5 tokens per character)
            part_token_counts["problem"] += len(problem_text) * 0.5
        except:
            part_rewards["problem"].append(0.0)
        
        # Calculate the reward for the answer (final answer part)
        try:
            answer_reward = reward_model("", problem_text, answer_part)
            part_rewards["answer"].append(answer_reward)
            # Estimate the number of tokens in the answer part
            part_token_counts["answer"] += len(answer_part) * 0.5
        except:
            part_rewards["answer"].append(0.0)
        
        # Analyze the thinking chain by segment
        if thinking_chain:
            # Split the thinking chain by sentence
            sentences = thinking_chain.split('.')
            for sentence in sentences:
                sentence = sentence.strip()
                if not sentence:
                    continue
                
                # Classify the sentence type
                text_type, detected_word = dewey_text_type(sentence)
                
                # Calculate the reward for the sentence
                try:
                    sentence_reward = reward_model("", problem_text, sentence)
                    
                    if text_type == "problem_formulation":
                        part_rewards["problem_formulation"].append(sentence_reward)
                        part_token_counts["problem_formulation"] += len(sentence) * 0.5
                    elif text_type == "computation":
                        part_rewards["computation"].append(sentence_reward)
                        part_token_counts["computation"] += len(sentence) * 0.5
                    elif text_type == "verification":
                        part_rewards["verification"].append(sentence_reward)
                        part_token_counts["verification"] += len(sentence) * 0.5
                except:
                    continue
    
    avg_rewards = {}
    avg_bits = {}
    for part, rewards in part_rewards.items():
        if rewards:
            avg_rewards[f"avg_{part}_reward"] = np.mean(rewards)
            avg_rewards[f"{part}_reward_count"] = len(rewards)
        else:
            avg_rewards[f"avg_{part}_reward"] = 0.0
            avg_rewards[f"{part}_reward_count"] = 0
    
    # Calculate the average bit value for each type (based on cot_precision_dict)
    total_tokens = sum(part_token_counts.values())
    if total_tokens > 0 and cot_precision_dict:
        # Calculate the total weighted average bit value
        weighted_sum = sum(bit * count for bit, count in cot_precision_dict.items())
        total_cot_tokens = sum(count for bit, count in cot_precision_dict.items())
        avg_cot_bit = weighted_sum / total_cot_tokens if total_cot_tokens > 0 else 0
        
        # Calculate the average bit value based on the token count ratio for each type
        for part, token_count in part_token_counts.items():
            if token_count > 0:
                # Calculate the average bit value based on the token count ratio for each type
                token_ratio = token_count / total_tokens
                # Use the average bit value of cot_precision as the baseline, but adjust based on the token ratio
                # Here we assume that each type uses a similar bit distribution, so we directly use the average bit value
                avg_bits[f"avg_{part}_bit"] = avg_cot_bit
                avg_bits[f"{part}_bit_count"] = int(token_count)
            else:
                avg_bits[f"avg_{part}_bit"] = 0.0
                avg_bits[f"{part}_bit_count"] = 0
    else:
        for part in part_bits.keys():
            avg_bits[f"avg_{part}_bit"] = 0.0
            avg_bits[f"{part}_bit_count"] = 0
    
    # Merge the results
    result = {}
    result.update(avg_rewards)
    result.update(avg_bits)
    
    return result

args = parse_args()

if args.prune_path is not None:
    prune_path = f"{project_root}/src/code/cot_split/results/{args.prune_path}/data.json"
    with open(prune_path, "r") as f:
        data = json.load(f)
        prune_func = data["column_averages"]
else:
    prune_func = None
    print("prune_func is None")

# Set the CUDA device at the beginning of the script
gpu_idx = int(args.device.split(":")[1])
torch.cuda.set_device(gpu_idx)

args_dict = {
        "model": args.model,
        "dataset": args.dataset,
        "reward_model": args.reward_model,
        "prompt_type": args.prompt_type,
        "scheduler": args.scheduler,
        "device": args.device,
        "max_steps": args.max_steps,
        "prefill_bit": args.prefill_bit,
        "past_key_values": args.past_key_values,
        "naive_bit": [int(bit) for bit in args.naive_bit.split(",")],
        "high_bit_steps": args.high_bit_steps,
        "part": args.part,
        "do_sample": args.do_sample,
        "temperature": args.temperature,
        "prune_func": prune_func,
        "split": args.split,
        "sol_precision": args.sol_precision,
        "windows": args.windows,
        "alpha_split": [float(alpha) for alpha in args.alpha_split.split(",")],
        "minus_score": [float(score) for score in args.minus_score.split(",")],
        "problem_split_mean_score": [float(score) for score in args.problem_split_mean_score.split(",")],
        "problem_split_max_score": [float(score) for score in args.problem_split_max_score.split(",")],
        "computation_split_mean_score": [float(score) for score in args.computation_split_mean_score.split(",")],
        "computation_split_max_score": [float(score) for score in args.computation_split_max_score.split(",")],
        "verification_split_mean_score": [float(score) for score in args.verification_split_mean_score.split(",")],
        "verification_split_max_score": [float(score) for score in args.verification_split_max_score.split(",")],
        "descent_prompt": args.descent_prompt,
        "xverify": args.xverify,
    }
model = er_model(**args_dict)
num_samples = None if args.num_samples == -1 else args.num_samples

# Create a new evaluate method that calculates part rewards before deleting the reward_model
def evaluate_with_part_rewards(model, num_samples):
    """Evaluate method with part rewards calculation"""
    content = []
    cot_precision = defaultdict(int)
    text_type_stats = defaultdict(int)
    list_prob_correct = []
    list_prob_false = []
    list_split_correct = []
    list_split_false = []
    all_split_prob_15 = []

    prompt = model.dataset.get_prompt(index=num_samples)
    answer_token_len = []
    thinking_chain_token_len = []
    answer = []
    thinking_chain = []
    no_thinking_chain_count = 0
    
    # Record the number of tokens for each question in real time
    real_time_token_counts = []

    for item in tqdm(prompt, desc="Processing prompts", unit="prompt"):
        model.model.reset(item, "")
        messages = [{"role": "user", "content": item}]
        text = model.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        model_inputs = model.tokenizer([text], return_tensors="pt").to(model.device)
        input_ids = model_inputs.input_ids
        outputs = model.model.generate(
            input_ids=input_ids, 
            max_new_tokens=model.max_steps, 
            prefill_bit=model.prefill_bit, 
            past_key_values=model.past_key_values,
            **model.kw_dict,
        )

        # Record the number of tokens in real time
        if model.record_token_count:
            generated_tokens = len(outputs.input_ids[0]) - len(input_ids[0])
            real_time_token_counts.append(generated_tokens)

        temp_precision = outputs.cot_precision.copy()
        for key, value in temp_precision.items():
            cot_precision[key] += value

        output_ids = outputs.input_ids
        output_ids = output_ids[0][len(input_ids[0])-2:]
    
        temp_output = model.tokenizer.decode(output_ids)
        content.append(temp_output)
        output_ids_list = output_ids.tolist()
        try:
            start_idx = output_ids_list.index(model.think_token)
            end_idx = output_ids_list.index(model.think_end_token)

            thinking_chain_ids = output_ids[start_idx+1:end_idx]
            solution_ids = output_ids[end_idx+1:]

            answer_token_len.append(len(solution_ids))
            thinking_chain_token_len.append(len(thinking_chain_ids))

            thinking_chain_text = model.tokenizer.decode(thinking_chain_ids, skip_special_tokens=True)
            solution_text = model.tokenizer.decode(solution_ids, skip_special_tokens=True)

            thinking_chain.append(thinking_chain_text)
            answer.append(solution_text)
        except ValueError:
            no_thinking_chain_count += 1
            thinking_chain.append("Error Thinking Chain")
            answer.append("Error Answer")
    
    # Calculate part rewards and bit values before deleting the reward_model
    part_rewards = {}
    if hasattr(model, 'reward_model') and model.reward_model is not None:
        try:
            part_rewards = calculate_part_rewards_and_bits(content, model.dataset, model.reward_model, model, cot_precision)
            print(f"\nPart rewards and bit value statistics:")
            for key, value in part_rewards.items():
                if key.startswith("avg_"):
                    print(f"{key}: {value:.4f}")
                elif key.endswith("_count"):
                    print(f"{key}: {value}")
        except Exception as e:
            print(f"Warning: Error calculating part rewards and bits: {e}")

    avg_thinking_chain_len = sum(thinking_chain_token_len) / len(thinking_chain_token_len) if thinking_chain_token_len else 0
    avg_solution_len = sum(answer_token_len) / len(answer_token_len) if answer_token_len else 0

    weighted_sum = sum(bit * count for bit, count in cot_precision.items())
    non_zero_tokens = sum(count for bit, count in cot_precision.items())
    avg_precision = weighted_sum / non_zero_tokens if non_zero_tokens > 0 else 0
    
    precision_stats = {
        'total_tokens': non_zero_tokens,
        'avg_precision': avg_precision,
        'distribution': {}
    }
    for bit, count in sorted(cot_precision.items()):
        if bit != 0: 
            probability = count / non_zero_tokens
            precision_stats['distribution'][bit] = {
                'probability': probability
            }

    if torch.cuda.is_available():
        all_prob_means = model.model.list_prob if hasattr(model.model, 'list_prob') else []
        all_split_probs = model.model.list_split_prob if hasattr(model.model, 'list_split_prob') else []
        all_split_prob_15 = model.model.list_split_prob_15 if hasattr(model.model, 'list_split_prob_15') else []
        
        del model.model
        del model.reward_model
        del model.tokenizer
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        gc.collect()

    if model.xverify_path is not None:
        accuracy, correct_list = model.dataset.eval_math_is_correct(answer)
        for i, is_correct in enumerate(correct_list):
            if i < len(all_prob_means): 
                prob_mean = all_prob_means[i]
                if is_correct == 1:
                    list_prob_correct.append(prob_mean)
                else:
                    list_prob_false.append(prob_mean)
    else:
        accuracy = model.dataset.result_eval(answer)
    
    if all_split_probs and model.xverify_path is not None:
        for i, is_correct in enumerate(correct_list):
            if i < len(all_split_probs): 
                split_prob = all_split_probs[i]
                if is_correct == 1:
                    list_split_correct.append(split_prob)
                else:
                    list_split_false.append(split_prob)
    
    split_prob_15_correct = [[] for _ in range(15)]
    split_prob_15_false = [[] for _ in range(15)]
    
    if all_split_prob_15 and model.xverify_path is not None:
        for i, is_correct in enumerate(correct_list):
            if i < len(all_split_prob_15): 
                split_prob_15 = all_split_prob_15[i]
                for pos in range(15):
                    if pos < len(split_prob_15):
                        prob_value = split_prob_15[pos]
                        if is_correct == 1:
                            split_prob_15_correct[pos].append(prob_value)
                        else:
                            split_prob_15_false[pos].append(prob_value)
    
    avg_split_prob_15_correct = []
    avg_split_prob_15_false = []
    for pos in range(15):
        if split_prob_15_correct[pos]:
            avg_correct = sum(split_prob_15_correct[pos]) / len(split_prob_15_correct[pos])
        else:
            avg_correct = 0.0
        avg_split_prob_15_correct.append(avg_correct)
        if split_prob_15_false[pos]:
            avg_false = sum(split_prob_15_false[pos]) / len(split_prob_15_false[pos])
        else:
            avg_false = 0.0
        avg_split_prob_15_false.append(avg_false)
    
    avg_prob_correct = sum(list_prob_correct) / len(list_prob_correct) if list_prob_correct else 0
    avg_prob_false = sum(list_prob_false) / len(list_prob_false) if list_prob_false else 0
    avg_split_prob_correct = sum(list_split_correct) / len(list_split_correct) if list_split_correct else 0
    avg_split_prob_false = sum(list_split_false) / len(list_split_false) if list_split_false else 0
    
    # Prepare results
    results = {
        "accuracy": accuracy,
        "no_thinking_chain_count": no_thinking_chain_count,
        "cot_precision": avg_precision,
        "avg_thinking_chain_len": avg_thinking_chain_len,
        "avg_solution_len": avg_solution_len,
        "precision_stats": precision_stats,
        "text_type_stats": dict(text_type_stats),
        "avg_prob_correct": avg_prob_correct,
        "avg_prob_false": avg_prob_false,
        "avg_split_prob_correct": avg_split_prob_correct,
        "avg_split_prob_false": avg_split_prob_false,
        "avg_split_prob_15_correct": avg_split_prob_15_correct,
        "avg_split_prob_15_false": avg_split_prob_15_false,
    }
    
    # Add part reward results
    results.update(part_rewards)
    
    # Add token count information
    if model.record_token_count:
        results["real_time_token_counts"] = real_time_token_counts
        results["avg_tokens_per_question"] = sum(real_time_token_counts) / len(real_time_token_counts) if real_time_token_counts else 0
        results["total_tokens_generated"] = sum(real_time_token_counts) if real_time_token_counts else 0

    return results, content

# Use the new evaluate method
results, content = evaluate_with_part_rewards(model, num_samples)

output_data = {
    "config": args_dict,
    "results": results,
}
output_path = args.output

# Create the output directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)

with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(output_data, f, indent=4, ensure_ascii=False)

print(f"\nResults saved to: {output_path}")