import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from pickle import FALSE
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModelForVision2Seq
from simple_input_processor import  load_and_process_existing_data, prepare_images_for_processor
import json
from peft import PeftModel
from tqdm import tqdm

class Qwen2VLInference:
    def __init__(self, model_path: str, lora_path: str = None):
       
        self.model_path = model_path
        self.lora_path = lora_path
        
  
        self.processor = AutoProcessor.from_pretrained(
            model_path, 
            trust_remote_code=True
        )

        try:
            self.model = AutoModelForVision2Seq.from_pretrained(
                model_path,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=True
            )
            print("base model loaded successfully")
        except RuntimeError as e:
            if "size mismatch" in str(e):
                print(f"detected size mismatch error: {e}")
                print("trying to load with ignore_mismatched_sizes=True...")
                self.model = AutoModelForVision2Seq.from_pretrained(
                    model_path,
                    torch_dtype=torch.bfloat16,
                    device_map="auto",
                    trust_remote_code=True,
                )
                print("base model loaded successfully (ignore size mismatch)")
            else:
                raise e
        

        if lora_path:
            print(f"loading LoRA weights: {lora_path}")
            try:
                self.model = PeftModel.from_pretrained(
                    self.model,
                    lora_path,
                    device_map="auto"
                )
                self.model.eval()
                print("LoRA weights loaded successfully")
            except Exception as e:
                print(f"LoRA weights loading failed: {e}")
                print("using base model for inference")
        
        print("model loaded successfully!")
    
    def inference_single(self, messages: list, images: list = None, max_new_tokens: int = 128):
       
       
        image_objects = None
        if images:
            image_objects = prepare_images_for_processor(images)
        
        text = self._format_messages(messages)      

        if image_objects:
            inputs = self.processor(
                text=[text],
                images=image_objects,
                return_tensors="pt"
            )
        else:
            inputs = self.processor(
                text=[text],
                return_tensors="pt"
            )

        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
   
        with torch.no_grad():
            generated_ids = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=self.processor.tokenizer.eos_token_id
            )

        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs["input_ids"], generated_ids)
        ]
        
        response = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return response.strip()
    
    def _format_messages(self, messages: list) -> str:
       
        formatted_text = ""
        
        for i, message in enumerate(messages):
            role = message["role"]
            content = message["content"]
            
            if i == 0 and role != "system":
                formatted_text += "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
            
            if "<image>" in content:
                content = content.replace("<image>", "<|vision_start|><|image_pad|><|vision_end|>")
            
            formatted_text += f"<|im_start|>{role}\n{content}<|im_end|>\n"
        
        formatted_text += "<|im_start|>assistant\n"
        
        return formatted_text
    
    def test_webshop_case(self, instruction: str, observation: str, image_path: str):
       
        from simple_input_processor import create_webshop_test_case
        
        test_case = create_webshop_test_case(
            instruction=instruction,
            observation=observation,
            image_path=image_path
        )
        
        messages = test_case["input"]["messages"]
        images = test_case["input"]["images"]
        
        response = self.inference_single(messages, images)
        
        return response
    
    def batch_inference(self, test_cases: list):
        
        results = []
        
        for i, case in enumerate(tqdm(test_cases, desc="batch inference progress", unit="case")):
            try:
                messages = case["input"]["messages"]
                images = case["input"]["images"]
                
                response = self.inference_single(messages, images)
                
                results.append({
                    "case_id": i,
                    "input": case,
                    "response": response,
                    "expected": case.get("expected_action")
                })
                
            except Exception as e:
                tqdm.write(f"error in processing the {i+1}th test case: {e}")
                results.append({
                    "case_id": i,
                    "input": case,
                    "response": f"Error: {e}",
                    "expected": case.get("expected_action")
                })
        
        return results

inference = Qwen2VLInference(
    model_path="path/to/Qwen2.5-VL-7B-Instruct",
    lora_path="path/to/cotri-vl-model"  
)


print("=" * 60)



test_cases = load_and_process_existing_data(
    "path/to/webshop_3rd_qwen_eval_data_copier.json",
    max_messages=7,  
    max_images=3   
)


sample_cases = test_cases[:]

results = inference.batch_inference(sample_cases)

with open("inference_results.json", 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print(f"inference completed, results saved to inference_results.json")  