import os
import re
import random
import argparse
import glob
import json
import torch
from typing import List, Dict, Any
from tqdm import tqdm
from datasets import load_dataset
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from human_eval.data import write_jsonl, read_problems, stream_jsonl


def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
    os.environ['PYTHONHASHSEED'] = str(seed)


def extract_final_answer(answer):
    boxed_match = re.search(r'\\boxed\{([^}]+)\}', answer)
    if boxed_match:
        answer = boxed_match.group(1)
    
    clean_text = re.sub(r'[$]', '', answer)
    pattern = r'-?\d{1,3}(?:,\d{3})+(?:\.\d+)?|-?\d+/\d+|-?\d+(?:\.\d+)?'
    numbers = re.findall(pattern, clean_text)
    
    if not numbers:
        return None
    
    last_num = numbers[-1].replace(',', '')
    
    if '/' in last_num:
        try:
            parts = last_num.split('/')
            result = float(parts[0]) / float(parts[1])
            return str(int(result)) if result.is_integer() else str(result)
        except:
            return last_num
    
    try:
        num = float(last_num)
        return str(int(num)) if num.is_integer() else str(num)
    except:
        return last_num

def create_chat_prompt(system_prompt: str, fewshot_examples: List[Dict], test_question: str, tokenizer, model_name: str) -> str:
    messages = [{"role": "system", "content": system_prompt}]
    
    for example in fewshot_examples:
        example_prompt = f"{example['prompt']}\n\nTest examples:"
        for test_example in example['test_list']:
            example_prompt += f"\n{test_example}"
        
        messages.append({
            "role": "user", 
            "content": f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{example_prompt}\n```\n"
        })
        messages.append({
            "role": "assistant", 
            "content": example['code']
        })

    formatted_prompt = f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{test_question}\n```\n"
    messages.append({"role": "user", "content": formatted_prompt})
    print("Final prompt:", formatted_prompt)
    if model_name == "qwen3":
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
    else:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def single_generate_with_hf(
    model,
    tokenizer,
    prompts: List[str],
    generation_config: GenerationConfig,
    device: torch.device,
    max_model_len: int = 4096
) -> List[str]:
    all_outputs = []
    
    for i, prompt in enumerate(tqdm(prompts, desc="Generating completions")):
        try:
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_model_len)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    generation_config=generation_config,
                    pad_token_id=tokenizer.pad_token_id,
                )
                
                input_length = inputs["input_ids"].shape[1]
                generated_tokens = outputs[0][input_length:]
                generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
                all_outputs.append(generated_text)
                
        except Exception as e:
            print(f"Error generating for prompt {i}: {e}")
            all_outputs.append("")
    return all_outputs

def postprocess_completion(completion):
    completion = completion.replace("\r", "")
    completion = completion.strip()
    if '```python' in completion:
        print("completion matches ```python")
        def_line = completion.index('```python')
        completion = completion[def_line:].strip()
        completion = completion.replace('```python', '')
        try:
            next_line = completion.index('```')
            completion = completion[:next_line].strip()
        except:
            print("wrong completion")
    if "__name__ == \"__main__\"" in completion:
        print("completion matches __name__ == \"__main__\"")
        try:
            next_line = completion.index('if __name__ == "__main__":')
            completion = completion[:next_line].strip()
        except:
            print("wrong completion")
    if "# Example usage" in completion:
        print("completion matches # Example usage")
        next_line = completion.index('# Example usage')
        completion = completion[:next_line].strip()
    if "The solution is:" in completion:
        print("completion matches The solution is:")
        def_line = completion.index("The solution is:")
        completion = completion[def_line:].strip()
        completion = completion.replace('The solution is:', '')
        try:
            next_line = completion.index('\n\nThe answer is:')
            completion = completion[:next_line].strip()
        except:
            completion = completion.strip()
            print("maybe wrong completion")
    if "The answer is:" in completion:
        print("completion matches The answer is:")
        def_line = completion.index("The answer is:")
        completion = completion[def_line:].strip()
        completion = completion.replace('The answer is:', '')
        try:
            next_line = completion.index('\n\nThe answer is:')
            completion = completion[:next_line].strip()
        except:
            completion = completion.strip()
            print("maybe wrong completion")
    
    return completion

