import os
import sys
import argparse
from tqdm import tqdm
import torch

project_root = os.environ.get("PROJECT_ROOT")
if project_root and project_root not in sys.path:
    sys.path.append(project_root)
from env import Agent, Math500Dataset,AMC23Dataset,Gsm8kDataset,AimeDataset,GpqaDataset

data_root = os.environ.get("DATA_ROOT")
a = 4000
b = 500
cuda = 6
class CustomAgent(Agent):
    def __init__(self, model_path, is_anyprecision=False, device=f"cuda:{cuda}", insert_token_a=None, wait_threshold_b=None):
        super().__init__(model_path, is_anyprecision, device)
        self.insert_token_a = insert_token_a
        self.wait_threshold_b = wait_threshold_b
        self.think_end_token_id = self.tokenizer.encode("</think>")[-1] if "</think>" in self.tokenizer.get_vocab() else None
        
    def __call__(self, prompt, do_sample=True, precision=8):
        messages = [
            {"role": "user", "content": prompt},
        ]
            
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)

        kw_args = {"max_new_tokens": 32768, "output_scores": True, "return_dict_in_generate": True}
        if do_sample:
            kw_args["do_sample"] = True
            kw_args["temperature"] = 0.6 
        else:
            kw_args["do_sample"] = False
        if self.is_anyprecision:
            kw_args["precision"] = precision 

        if self.insert_token_a is not None or self.wait_threshold_b is not None:
            return self._custom_generate(model_inputs, **kw_args)
        else:
            generated_output = self.model.generate(
                **model_inputs,
                **kw_args,
            )

            generated_ids = generated_output.sequences
            generated_ids = [
                output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
            ]

            response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
            
            return response, len(generated_ids[0])
    
    def _custom_generate(self, model_inputs, **kw_args):
        input_ids = model_inputs.input_ids
        attention_mask = model_inputs.attention_mask
        
        generated_ids = input_ids.clone()
        current_length = input_ids.shape[1]
        max_new_tokens = kw_args.get("max_new_tokens", 32768)
        
        inserted_think_end = False
        tokens_after_think = 0
        think_end_found = False
        
        for step in range(max_new_tokens):
            if (self.insert_token_a is not None and 
                step == self.insert_token_a and 
                not inserted_think_end and 
                self.think_end_token_id is not None):
                
                think_end_tensor = torch.tensor([[self.think_end_token_id]], device=generated_ids.device)
                generated_ids = torch.cat([generated_ids, think_end_tensor], dim=1)
                attention_mask = torch.cat([attention_mask, torch.tensor([[1]], device=attention_mask.device)], dim=1)
                inserted_think_end = True
                think_end_found = True
                current_length += 1
                continue
            
            with torch.no_grad():
                outputs = self.model(
                    input_ids=generated_ids,
                    attention_mask=attention_mask,
                    use_cache=True
                )
                
                next_token_logits = outputs.logits[:, -1, :]
                
                if kw_args.get("do_sample", False):
                    probs = torch.softmax(next_token_logits / kw_args.get("temperature", 0.6), dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1)
                else:
                    next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
            
            if (self.think_end_token_id is not None and 
                next_token.item() == self.think_end_token_id and 
                not think_end_found):
                think_end_found = True
                tokens_after_think = 0
            elif think_end_found:
                tokens_after_think += 1
            
            if (self.wait_threshold_b is not None and 
                think_end_found and 
                tokens_after_think < self.wait_threshold_b and 
                tokens_after_think > 0):
                
                wait_token_id = self.tokenizer.encode(" Wait")[-1] if " Wait" in self.tokenizer.get_vocab() else next_token.item()
                next_token = torch.tensor([[wait_token_id]], device=next_token.device)
            
            generated_ids = torch.cat([generated_ids, next_token], dim=1)
            attention_mask = torch.cat([attention_mask, torch.tensor([[1]], device=attention_mask.device)], dim=1)
            current_length += 1
            
            if next_token.item() == self.tokenizer.eos_token_id:
                break
        
        generated_ids = generated_ids[0][len(input_ids[0]):]
        response = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
        
        return response, len(generated_ids)

