import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from PIL import Image
import math

def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]

def get_question_type(question_data):
    """根据问题和答案的格式猜测问题类型"""
    question_text = question_data.get("question", "").lower()
    answer_text = question_data.get("answer", "").lower()

  
    if "b" in question_text and "a" in question_text:
        return "multiple_choice"
    
    if answer_text in ["yes", "no"]:
        return "yes_no"
    

    if answer_text.isdigit():
        return "numerical"
        
 
    return "open_ended"

def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]
from PIL import Image

def stitch_images(image_paths, direction='horizontal', gap=10, bg_color=(255, 255, 255)):
    """
    Stitches multiple images together.

    :param image_paths: List of paths to images.
    :param direction: 'horizontal' or 'vertical'.
    :param gap: Gap between images in pixels.
    :param bg_color: Background color for the gap.
    :return: A single stitched PIL Image object.
    """
    images = [Image.open(p).convert('RGB') for p in image_paths]

    if not images:
        return None

    if direction == 'horizontal':
        
        min_height = min(img.height for img in images)
        images = [img.resize((int(img.width * min_height / img.height), min_height)) for img in images]

        total_width = sum(img.width for img in images) + gap * (len(images) - 1)
        max_height = max(img.height for img in images)

        stitched_image = Image.new('RGB', (total_width, max_height), bg_color)

        current_x = 0
        for img in images:
            stitched_image.paste(img, (current_x, 0))
            current_x += img.width + gap
    else: # vertical
        
        min_width = min(img.width for img in images)
        images = [img.resize((min_width, int(img.height * min_width / img.width))) for img in images]

        max_width = max(img.width for img in images)
        total_height = sum(img.height for img in images) + gap * (len(images) - 1)

        stitched_image = Image.new('RGB', (max_width, total_height), bg_color)

        current_y = 0
        for img in images:
            stitched_image.paste(img, (0, current_y))
            current_y += img.height + gap

    return stitched_image