def evaluate_model_hf(args, model, tokenizer, generation_config: GenerationConfig, device: torch.device,
                     seed: int = 42, log_file: str = "evaluation_log.txt", num_shots: int = 3):

    print("Starting model evaluation with HuggingFace Transformers...")
    if 'OLMoE-1B' in args.model_path:
        model_name = 'olmoe'
    elif 'Qwen3' in args.model_path:
        model_name = 'qwen3'
    elif 'DeepSeek' in args.model_path:
        model_name = 'deepseek'
    else:
        print("No config for model name.")
        model_name = 'unknown'
    
    problems = read_problems()
    task_ids = sorted(problems.keys())
    prompts = [problems[task_id]['prompt'] for task_id in task_ids][args.start:args.end]

    fewshot_examples_data = []
    if num_shots > 0:
        try:
            prompt_dataset = load_dataset("nlile/mbpp", "sanitized")['prompt']
            fewshot_pool = []
            for example in prompt_dataset:
                fewshot_pool.append({
                    'prompt': example['prompt'],
                    'code': example['code'],
                    'test_list': example['test_list']
                })
            
            if len(fewshot_pool) >= num_shots:
                fewshot_examples_data = random.sample(fewshot_pool, num_shots)
            else:
                fewshot_examples_data = fewshot_pool
                print(f"Warning: Only {len(fewshot_pool)} examples available for few-shot")
            
            print(f"Loaded {len(fewshot_examples_data)} examples for few-shot learning")
        except Exception as e:
            print(f"Error loading few-shot examples: {e}")
            fewshot_examples_data = []
    
    system_prompt = "You are a helpful assistant."
    all_prompts = []
    
    for i, prompt in enumerate(prompts):
        prompt = prompt.replace('    ', '\t')
        chat_prompt = create_chat_prompt(system_prompt, fewshot_examples_data, prompt, tokenizer, model_name)
        all_prompts.append(chat_prompt)
    
    print("Starting single generation with HuggingFace...")
    completions = single_generate_with_hf(
        model, tokenizer, all_prompts, generation_config, device, max_model_len=1024
    )
    
    completion_seqs = []
    for i, (task_id, completion) in enumerate(zip(task_ids, completions)):
        print(f"Processing {i+1}/{len(completions)}: {task_id}")
        print(f"Raw completion: {completion[:200]}...")
        all_code = completion.split("### Response:")[-1]
        completion = all_code.replace('\t', '    ')

        completion_seq = postprocess_completion(completion)
         
        completion_seqs.append({
            'task_id': task_id,
            'completion': completion_seq,
            'all_code': all_code,
        })
        
        if (i + 1) % 10 == 0:
            output_file = f"{args.output_dir}/{args.prefix}_{i+1}.jsonl"
            write_jsonl(output_file, completion_seqs[-10:])
    
    output_file = f"{args.output_dir}/{args.prefix}_{args.start}_{args.end}.jsonl"
    write_jsonl(output_file, completion_seqs)
    merge_output_files(args)
    
    with open(log_file, "a") as f:
        f.write(f"Number of evaluation samples: {len(completion_seqs)}\n")
        f.write(f"Few-shot examples: {num_shots}\n")
        f.write(f"Inference method: HuggingFace Transformers (batch_size=1)\n\n")

def merge_output_files(args):
    sorted_files = sorted(glob.glob(f"{args.output_dir}/*.jsonl"))
    outputs = []
    
    for code_file in tqdm(sorted_files, total=len(sorted_files)):
        codes = [c for c in stream_jsonl(code_file)]
        for code in codes:
            completion = postprocess_completion(code['completion'])
            code['completion'] = completion
        outputs += codes

    final_output = f"{args.output_dir}/{args.prefix}_final.jsonl"
    print(f"Saving final results to {final_output}")
    write_jsonl(final_output, outputs)




def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True, help="Path to the model")
    parser.add_argument("--eval_samples", type=int, default=None, help="Number of samples to evaluate")
    parser.add_argument("--split", type=str, default="test", choices=["test", "train"], help="Dataset split")
    parser.add_argument("--prefix", type=str, default="")
    parser.add_argument("--num_shots", type=int, default=0, help="Number of few-shot examples")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size for vLLM generation")

    parser.add_argument("--start", type=int, help="Start index for evaluation")
    parser.add_argument("--end", type=int, help="End index for evaluation")
    
    # Generation parameters
    parser.add_argument("--do_sample", action="store_true", help="Whether to use sampling")
    parser.add_argument("--temperature", type=float, default=None, help="Generation temperature")
    parser.add_argument("--max_new_tokens", type=int, default=None, help="Maximum new tokens to generate")
    parser.add_argument("--top_p", type=float, default=None, help="Top-p (nucleus) sampling parameter")
    parser.add_argument("--top_k", type=int, default=None, help="Top-k sampling parameter")
    
    # File paths
    parser.add_argument("--log_dir", type=str, required=True, help="Directory for log files")
    parser.add_argument("--output_dir", type=str, required=True, help="Directory for output files")
    
    args = parser.parse_args()
    seed_everything(args.seed)


    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    
    if True:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            args.model_path,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
            device_map="auto" if torch.cuda.is_available() else None,
        )
        
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        generation_config = GenerationConfig(
            temperature=args.temperature,
            top_p=args.top_p if args.temperature > 0 else None,
            top_k=args.top_k if args.top_k > 0 else None,
            max_new_tokens=args.max_new_tokens,
            do_sample=args.temperature > 0,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        generation_config.pad_token_id = tokenizer.pad_token_id
        generation_config.eos_token_id = tokenizer.eos_token_id
        
        print(f"Generation config: {generation_config}")
        
        log_file = os.path.join(args.log_dir, f"{args.prefix}_hf_fewshot_{args.num_shots}_batch1_log.txt")
        
        with open(log_file, "w") as f:
            f.write(f"Model Path: {args.model_path}\n")
            f.write(f"Inference method: HuggingFace Transformers\n")
            f.write(f"Few-shot examples: {args.num_shots}\n")
            f.write(f"Batch size: 1\n")
            f.write(f"Device: {device}\n")
            f.write(f"Generation config: {generation_config}\n")
            f.write(f"Eval Samples: {'All' if args.eval_samples is None else args.eval_samples}\n\n")
        
        evaluate_model_hf(
            args,
            model,
            tokenizer,
            generation_config,
            device,
            seed=args.seed,
            log_file=log_file,
            num_shots=args.num_shots
        )
    
    else:
        pass
        
    
    print("Evaluation completed!")

if __name__ == "__main__":
    main()