import os
import sys
import argparse
import json
import torch
import time
from collections import defaultdict
alpha = 0.1
minus = 0.5
mean = 0.4

project_root = os.environ.get("PROJECT_ROOT")
if project_root and project_root not in sys.path:
    sys.path.append(project_root)
from env import er_model

data_root = os.environ.get("DATA_ROOT")

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="qwen7b")
    parser.add_argument("--dataset", type=str, default="math500")
    parser.add_argument("--reward_model", type=str, default="prm")
    parser.add_argument("--scheduler", type=str, default="descent_dewey_score")
    parser.add_argument("--prompt_type", type=str, default="better")
    parser.add_argument("--device", type=str, default="cuda:5")
    parser.add_argument("--max_steps", type=int, default=32768)
    parser.add_argument("--prefill_bit", type=int, default=4)
    parser.add_argument("--past_key_values", type=bool, default=None)
    parser.add_argument("--naive_bit", type=str, default="4,3")
    parser.add_argument("--high_bit_steps", type=int, default=512)
    parser.add_argument("--part", type=str, default="cot")
    parser.add_argument("--do_sample", type=bool, default=True)
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--output", type=str, default="tt/math443.json", help="Path to save the results JSON file. If not specified, will use timestamp-based filename.")
    parser.add_argument("--num_samples", type=int, default=29, help="Number of samples to evaluate. If not specified, will evaluate all samples.")
    parser.add_argument("--prune_path", type=str, default=None, help="Path to the prune function. If not specified, will not prune.")
    parser.add_argument("--split", type=bool, default=True, help="Path to the prune function. If not specified, will not prune.")
    parser.add_argument("--sol_precision", type=int, default=4)
    parser.add_argument("--windows", type=int, default=1)
    parser.add_argument("--alpha_split", type=str, default=f"{alpha},{alpha},{alpha}", help="Alpha values for split probability calculation")
    parser.add_argument("--minus_score", type=str, default=f"{minus},{minus},{minus}", help="Minus score values for processing")
    parser.add_argument("--problem_split_mean_score", type=str, default=f"{mean},{mean}")
    parser.add_argument("--problem_split_max_score", type=str, default="0,0")
    parser.add_argument("--computation_split_mean_score", type=str, default=f"{mean},{mean}")
    parser.add_argument("--computation_split_max_score", type=str, default="0,0")
    parser.add_argument("--verification_split_mean_score", type=str, default=f"{mean},{mean}")
    parser.add_argument("--verification_split_max_score", type=str, default="0,0")
    parser.add_argument("--descent_prompt", type=str, default="nothink", help="Path to custom descent_prompt json file.")
    parser.add_argument("--xverify", type=bool, default=False)
    return parser.parse_args()

args = parse_args()

if args.prune_path is not None:
    prune_path = f"{project_root}/src/code/cot_split/results/{args.prune_path}/data.json"
    with open(prune_path, "r") as f:
        data = json.load(f)
        prune_func = data["column_averages"]
else:
    prune_func = None
    print("prune_func is None")

# Set the CUDA device at the beginning of the script
gpu_idx = int(args.device.split(":")[1])
torch.cuda.set_device(gpu_idx)

args_dict = {
        "model": args.model,
        "dataset": args.dataset,
        "reward_model": args.reward_model,
        "prompt_type": args.prompt_type,
        "scheduler": args.scheduler,
        "device": args.device,
        "max_steps": args.max_steps,
        "prefill_bit": args.prefill_bit,
        "past_key_values": args.past_key_values,
        "naive_bit": [int(bit) for bit in args.naive_bit.split(",")],
        "high_bit_steps": args.high_bit_steps,
        "part": args.part,
        "do_sample": args.do_sample,
        "temperature": args.temperature,
        "prune_func": prune_func,
        "split": args.split,
        "sol_precision": args.sol_precision,
        "windows": args.windows,
        "alpha_split": [float(alpha) for alpha in args.alpha_split.split(",")],
        "minus_score": [float(score) for score in args.minus_score.split(",")],
        "problem_split_mean_score": [float(score) for score in args.problem_split_mean_score.split(",")],
        "problem_split_max_score": [float(score) for score in args.problem_split_max_score.split(",")],
        "computation_split_mean_score": [float(score) for score in args.computation_split_mean_score.split(",")],
        "computation_split_max_score": [float(score) for score in args.computation_split_max_score.split(",")],
        "verification_split_mean_score": [float(score) for score in args.verification_split_mean_score.split(",")],
        "verification_split_max_score": [float(score) for score in args.verification_split_max_score.split(",")],
        "descent_prompt": args.descent_prompt,
        "xverify": args.xverify,
    }