def eval_model(args, torch_dtype):
    # Setup
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    
    # Load Model
    try:
        tokenizer, model, image_processor, context_len = load_pretrained_model(
            model_path,
            args.model_base,
            model_name,
            device_map="auto",
            torch_dtype=torch_dtype,
            load_4bit=args.load_4bit
        )

        print(f" Checking tokenizer and model compatibility...")
        print(f"   Tokenizer vocab size: {len(tokenizer)}")
        print(f"   Model vocab size: {model.config.vocab_size}")
        print(f"   Model embedding size: {model.get_input_embeddings().weight.shape[0]}")
        print(f"   Model lm_head size: {model.get_output_embeddings().weight.shape[0]}")


        embed_weights = model.get_input_embeddings().weight
        if embed_weights.std().item() < 1e-6:
            print(f" Embedding weights are corrupted (all zeros)! Attempting to fix...")

            try:
                from transformers import AutoModelForCausalLM
                print(f"   Loading base model to get clean weights...")
                base_model = AutoModelForCausalLM.from_pretrained(
                    args.model_base,
                    torch_dtype=torch_dtype,
                    device_map="auto"
                )

                # Handle vocab size mismatch
                model_vocab_size = model.get_input_embeddings().weight.shape[0]
                base_vocab_size = base_model.get_input_embeddings().weight.shape[0]

                print(f"   Model vocab: {model_vocab_size}, Base vocab: {base_vocab_size}")

                if model_vocab_size != base_vocab_size:
                    print(f"   🔧 Handling vocab size mismatch...")
                    min_vocab_size = min(model_vocab_size, base_vocab_size)

                    with torch.no_grad():
                        # Copy overlapping embeddings
                        model.get_input_embeddings().weight[:min_vocab_size].copy_(
                            base_model.get_input_embeddings().weight[:min_vocab_size]
                        )

                        # Initialize extra tokens with mean embedding
                        if model_vocab_size > base_vocab_size:
                            mean_embedding = base_model.get_input_embeddings().weight.mean(dim=0, keepdim=True)
                            model.get_input_embeddings().weight[base_vocab_size:].copy_(
                                mean_embedding.expand(model_vocab_size - base_vocab_size, -1)
                            )

                        # Fix lm_head similarly
                        model.lm_head.weight[:min_vocab_size].copy_(
                            base_model.lm_head.weight[:min_vocab_size]
                        )

                        if model_vocab_size > base_vocab_size:
                            mean_lm_head = base_model.lm_head.weight.mean(dim=0, keepdim=True)
                            model.lm_head.weight[base_vocab_size:].copy_(
                                mean_lm_head.expand(model_vocab_size - base_vocab_size, -1)
                            )
                else:
                    # Direct copy if sizes match
                    with torch.no_grad():
                        model.get_input_embeddings().weight.copy_(base_model.get_input_embeddings().weight)
                        model.lm_head.weight.copy_(base_model.lm_head.weight)

                print(f"   Successfully fixed embedding and lm_head weights!")

                # Verify fix
                fixed_embed_std = model.get_input_embeddings().weight.std().item()
                print(f"  Fixed embedding std: {fixed_embed_std:.6f}")

                del base_model  # Free memory
                torch.cuda.empty_cache()

            except Exception as e:
                print(f"   Failed to fix embeddings: {e}")
                print(f"    Model may not work properly!")
        else:
            print(f"  Embedding weights look normal (std: {embed_weights.std().item():.6f})")

        if len(tokenizer) != model.get_output_embeddings().weight.shape[0]:
            print(f" Vocab size mismatch detected! Resizing model embeddings...")
            model.resize_token_embeddings(len(tokenizer))
            print(f"Model embeddings resized to {len(tokenizer)}")

        def create_safe_attention_forward(original_forward):
            def safe_forward(self, *args, **kwargs):
                try:
                    outputs = original_forward(*args, **kwargs)

                    # Check for NaN in outputs and replace with zeros
                    if isinstance(outputs, tuple):
                        fixed_outputs = []
                        for i, output in enumerate(outputs):
                            if isinstance(output, torch.Tensor) and torch.isnan(output).any():
                                output = torch.where(torch.isnan(output), torch.zeros_like(output), output)
                            fixed_outputs.append(output)
                        return tuple(fixed_outputs)
                    elif isinstance(outputs, torch.Tensor) and torch.isnan(outputs).any():
                        return torch.where(torch.isnan(outputs), torch.zeros_like(outputs), outputs)

                    return outputs

                except Exception as e:
                    print(f"   ❌ Error in attention forward: {e}")
                    # Return zero tensor as fallback
                    if hasattr(self, 'hidden_size'):
                        batch_size = args[0].shape[0] if args else 1
                        seq_len = args[0].shape[1] if args else 1
                        return torch.zeros(batch_size, seq_len, self.hidden_size,
                                         device=args[0].device if args else 'cuda',
                                         dtype=torch.float16)
                    else:
                        raise e

            return safe_forward

        # Patch attention layers
        patched_count = 0
        if hasattr(model, 'get_model'):
            base_model = model.get_model()
            if hasattr(base_model, 'layers'):
                for i, layer in enumerate(base_model.layers):
                    if hasattr(layer, 'self_attn'):
                        original_forward = layer.self_attn.forward
                        layer.self_attn.forward = create_safe_attention_forward(original_forward).__get__(layer.self_attn, type(layer.self_attn))
                        patched_count += 1

        print(f" Patched {patched_count} attention layers with NaN protection")

    except Exception as e:
        print(f" ERROR loading model: {str(e)}")
        print("\nTroubleshooting tips:")
        print("1. For LoRA models, ensure you provide --model-base pointing to the base model")
        print("2. Check that all required model files exist in the specified directories")
        print("3. Verify that the model path contains the necessary tokenizer files")
        return
    non_lora_trainables_path = os.path.join(model_path, "non_lora_trainables.bin")
    if os.path.exists(non_lora_trainables_path):
        print(" Found and loading non-lora trainable weights...")
        non_lora_state_dict = torch.load(non_lora_trainables_path, map_location='cpu')
        incompatible_keys = model.load_state_dict(non_lora_state_dict, strict=False)
        
        if incompatible_keys.missing_keys:
            print(f"WARN: Some keys were not found in the model: {incompatible_keys.missing_keys}")
        if incompatible_keys.unexpected_keys:
            print(f"WARN: Some keys from the file were not used: {incompatible_keys.unexpected_keys}")
            
        print(" Successfully loaded non-lora trainable weights.")
    else:
        print("non_lora_trainables.bin not found. Skipping.")

    if hasattr(model, 'lm_head'):
        print(" Checking lm_head state...")

        test_input = torch.tensor([[1, 2, 3]], device=model.device)
        with torch.no_grad():
            test_logits = model(test_input).logits

        if test_logits.std().item() < 1e-6:
            print("  LM head appears broken, keeping original dtype...")
        else:
            print(" LM head working normally")
            print(" Keeping lm_head in original dtype to avoid type mismatch")

    if hasattr(model, 'config') and hasattr(model.config, 'mm_vision_tower'):
        vision_tower_path = model.config.mm_vision_tower
        print(f"Path to vision tower from model config: '{vision_tower_path}'")
        
        if vision_tower_path and os.path.exists(vision_tower_path):
            print(f"Vision tower path CHECK: OK, path exists.")
        else:
            print(f"Vision tower path CHECK: FAILED, path does not exist or is empty!")
    else:
        print("Model config does not have the 'mm_vision_tower' attribute.")
        
    questions = json.load(open(os.path.expanduser(args.question_file), "r"))
    questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
    
    # Prepare answers file
    answers_file = os.path.expanduser(args.answers_file)
    answers_dir = os.path.dirname(answers_file)
    if answers_dir:  
        os.makedirs(answers_dir, exist_ok=True)
    ans_file = open(answers_file, "w")
    
    # Process each question
    for line in tqdm(questions):
        idx = line["index"]