def main():
    parser = argparse.ArgumentParser(description='Run model evaluation with custom token insertion')
    parser.add_argument('--a', type=int, default=None, help='Insert </think> at the a-th generated token')
    parser.add_argument('--b', type=int, default=None, help='Insert Wait if tokens after </think> is less than b')
    parser.add_argument('--device', type=str, default=f'cuda:{cuda}', help='Device to use')
    parser.add_argument('--num_samples', type=int, default=15, help='Number of samples to evaluate')
    parser.add_argument('--xverify', type=bool, default=False, help='Whether to use xverify for evaluation')
    
    args = parser.parse_args()

    # Get model name from path
    # model_path = f"{data_root}/DeepSeek-R1-Distill-Qwen-7B"
    model_path = f"{data_root}/Qwen3-8B"
    model_name = os.path.basename(model_path)

    model = CustomAgent(
        model_path=model_path, 
        is_anyprecision=False, 
        device=args.device,
        insert_token_a=a,
        wait_threshold_b=b
    )

    xverify_path = f"{data_root}/xVerify-9B-C" if args.xverify else None
    
    dataset = {
        "math500": Math500Dataset(dataset_path=f"{data_root}/MATH-500", prompt_type="better", shuffle=True, xverify_path=xverify_path, device=f"cuda:{cuda}"),
        "amc23": AMC23Dataset(dataset_path=f"{data_root}/AMC-23", prompt_type="better", shuffle=True, xverify_path=xverify_path, device=f"cuda:{cuda}"),
        "gsm8k": Gsm8kDataset(dataset_path=f"{data_root}/gsm8k", prompt_type="better", shuffle=True, xverify_path=xverify_path, device=f"cuda:{cuda}"),
        "aime": AimeDataset(dataset_path=f"{data_root}/aime-2021-2025", prompt_type="better", shuffle=True, xverify_path=xverify_path, device=f"cuda:{cuda}"),
        "gpqa": GpqaDataset(dataset_path=f"{data_root}/gpqa_diamond_mc", prompt_type="better", shuffle=True, xverify_path=xverify_path, device=f"cuda:{cuda}"),
    }
    dataset_name = "math500"

    prompt = dataset[dataset_name].get_prompt(index=args.num_samples)

    answers = []
    lengths = []
    for i in tqdm(range(len(prompt)), desc="Processing prompt with custom token insertion"):
        # print(prompt[i])
        answer, length = model(prompt[i])
        print(f"Sample {i}: {answer}")
        answers.append(answer)
        lengths.append(length)
        
    # Get evaluation results based on xverify parameter
    if args.xverify:
        results, correct_list = dataset[dataset_name].eval_xverify(answers)
    else:
        results = dataset[dataset_name].result_eval(answers)
        correct_list = None
    
    mean_length = sum(lengths) / len(lengths)

    results_dir = "./qwen38"
    os.makedirs(results_dir, exist_ok=True)

    param_suffix = ""
    if a is not None:
        param_suffix += f"_a{a}"
    if b is not None:
        param_suffix += f"_b{b}"
    
    eval_suffix = "_xverify" if args.xverify else "_baseline"
    result_file = os.path.join(results_dir, f"{dataset_name}_{model_name}_resultsnum9{eval_suffix}{param_suffix}.txt")

    with open(result_file, "w", encoding='utf-8') as f:
        f.write(f"Evaluation Results:\n")
        f.write("=" * 50 + "\n")
        f.write(f"Dataset: {dataset_name}\n")
        f.write(f"Model: {model_name}\n")
        f.write(f"Parameter a (insert </think> at token): {a}\n")
        f.write(f"Parameter b (wait threshold): {b}\n")
        f.write(f"Use xverify: {args.xverify}\n")
        f.write(f"Accuracy: {results}\n")
        f.write(f"Mean Length: {mean_length}\n")
        f.write("\n" + "=" * 50 + "\n")

if __name__ == "__main__":
    main()
