import os
import sys
import argparse
import json
import torch
import numpy as np
from collections import defaultdict
import gc
from tqdm import tqdm

project_root = os.environ.get("PROJECT_ROOT")
if project_root and project_root not in sys.path:
    sys.path.append(project_root)
from env import er_model

data_root = os.environ.get("DATA_ROOT")

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="qwen38")
    parser.add_argument("--dataset", type=str, default="math500")
    parser.add_argument("--reward_model", type=str, default="prm")
    parser.add_argument("--scheduler", type=str, default="naive")
    parser.add_argument("--prompt_type", type=str, default="better")
    parser.add_argument("--device", type=str, default="cuda:1")
    parser.add_argument("--max_steps", type=int, default=32768)
    parser.add_argument("--prefill_bit", type=int, default=4)
    parser.add_argument("--past_key_values", type=bool, default=None)
    parser.add_argument("--naive_bit", type=str, default="4,3")
    parser.add_argument("--high_precision_steps", type=int, default=3300, 
                       help="Number of tokens to use high precision before switching to low precision")
    parser.add_argument("--part", type=str, default="cot")
    parser.add_argument("--do_sample", type=bool, default=True)
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--output", type=str, default=None, help="Path to save the results JSON file. If not specified, will use auto-generated filename.")
    parser.add_argument("--num_samples", type=int, default=199, help="Number of samples to evaluate. Use -1 for all samples.")
    parser.add_argument("--split", type=bool, default=False, help="Whether to enable split mode")
    parser.add_argument("--xverify", type=bool, default=False)
    return parser.parse_args()

args = parse_args()

# Set the CUDA device at the beginning of the script
gpu_idx = int(args.device.split(":")[1])
torch.cuda.set_device(gpu_idx)

args_dict = {
        "model": args.model,
        "dataset": args.dataset,
        "reward_model": args.reward_model,
        "prompt_type": args.prompt_type,
        "scheduler": args.scheduler,
        "device": args.device,
        "max_steps": args.max_steps,
        "prefill_bit": args.prefill_bit,
        "past_key_values": args.past_key_values,
        "naive_bit": [int(bit) for bit in args.naive_bit.split(",")],
        "high_bit_steps": args.high_precision_steps,
        "part": args.part,
        "do_sample": args.do_sample,
        "temperature": args.temperature,
        "split": args.split,
        "xverify": args.xverify,
    }

print(f"Initializing model with naive scheduler...")
print(f"Configuration: {args_dict}")

model = er_model(**args_dict)
num_samples = None if args.num_samples == -1 else args.num_samples