# line 238 (and replace the following lines)
        qs = line["question"]  # "question" 字段已包含所有内容
        cur_prompt = qs
        question_type = get_question_type(line) 
        # Process image if it exists
        image = None
        if 'img_paths' in line and line['img_paths']:
            original_paths = line['img_paths']

            if len(original_paths) > 1:
                
                image_paths = original_paths
                print(f"INFO: Stitching {len(image_paths)} images together.")
                image = stitch_images(image_paths, direction='horizontal')
            elif original_paths:
                
                try:
                    image = Image.open(original_paths[0]).convert('RGB')
                except FileNotFoundError:
                    print(f"错误：在 {original_paths[0]} 未找到图片")
                    image = None
            # line 305
            if image is not None and str(line.get('index_origin', '')).endswith('-flip'):
                print(f"INFO: Dynamically flipping image for index {line['index']} based on '-flip' flag.")
             
                image = image.transpose(Image.FLIP_LEFT_RIGHT)

        
        if image:
            images_tensor = process_images([image], image_processor, model.config)[0]
            images = images_tensor.unsqueeze(0).to(model.device, dtype=torch_dtype)
            image_sizes = [image.size]

         
            if getattr(model.config, 'mm_use_im_start_end', False):
                qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
            else:
                qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
            cur_prompt = '<image>' + '\n' + cur_prompt
        else:
            images = None
            image_sizes = None

        instruction = "" 
        
       
        if question_type == "multiple_choice":
            print(f"INFO: Detected Multiple Choice Question. Applying specific prompt.")
            instruction = (
                "\nLet's think step by step.Based on the question and image, select the single best option.Please select the most appropriate answer from options "
                "\nFinal Answer:"
            )
        elif question_type == "yes_no":
            print(f"INFO: Detected Yes/No Question. Applying specific prompt.")
            instruction = "\nLet's think step by step.Please answer with only 'Yes' or 'No'."
        elif question_type == "numerical":
            print(f"INFO: Detected Numerical Question. Applying specific prompt.")
            instruction = "\nLet's think step by step. The question involves measuring the precise distance in 3D space through a 2D image.Please provide a numerical answer."
        
        if instruction:
            qs = qs + instruction
            cur_prompt = cur_prompt + instruction
        # Build conversation prompt
        conv = conv_templates[args.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        # Tokenize the prompt
        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

        # Generate response
        with torch.inference_mode():
            
            print(f" Generation debug info:")
            print(f"   Input IDs shape: {input_ids.shape}")
            print(f"   Input IDs range: [{input_ids.min().item()}, {input_ids.max().item()}]")
            print(f"   Tokenizer vocab size: {len(tokenizer)}")
            print(f"   Model vocab size: {model.config.vocab_size}")

           
            if input_ids.max().item() >= len(tokenizer) or input_ids.min().item() < 0:
                print(f"  WARNING: Input IDs contain tokens outside vocab range!")
                print(f"   Input ID range: [{input_ids.min().item()}, {input_ids.max().item()}]")
                print(f"   Vocab size: {len(tokenizer)}")

                mask_too_large = input_ids >= len(tokenizer)
                if mask_too_large.any():
                    print(f"   Clamping {mask_too_large.sum().item()} tokens that are too large")
                    input_ids = torch.where(mask_too_large, len(tokenizer) - 1, input_ids)

                print(f"   Final input ID range: [{input_ids.min().item()}, {input_ids.max().item()}]")

            if images is not None:
                print(f"🔍 Testing forward pass for NaN values...")
                try:
                    with torch.no_grad():
                        test_outputs = model(input_ids, images=images, image_sizes=image_sizes)
                        test_logits = test_outputs.logits

                    if torch.isnan(test_logits).any():
                        print(f"  WARNING: Forward pass produces NaN! Trying text-only generation...")
                        text_only_input = input_ids.clone()
                        image_token_mask = text_only_input == IMAGE_TOKEN_INDEX
                        if image_token_mask.any():
                            image_word_tokens = tokenizer.encode("image", add_special_tokens=False)
                            first_image_pos = torch.where(image_token_mask)[1][0]
                            text_only_input[0, first_image_pos] = image_word_tokens[0] if image_word_tokens else tokenizer.unk_token_id

                        print(f"   Fallback to text-only generation...")
                        output_ids = model.generate(
                            text_only_input,
                            do_sample=False,
                            #temperature=args.temperature,
                            max_new_tokens=128,
                            use_cache=True,
                            pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
                            eos_token_id=tokenizer.eos_token_id,
                        )
                    else:
                        print(f" Forward pass clean, proceeding with multimodal generation...")
                        output_ids = model.generate(
                            input_ids,
                            images=images,
                            image_sizes=image_sizes,
                            do_sample=True,
                            temperature=args.temperature,
                            top_p=0.95,
                            top_k=50,
                            max_new_tokens=128,
                            use_cache=True,
                            pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
                            eos_token_id=tokenizer.eos_token_id,
                        )
                except Exception as e:
                    print(f"  Forward pass test failed: {e}")
                    print(f"   Falling back to text-only generation...")
        
                    text_only_input = input_ids.clone()
                    image_token_mask = text_only_input == IMAGE_TOKEN_INDEX
                    if image_token_mask.any():
                        text_only_input = text_only_input[~image_token_mask].unsqueeze(0)

                    output_ids = model.generate(
                        text_only_input,
                        do_sample=False,
                        temperature=args.temperature,
                        max_new_tokens=512,
                        use_cache=True,
                        pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
                        eos_token_id=tokenizer.eos_token_id,
                    )
            else:
          
                output_ids = model.generate(
                    input_ids,
                    do_sample=False,
                    temperature=args.temperature,
                    max_new_tokens=128,
                    use_cache=True,
                    pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                )
        print("output_ids:", output_ids)
        print("raw decode (with special tokens):", tokenizer.batch_decode(output_ids, skip_special_tokens=False)[0])
        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

        # Save the answer
        ans_id = shortuuid.uuid()
        ans_dict = {
            "question_id": idx,
            "prompt": cur_prompt,
            "text": outputs,
            "answer_id": ans_id,
            "model_id": model_name,
            "metadata": {}
        }
        print("=" * 50)
        print(f"[QID] {idx}")
        print(f"[Prompt]\n{cur_prompt}")
        print(f"[Output]\n{outputs}")
        print("=" * 50)

        ans_file.write(json.dumps(ans_dict) + "\n")
        ans_file.flush()

    ans_file.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="")
    parser.add_argument("--question-file", type=str, default="tables/question.json")
    parser.add_argument("--answers-file", type=str, default="answer.jsonl")
    parser.add_argument("--conv-mode", type=str, default="llava_v0")
    parser.add_argument("--num-chunks", type=int, default=1)
    parser.add_argument("--chunk-idx", type=int, default=0)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--answer-prompter", action="store_true")
    parser.add_argument("--single-pred-prompt", action="store_true")
    parser.add_argument("--load-4bit", action="store_true", help="Load model in 4-bit quantization")
    parser.add_argument("--fp16", action="store_true", help="Use float16 for inference.")
    parser.add_argument("--bf16", action="store_true", help="Use bfloat16 for inference.")
    args = parser.parse_args()

    if args.fp16 and args.bf16:
        raise ValueError("Cannot use both --fp16 and --bf16.")
    
    # Determine torch dtype for inference
    torch_dtype = torch.float32
    if args.fp16:
        torch_dtype = torch.float16
    elif args.bf16:
        torch_dtype = torch.bfloat16

    eval_model(args, torch_dtype)