total_start_time = time.time()

model = er_model(**args_dict)
num_samples = None if args.num_samples == -1 else args.num_samples

class TimedRewardModel:
    def __init__(self, original_reward_model):
        self.original_reward_model = original_reward_model
        self.total_time = 0.0
        self.call_count = 0
        self.times = []
    
    def __call__(self, system_prompt, user, answer):
        start_time = time.time()
        result = self.original_reward_model(system_prompt, user, answer)
        end_time = time.time()
        
        call_time = end_time - start_time
        self.total_time += call_time
        self.call_count += 1
        self.times.append(call_time)
        
        return result
    
    def get_stats(self):
        return {
            "total_time": self.total_time,
            "call_count": self.call_count,
            "avg_time": self.total_time / self.call_count if self.call_count > 0 else 0,
            "times": self.times
        }

if hasattr(model.model, 'reward_func') and model.model.reward_func is not None:
    model.model.reward_func = TimedRewardModel(model.model.reward_func)

def evaluate_with_timing(model, num_samples):
    """Evaluate method with time measurement"""
    content = []
    cot_precision = defaultdict(int)
    text_type_stats = defaultdict(int)
    list_prob_correct = []
    list_prob_false = []
    list_split_correct = []
    list_split_false = []
    all_split_prob_15 = []

    total_generation_times = []
    reward_model_times = []
    evaluation_times = []
    other_times = []

    prompt = model.dataset.get_prompt(index=num_samples)
    answer_token_len = []
    thinking_chain_token_len = []
    answer = []
    thinking_chain = []
    no_thinking_chain_count = 0
    
    real_time_token_counts = []

    for item_idx, item in enumerate(prompt):
        print(f"Processing question {item_idx + 1}/{len(prompt)}")
        
        question_start_time = time.time()
        
        if hasattr(model.model, 'reward_func') and isinstance(model.model.reward_func, TimedRewardModel):
            model.model.reward_func.total_time = 0.0
            model.model.reward_func.call_count = 0
            model.model.reward_func.times = []
        
        model.model.reset(item, "")
        messages = [{"role": "user", "content": item}]
        text = model.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        
        model_inputs = model.tokenizer([text], return_tensors="pt").to(model.device)
        input_ids = model_inputs.input_ids
        
        generation_start_time = time.time()
        
        outputs = model.model.generate(
            input_ids=input_ids, 
            max_new_tokens=model.max_steps, 
            prefill_bit=model.prefill_bit, 
            past_key_values=model.past_key_values,
            **model.kw_dict,
        )
        
        generation_end_time = time.time()
        generation_time = generation_end_time - generation_start_time
        total_generation_times.append(generation_time)

        generated_tokens = len(outputs.input_ids[0]) - len(input_ids[0])
        real_time_token_counts.append(generated_tokens)

        temp_precision = outputs.cot_precision.copy()
        for key, value in temp_precision.items():
            cot_precision[key] += value

        output_ids = outputs.input_ids
        output_ids = output_ids[0][len(input_ids[0])-2:]
    
        temp_output = model.tokenizer.decode(output_ids)
        content.append(temp_output)
        output_ids_list = output_ids.tolist()
        
        try:
            start_idx = output_ids_list.index(model.think_token)
            end_idx = output_ids_list.index(model.think_end_token)

            thinking_chain_ids = output_ids[start_idx+1:end_idx]
            solution_ids = output_ids[end_idx+1:]

            answer_token_len.append(len(solution_ids))
            thinking_chain_token_len.append(len(thinking_chain_ids))

            thinking_chain_text = model.tokenizer.decode(thinking_chain_ids, skip_special_tokens=True)
            solution_text = model.tokenizer.decode(solution_ids, skip_special_tokens=True)

            thinking_chain.append(thinking_chain_text)
            answer.append(solution_text)
        except ValueError:
            no_thinking_chain_count += 1
            thinking_chain.append("Error Thinking Chain")
            answer.append("Error Answer")
        
        question_end_time = time.time()
        question_total_time = question_end_time - question_start_time
        
        reward_time = 0.0
        if hasattr(model.model, 'reward_func') and isinstance(model.model.reward_func, TimedRewardModel):
            reward_stats = model.model.reward_func.get_stats()
            reward_time = reward_stats["total_time"]
        
        reward_model_times.append(reward_time)
        
        evaluation_start_time = time.time()
        
        evaluation_time = 0.0
        
        evaluation_times.append(evaluation_time)
        
        other_time = question_total_time - reward_time - evaluation_time
        other_times.append(other_time)
        
        print(f"Question {item_idx + 1} - Total: {question_total_time:.3f}s, Generation: {generation_time:.3f}s, Reward: {reward_time:.3f}s, Evaluation: {evaluation_time:.3f}s, Other: {other_time:.3f}s")
        if hasattr(model.model, 'reward_func') and isinstance(model.model.reward_func, TimedRewardModel):
            reward_stats = model.model.reward_func.get_stats()
            print(f"  Reward calls: {reward_stats['call_count']}, Avg reward time: {reward_stats['avg_time']:.3f}s")

    avg_total_time = sum(total_generation_times) / len(total_generation_times) if total_generation_times else 0
    avg_reward_time = sum(reward_model_times) / len(reward_model_times) if reward_model_times else 0
    avg_evaluation_time = sum(evaluation_times) / len(evaluation_times) if evaluation_times else 0
    avg_other_time = sum(other_times) / len(other_times) if other_times else 0
    
    avg_thinking_chain_len = sum(thinking_chain_token_len) / len(thinking_chain_token_len) if thinking_chain_token_len else 0
    avg_solution_len = sum(answer_token_len) / len(answer_token_len) if answer_token_len else 0

    weighted_sum = sum(bit * count for bit, count in cot_precision.items())
    non_zero_tokens = sum(count for bit, count in cot_precision.items())
    avg_precision = weighted_sum / non_zero_tokens if non_zero_tokens > 0 else 0
    
    precision_stats = {
        'total_tokens': non_zero_tokens,
        'avg_precision': avg_precision,
        'distribution': {}
    }
    for bit, count in sorted(cot_precision.items()):
        if bit != 0: 
            probability = count / non_zero_tokens
            precision_stats['distribution'][bit] = {
                'probability': probability
            }

    if torch.cuda.is_available():
        all_prob_means = model.model.list_prob if hasattr(model.model, 'list_prob') else []
        all_split_probs = model.model.list_split_prob if hasattr(model.model, 'list_split_prob') else []
        all_split_prob_15 = model.model.list_split_prob_15 if hasattr(model.model, 'list_split_prob_15') else []
        
        del model.model
        del model.reward_model
        del model.tokenizer
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        import gc
        gc.collect()

    evaluation_start_time = time.time()
    
    if model.xverify_path is not None:
        accuracy, correct_list = model.dataset.eval_math_is_correct(answer)
        for i, is_correct in enumerate(correct_list):
            if i < len(all_prob_means): 
                prob_mean = all_prob_means[i]
                if is_correct == 1:
                    list_prob_correct.append(prob_mean)
                else:
                    list_prob_false.append(prob_mean)
    else:
        accuracy = model.dataset.result_eval(answer)
    
    evaluation_end_time = time.time()
    total_evaluation_time = evaluation_end_time - evaluation_start_time
    
    avg_evaluation_time_per_question = total_evaluation_time / len(prompt) if prompt else 0
    evaluation_times = [avg_evaluation_time_per_question] * len(prompt)
    
    other_times = []
    for i in range(len(prompt)):
        other_time = total_generation_times[i] - reward_model_times[i] - evaluation_times[i]
        other_times.append(other_time)
    
    avg_evaluation_time = sum(evaluation_times) / len(evaluation_times) if evaluation_times else 0
    avg_other_time = sum(other_times) / len(other_times) if other_times else 0
    
    if all_split_probs and model.xverify_path is not None:
        for i, is_correct in enumerate(correct_list):
            if i < len(all_split_probs): 
                split_prob = all_split_probs[i]
                if is_correct == 1:
                    list_split_correct.append(split_prob)
                else:
                    list_split_false.append(split_prob)
    
    split_prob_15_correct = [[] for _ in range(15)]
    split_prob_15_false = [[] for _ in range(15)]
    
    if all_split_prob_15 and model.xverify_path is not None:
        for i, is_correct in enumerate(correct_list):
            if i < len(all_split_prob_15): 
                split_prob_15 = all_split_prob_15[i]
                for pos in range(15):
                    if pos < len(split_prob_15):
                        prob_value = split_prob_15[pos]
                        if is_correct == 1:
                            split_prob_15_correct[pos].append(prob_value)
                        else:
                            split_prob_15_false[pos].append(prob_value)
    
    avg_split_prob_15_correct = []
    avg_split_prob_15_false = []
    for pos in range(15):
        if split_prob_15_correct[pos]:
            avg_correct = sum(split_prob_15_correct[pos]) / len(split_prob_15_correct[pos])
        else:
            avg_correct = 0.0
        avg_split_prob_15_correct.append(avg_correct)
        if split_prob_15_false[pos]:
            avg_false = sum(split_prob_15_false[pos]) / len(split_prob_15_false[pos])
        else:
            avg_false = 0.0
        avg_split_prob_15_false.append(avg_false)
    
    avg_prob_correct = sum(list_prob_correct) / len(list_prob_correct) if list_prob_correct else 0
    avg_prob_false = sum(list_prob_false) / len(list_prob_false) if list_prob_false else 0
    avg_split_prob_correct = sum(list_split_correct) / len(list_split_correct) if list_split_correct else 0
    avg_split_prob_false = sum(list_split_false) / len(list_split_false) if list_split_false else 0
    
    results = {
        "accuracy": accuracy,
        "no_thinking_chain_count": no_thinking_chain_count,
        "cot_precision": avg_precision,
        "avg_thinking_chain_len": avg_thinking_chain_len,
        "avg_solution_len": avg_solution_len,
        "precision_stats": precision_stats,
        "text_type_stats": dict(text_type_stats),
        "avg_prob_correct": avg_prob_correct,
        "avg_prob_false": avg_prob_false,
        "avg_split_prob_correct": avg_split_prob_correct,
        "avg_split_prob_false": avg_split_prob_false,
        "avg_split_prob_15_correct": avg_split_prob_15_correct,
        "avg_split_prob_15_false": avg_split_prob_15_false,
        "timing_stats": {
            "avg_total_time_per_question": avg_total_time,
            "avg_reward_model_time_per_question": avg_reward_time,
            "avg_evaluation_time_per_question": avg_evaluation_time,
            "avg_other_time_per_question": avg_other_time,
            "total_generation_times": total_generation_times,
            "reward_model_times": reward_model_times,
            "evaluation_times": evaluation_times,
            "other_times": other_times,
            "total_questions_processed": len(prompt),
            "total_evaluation_time": total_evaluation_time
        },
        "token_stats": {
            "real_time_token_counts": real_time_token_counts,
            "avg_tokens_per_question": sum(real_time_token_counts) / len(real_time_token_counts) if real_time_token_counts else 0,
            "total_tokens_generated": sum(real_time_token_counts) if real_time_token_counts else 0,
            "min_tokens_per_question": min(real_time_token_counts) if real_time_token_counts else 0,
            "max_tokens_per_question": max(real_time_token_counts) if real_time_token_counts else 0,
            "avg_thinking_chain_tokens": avg_thinking_chain_len,
            "avg_solution_tokens": avg_solution_len,
            "thinking_chain_token_lengths": thinking_chain_token_len,
            "solution_token_lengths": answer_token_len
        }
    }
    

    return results, content

results, content = evaluate_with_timing(model, num_samples)

total_end_time = time.time()
total_runtime = total_end_time - total_start_time

results["total_runtime"] = total_runtime

output_data = {
    "config": args_dict,
    "results": results,
}
output_path = args.output

# Create the output directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)

with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(output_data, f, indent=4, ensure_ascii=False)
