import argparse
import json
import re
from vllm import LLM, SamplingParams
import sys
import torch
import gc
import wandb
from tqdm.auto import tqdm
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
MAX_INT = sys.maxsize


def extract_answer(dataset: str, sentence: str) -> str:
    """Extract the answer from model output based on dataset type."""
    sentence_ = sentence.strip().lower()
    
    if dataset == 'boolq':
        pred_answers = re.findall(r'true|false', sentence_)
    elif dataset == 'piqa':
        pred_answers = re.findall(r'solution1|solution2', sentence_)
    elif dataset in ['social_i_qa', 'ARC-Challenge', 'ARC-Easy', 'openbookqa']:
        pred_answers = re.findall(r'answer1|answer2|answer3|answer4|answer5', sentence_)
    elif dataset == 'hellaswag':
        pred_answers = re.findall(r'ending1|ending2|ending3|ending4', sentence_)
    elif dataset == 'winogrande':
        pred_answers = re.findall(r'option1|option2', sentence_)
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")
        
    return pred_answers[0] if pred_answers else ""


def batch_data(data_list, batch_size=1):
    """Split data into batches."""
    n = len(data_list) // batch_size
    batch_data = []
    for i in range(n-1):
        start = i * batch_size
        end = (i+1)*batch_size
        batch_data.append(data_list[start:end])

    last_start = (n-1) * batch_size
    last_end = MAX_INT
    batch_data.append(data_list[last_start:last_end])
    return batch_data


def generate_prompt(instruction, input=None):
    """Generate prompt in the standard format."""
    if input:
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Input:
{input}

### Response:
"""
    else:
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
"""


def commonsense_test_fallback(model, dataset_name, data_path, start=0, end=MAX_INT, batch_size=1):
    """Fallback evaluation function using transformers pipeline when vLLM fails."""
    print("Using fallback evaluation with transformers pipeline...")
    
    # Clear memory first
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    gc.collect()

    # Load dataset
    with open(data_path, 'r') as f:
        dataset = json.load(f)
    
    dataset = dataset[start:end]
    instructions = [data.get('instruction') for data in dataset]
    answers = [data.get('answer') for data in dataset]
    
    # Try to load model with transformers using direct generation (not pipeline)
    try:
        print(f"Loading model: {model}")
        
        # Load tokenizer first
        tokenizer = AutoTokenizer.from_pretrained(model)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            
        # Load model with optimized settings to avoid meta device issues
        model_obj = AutoModelForCausalLM.from_pretrained(
            model,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            low_cpu_mem_usage=True,
            trust_remote_code=True,
            max_memory={0: "70GiB"}  # Reserve some memory to avoid OOM
        )
        
        # Ensure model is in eval mode
        model_obj.eval()
        
        print(f"Successfully loaded model")
        print(f"Model device: {next(model_obj.parameters()).device}")
        
    except Exception as e:
        print(f"Failed to load model with transformers: {e}")
        return 0.0
    
    res_completions = []
    result = []
    invalid_outputs = []

    # Generate responses using direct model.generate (not pipeline)
    print("\nGenerating responses...")
    for i, instruction in enumerate(tqdm(instructions, desc="Generating responses", ncols=100)):
        try:
            prompt = generate_prompt(instruction)
            
            # Tokenize with conservative limits
            max_input_length = 1024  # Conservative limit
            inputs = tokenizer(
                prompt, 
                return_tensors="pt", 
                truncation=True, 
                max_length=max_input_length,
                padding=False
            )
            
            # Move inputs to model device
            device = next(model_obj.parameters()).device
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            # Generate with proper parameters
            with torch.no_grad():
                outputs = model_obj.generate(
                    **inputs,
                    max_new_tokens=32,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    use_cache=True,
                    attention_mask=inputs.get('attention_mask', None)
                )
            
            # Decode only the new tokens
            input_length = inputs['input_ids'].shape[1]
            generated_tokens = outputs[0][input_length:]
            generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
            
            res_completions.append(generated_text.strip())
            
            # Clear cache every 100 iterations to prevent memory buildup
            if (i + 1) % 100 == 0:
                torch.cuda.empty_cache()
                
        except Exception as e:
            print(f"Generation failed for instruction {i}: {str(e)[:100]}")
            res_completions.append("")

    # Clean up model to free memory
    del model_obj
    torch.cuda.empty_cache()
    gc.collect()

    # Evaluate responses
    print("\nEvaluating responses...")
    for idx, (instruction, completion, answer) in enumerate(
        tqdm(
            zip(instructions, res_completions, answers),
            total=len(instructions),
            desc="Evaluating answers",
            ncols=100
        )
    ):
        pred = extract_answer(dataset_name, completion)
        is_correct = (pred == answer)
        result.append(is_correct)
        
        if not is_correct and not pred:
            temp = {'instruction': instruction, 'output': completion, 'answer': answer, 'pred': pred}
            invalid_outputs.append(temp)

    # Calculate and log metrics
    acc = sum(result) / len(result) if len(result) > 0 else 0.0
    
    # Only log to wandb if it's not disabled
    if os.environ.get('WANDB_MODE', 'online') != 'disabled':
        try:
            wandb.log({
                f"eval/{dataset_name}_acc": acc,
            })
        except Exception as e:
            print(f"Failed to log to wandb: {e}")

    print(f'Invalid outputs count: {len(invalid_outputs)}')
    print(f'Evaluation range: start={start}, end={end}')
    print(f'Total evaluated: {len(result)}, Accuracy: {acc:.4f}')
    
    # Debug: Save some sample outputs for inspection
    if len(invalid_outputs) > 0:
        print("\nSample invalid outputs for debugging:")
        for i, sample in enumerate(invalid_outputs[:3]):
            print(f"Sample {i+1}:")
            print(f"  Output: '{sample['output']}'")
            print(f"  Expected: '{sample['answer']}'")
            print(f"  Predicted: '{sample['pred']}'")
    
    return acc