def evaluate_naive_scheduler(model, num_samples):
    """Evaluate using naive scheduler"""
    content = []
    cot_precision = defaultdict(int)
    list_prob_correct = []
    list_prob_false = []
    
    real_time_token_counts = []
    precision_usage_stats = defaultdict(int)

    prompt = model.dataset.get_prompt(index=num_samples)
    answer_token_len = []
    thinking_chain_token_len = []
    answer = []
    thinking_chain = []
    no_thinking_chain_count = 0

    for item in tqdm(prompt, desc="Processing prompts with naive scheduler", unit="prompt"):
        model.model.reset(item, "")
        messages = [{"role": "user", "content": item}]
        text = model.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        model_inputs = model.tokenizer([text], return_tensors="pt").to(model.device)
        input_ids = model_inputs.input_ids
        
        outputs = model.model.generate(
            input_ids=input_ids, 
            max_new_tokens=model.max_steps, 
            prefill_bit=model.prefill_bit, 
            past_key_values=model.past_key_values,
            **model.kw_dict,
        )

        generated_tokens = len(outputs.input_ids[0]) - len(input_ids[0])
        real_time_token_counts.append(generated_tokens)

        temp_precision = outputs.cot_precision.copy()
        for key, value in temp_precision.items():
            cot_precision[key] += value
            precision_usage_stats[key] += value

        output_ids = outputs.input_ids
        output_ids = output_ids[0][len(input_ids[0])-2:]
    
        temp_output = model.tokenizer.decode(output_ids)
        content.append(temp_output)
        output_ids_list = output_ids.tolist()
        
        try:
            start_idx = output_ids_list.index(model.think_token)
            end_idx = output_ids_list.index(model.think_end_token)

            thinking_chain_ids = output_ids[start_idx+1:end_idx]
            solution_ids = output_ids[end_idx+1:]

            answer_token_len.append(len(solution_ids))
            thinking_chain_token_len.append(len(thinking_chain_ids))

            thinking_chain_text = model.tokenizer.decode(thinking_chain_ids, skip_special_tokens=True)
            solution_text = model.tokenizer.decode(solution_ids, skip_special_tokens=True)

            thinking_chain.append(thinking_chain_text)
            answer.append(solution_text)
        except ValueError:
            no_thinking_chain_count += 1
            thinking_chain.append("Error Thinking Chain")
            answer.append("Error Answer")
    
    avg_thinking_chain_len = sum(thinking_chain_token_len) / len(thinking_chain_token_len) if thinking_chain_token_len else 0
    avg_solution_len = sum(answer_token_len) / len(answer_token_len) if answer_token_len else 0

    weighted_sum = sum(bit * count for bit, count in cot_precision.items())
    non_zero_tokens = sum(count for bit, count in cot_precision.items())
    avg_precision = weighted_sum / non_zero_tokens if non_zero_tokens > 0 else 0
    
    precision_stats = {
        'total_tokens': non_zero_tokens,
        'avg_precision': avg_precision,
        'distribution': {}
    }
    for bit, count in sorted(cot_precision.items()):
        if bit != 0: 
            probability = count / non_zero_tokens
            precision_stats['distribution'][bit] = {
                'probability': probability,
                'count': count
            }

    if torch.cuda.is_available():
        all_prob_means = model.model.list_prob if hasattr(model.model, 'list_prob') else []
        all_split_probs = model.model.list_split_prob if hasattr(model.model, 'list_split_prob') else []
        all_split_prob_15 = model.model.list_split_prob_15 if hasattr(model.model, 'list_split_prob_15') else []
        
        del model.model
        del model.reward_model
        del model.tokenizer
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        gc.collect()

    if model.xverify_path is not None:
        accuracy, correct_list = model.dataset.eval_math_is_correct(answer)
        for i, is_correct in enumerate(correct_list):
            if i < len(all_prob_means): 
                prob_mean = all_prob_means[i]
                if is_correct == 1:
                    list_prob_correct.append(prob_mean)
                else:
                    list_prob_false.append(prob_mean)
    else:
        accuracy = model.dataset.result_eval(answer)
    
    avg_prob_correct = sum(list_prob_correct) / len(list_prob_correct) if list_prob_correct else 0
    avg_prob_false = sum(list_prob_false) / len(list_prob_false) if list_prob_false else 0
    
    high_precision_tokens = precision_usage_stats.get(max(model.kw_dict["precisions"]), 0)
    low_precision_tokens = precision_usage_stats.get(min(model.kw_dict["precisions"]), 0)
    total_tokens = high_precision_tokens + low_precision_tokens
    
    naive_stats = {
        "high_precision_tokens": high_precision_tokens,
        "low_precision_tokens": low_precision_tokens,
        "high_precision_ratio": high_precision_tokens / total_tokens if total_tokens > 0 else 0,
        "low_precision_ratio": low_precision_tokens / total_tokens if total_tokens > 0 else 0,
        "high_bit_steps": args.high_precision_steps,
        "naive_bit_config": args.naive_bit,
    }
    
    # Prepare results
    results = {
        "accuracy": accuracy,
        "no_thinking_chain_count": no_thinking_chain_count,
        "cot_precision": avg_precision,
        "avg_thinking_chain_len": avg_thinking_chain_len,
        "avg_solution_len": avg_solution_len,
        "precision_stats": precision_stats,
        "avg_prob_correct": avg_prob_correct,
        "avg_prob_false": avg_prob_false,
        "naive_scheduler_stats": naive_stats,
    }
    
    results["real_time_token_counts"] = real_time_token_counts
    results["avg_tokens_per_question"] = sum(real_time_token_counts) / len(real_time_token_counts) if real_time_token_counts else 0
    results["total_tokens_generated"] = sum(real_time_token_counts) if real_time_token_counts else 0

    return results, content

print(f"Starting evaluation with {num_samples if num_samples else 'all'} samples...")
results, content = evaluate_naive_scheduler(model, num_samples)

if args.output is None:
    num_samples_str = "all" if args.num_samples == -1 else str(args.num_samples)
    output_filename = f"{args.dataset}_Steps{args.high_precision_steps}_num{num_samples_str}_{args.model}.json"
    output_path = f"results/{output_filename}"
else:
    output_path = args.output

output_data = {
    "config": args_dict,
    "results": results,
}

# Create the output directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)

with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(output_data, f, indent=4, ensure_ascii=False)

print(f"\nResults saved to: {output_path}")
print(f"Naive Scheduler Configuration:")
print(f"  - High precision bits: {max(args_dict['naive_bit'])}")
print(f"  - Low precision bits: {min(args_dict['naive_bit'])}")
print(f"  - High precision steps: {args.high_precision_steps}")
print(f"  - High precision ratio: {results['naive_scheduler_stats']['high_precision_ratio']:.4f}")
print(f"  - Low precision ratio: {results['naive_scheduler_stats']['low_precision_ratio']:.4f}")
print(f"  - Average precision: {results['cot_precision']:.4f}")
print(f"  - Accuracy: {results['accuracy']:.4f}")