def commonsense_test(model, dataset_name, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1):
    """Main evaluation function for commonsense tasks."""
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    gc.collect()

    # Load dataset
    with open(data_path, 'r') as f:
        dataset = json.load(f)
    
    dataset = dataset[start:end]
    instructions = [data.get('instruction') for data in dataset]
    answers = [data.get('answer') for data in dataset]
    
    # Batch the instructions
    batch_instructions = batch_data(instructions, batch_size=batch_size)

    # Try to use vLLM first, fallback to transformers if it fails
    try:
        # Setup VLLM
        stop_tokens = ["Instruction:", "Instruction", "Response:", "Response"]
        sampling_params = SamplingParams(temperature=0.1, top_p=0.75, top_k=40, max_tokens=32, stop=stop_tokens)
        
        # 修复：添加对本地模型路径的支持
        try:
            # 检查是否是本地路径
            if os.path.isdir(model):
                # 对于本地路径，确保vLLM可以正确加载
                llm = LLM(
                    model=model, 
                    tensor_parallel_size=tensor_parallel_size,
                    trust_remote_code=True,
                    download_dir=None,
                    load_format="auto"
                )
            else:
                # 对于HuggingFace模型ID
                llm = LLM(model=model, tensor_parallel_size=tensor_parallel_size)
        except Exception as e:
            print(f"Failed to load model with vLLM: {e}")
            print("Trying alternative loading method...")
            # 备用方案：使用绝对路径
            if os.path.isdir(model):
                model_path = os.path.abspath(model)
                llm = LLM(
                    model=model_path, 
                    tensor_parallel_size=tensor_parallel_size,
                    trust_remote_code=True
                )
            else:
                raise e
        
        res_completions = []
        result = []
        invalid_outputs = []

        # Generate responses
        print("\nGenerating responses...")
        try:
            for idx, prompts in enumerate(
                tqdm(batch_instructions, 
                    total=len(batch_instructions), 
                    desc="Generating responses",
                    ncols=100)
            ):
                if not isinstance(prompts, list):
                    prompts = [prompts]
                    
                formatted_prompts = [generate_prompt(instruction) for instruction in prompts]
                completions = llm.generate(formatted_prompts, sampling_params)
                
                for output in completions:
                    generated_text = output.outputs[0].text
                    res_completions.append(generated_text)
        except Exception as e:
            print(f"vLLM generation failed: {e}")
            print("Falling back to transformers pipeline...")
            return commonsense_test_fallback(model, dataset_name, data_path, start, end, batch_size)

    except Exception as e:
        print(f"vLLM setup failed: {e}")
        print("Falling back to transformers pipeline...")
        return commonsense_test_fallback(model, dataset_name, data_path, start, end, batch_size)

    # Evaluate responses
    print("\nEvaluating responses...")
    for idx, (instruction, completion, answer) in enumerate(
        tqdm(
            zip(instructions, res_completions, answers),
            total=len(instructions),
            desc="Evaluating answers",
            ncols=100
        )
    ):
        pred = extract_answer(dataset_name, completion)
        is_correct = (pred == answer)
        result.append(is_correct)
        
        if not is_correct and not pred:
            temp = {'instruction': instruction, 'output': completion, 'answer': answer}
            invalid_outputs.append(temp)

    # Calculate and log metrics
    acc = sum(result) / len(result)
    
    # Only log to wandb if it's not disabled
    if os.environ.get('WANDB_MODE', 'online') != 'disabled':
        try:
            wandb.log({
                f"eval/{dataset_name}_acc": acc,
            })
        except Exception as e:
            print(f"Failed to log to wandb: {e}")

    print(f'Invalid outputs count: {len(invalid_outputs)}')
    print(f'Evaluation range: start={start}, end={end}')
    print(f'Total evaluated: {len(result)}, Accuracy: {acc:.4f}')
    
    return acc


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True,
                      help="Path to the model")
    parser.add_argument("--dataset", type=str, required=True,
                      choices=["boolq", "piqa", "social_i_qa", "hellaswag",
                              "winogrande", "ARC-Challenge", "ARC-Easy", "openbookqa"],
                      help="Dataset to evaluate on")
    parser.add_argument("--data_file", type=str, default=None,
                      help="Path to the dataset file")
    parser.add_argument("--start", type=int, default=0,
                      help="Start index for evaluation")
    parser.add_argument("--end", type=int, default=MAX_INT,
                      help="End index for evaluation")
    parser.add_argument("--batch_size", type=int, default=32,
                      help="Batch size for evaluation")
    parser.add_argument("--tensor_parallel_size", type=int, default=1,
                      help="Tensor parallel size for model")
    parser.add_argument("--run_dir", type=str,
                      help="Directory containing the wandb run ID")

    args = parser.parse_args()
    
    # Set default data file path if not provided
    if args.data_file is None:
        args.data_file = f'data/commonsense/{args.dataset}/test.json'

    # Initialize wandb with better error handling
    wandb_mode = os.environ.get('WANDB_MODE', 'online')
    
    if args.run_dir and wandb_mode != 'disabled':
        try:
            # Try to read existing run ID
            wandb_run_id = None
            try:
                with open(os.path.join(args.run_dir, "wandb_run_id.txt"), "r") as f:
                    wandb_run_id = f.read().strip()
            except FileNotFoundError:
                print("WandB run ID file not found, will create new run")
            
            # Initialize wandb with timeout and fallback
            if wandb_run_id:
                wandb.init(
                    id=wandb_run_id,
                    project="project_name",
                    resume="allow",
                    mode=wandb_mode,
                    settings=wandb.Settings(init_timeout=120)  # 2分钟超时
                )
            else:
                wandb.init(
                    project="project_name", 
                    mode=wandb_mode,
                    settings=wandb.Settings(init_timeout=120)
                )
                
        except Exception as e:
            print(f"Failed to initialize wandb: {e}")
            if wandb_mode == 'online':
                print("Falling back to offline mode...")
                try:
                    if wandb_run_id:
                        wandb.init(
                            id=wandb_run_id,
                            project="project_name",
                            resume="allow",
                            mode="offline"
                        )
                    else:
                        wandb.init(project="project_name", mode="offline")
                except Exception as e2:
                    print(f"Even offline mode failed: {e2}")
                    print("Disabling wandb logging...")
                    os.environ['WANDB_MODE'] = 'disabled'
    else:
        print("WandB logging disabled or no run directory specified")

    return args


if __name__ == "__main__":
    args = parse_args()
    commonsense_test(
        model=args.model,
        dataset_name=args.dataset,
        data_path=args.data_file,
        start=args.start,
        end=args.end,
        batch_size=args.batch_size,
        tensor_parallel_size=args.tensor_parallel_size
    )