import sys

from transformers.agents.llm_engine import MessageRole, HfApiEngine, get_clean_message_list
from sport_agent.utils import load_config
import re

import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor

from qwen_vl_utils import process_vision_info
from openai import OpenAI

import torch

def load_pretrained_model(model_name):
    torch.manual_seed(0)
    print("from pretrained", model_name)
    if "VL" in model_name:
        model = Qwen2VLForConditionalGeneration.from_pretrained(
            model_name, torch_dtype="auto", device_map="auto",
            attn_implementation="flash_attention_2",
        )
        processor = AutoProcessor.from_pretrained(model_name)
        return model, processor
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

def load_client_model(endpoint):
    openai_api_key = "EMPTY"
    openai_api_base = f"http://{endpoint}:8000/v1"

    client = OpenAI(
        api_key=openai_api_key,
        base_url=openai_api_base,
    )
    return client

class ModelSingleton():
    def __new__(cls, model_name, lora_path=None):
        if hasattr(cls, "model_dict") and model_name in cls.model_dict:
            return cls

        if not hasattr(cls, "model_dict"):
            cls.model_dict = dict()
            
        if "VL" in model_name:
            model, tokenizer = load_pretrained_model(model_name)
            if lora_path is not None:
                print("Load Qwen-VL from lora", lora_path)
                import time
                from peft.peft_model import PeftModel
                time.sleep(10)
                model = PeftModel.from_pretrained(model, lora_path)
                model.merge_and_unload()
            cls.model_dict[model_name] = (model, tokenizer)
            
        else:
            config = load_config()
            model = load_client_model(config.qwen.endpoint)
            tokenizer = None
            cls.model_dict[model_name] = (model, tokenizer)
        return cls

openai_role_conversions = {
    MessageRole.TOOL_RESPONSE: MessageRole.USER,
    # MessageRole.SYSTEM: MessageRole.USER
}

from typing import Optional
class QwenEngine(HfApiEngine):
    def __init__(self, model_name: str = "", lora_path: Optional[str] = None, beam_size=5):
        module = ModelSingleton(model_name, lora_path)
        self.has_vision = False
        model, tokenizer = module.model_dict[model_name]
        if 'VL' in model_name:
            self.has_vision = True
            self.processor = tokenizer # for VLM use processor as tokenizer
            
        self.model, self.tokenizer = model, tokenizer
        self.model_name = model_name
        self.beam_size = beam_size
    def call_llm(self, messages, stop_sequences=[], *args, **kwargs):
        assert not self.has_vision, "Should use this function with Qwen LLM"
        
        response = self.model.chat.completions.create(
            model=self.model_name,
            messages=messages,
            stop=stop_sequences,
            n = 3
        )        

        return response

    def call_vlm(self, messages, stop_sequences=[], *args, **kwargs):
        print("call vlm")
        assert self.has_vision, "Should use this function with Qwen VL model"
        image_paths = kwargs.get("image_paths", [])
        beam_size = kwargs.get("beam_size", 1)
        if beam_size > 1:
            for msg_id, msg in enumerate(messages):
                if msg["role"] == "user":
                    content_replace = []
                    if len(image_paths) == 1:
                        for image_path in image_paths:
                            content_replace.append({
                                "type": "image",
                                "image": image_path
                               
                            })
                        
                    if len(image_paths) > 1:
                        for image_path in image_paths:
                            content_replace.append({
                                "type": "image",
                                "image": image_path,
                                "min_pixels": 100 * 28 * 28,
                                "max_pixels": 512 * 28 * 28
                            })
                        
                    content = {"type": "text", "text": msg["content"]}
                    content_replace.append(content)
                    messages[msg_id] = {
                        "role": "user",
                        "content": content_replace
                    }
                    break
        else:
            for msg_id, msg in enumerate(messages):
                if msg["role"] == "user":
                    content_replace = []
                    if len(image_paths) > 0:
                        for image_path in image_paths:
                            content_replace.append({
                                "type": "image",
                                "image": image_path
                            })
                        
                    content = {"type": "text", "text": msg["content"]}
                    content_replace.append(content)
                    messages[msg_id] = {
                        "role": "user",
                        "content": content_replace
                    }
                    break
            
        print("msg=", messages)
        text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)

        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(self.model.device)
        num_return_sequences = beam_size
       
        if num_return_sequences > 1:
            generated_ids = self.model.generate(
            **inputs, 
            max_new_tokens=512,
            temperature=1.2,          # Higher for randomness
            top_p=0.8,                # Use nucleus sampling for diversity
            top_k=100,                 # Limits token sampling
            do_sample=True,           # Enables stochastic sampling
            # num_beams=1,              # No beam search when sampling
            # diversity_penalty=0.0,    # Not applicable for sampling
            num_return_sequences=num_return_sequences
        )
        if num_return_sequences == 1:
            generated_ids = self.model.generate(
            **inputs, 
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.8,
            top_k=100,
            do_sample=True,
            repetition_penalty=1.05,
            num_return_sequences=num_return_sequences
        )
       

        inputs.input_ids = inputs.input_ids.repeat(num_return_sequences, 1)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        print("output_text=", output_text)

        return output_text
    
    def __call__(self, messages, stop_sequences=[], *args, **kwargs) -> str:
        # print ('----------------raw message',messages)
        torch.cuda.empty_cache()
        image_paths = kwargs.get("image_paths", [])
        beam_size = kwargs.get("beam_size", 1)
        messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
        #print ('----------------processed message',messages)
        task = messages[0]
        msgs = []
        for msg in messages:
            # print(msg["role"].value)
            if msg["role"] == MessageRole.SYSTEM:
                msgs.append(
                    {
                        "role": "system",
                        "content": msg["content"]
                    }
                )
            else:
                msgs.append(
                    {
                        "role": "user" if msg["role"] == MessageRole.USER else "assistant",
                        "content": msg["content"]
                    }
                )
        if not self.has_vision:
            answers = self.call_llm(messages, stop_sequences=stop_sequences)
        else:
            answers = self.call_vlm(messages, stop_sequences=stop_sequences, image_paths=image_paths, beam_size=beam_size)
        print(answers)
        new_answers = []
        for ansidx, answer in enumerate(answers):
            for stop in stop_sequences:
                stop_idx = answer.find(stop)
                if stop_idx == -1:
                    continue
                answer = answer[:stop_idx]
                new_answers.append(answer)
            if len(new_answers) == ansidx:
                new_answers.append(answer)
        
        if beam_size == 1:
            return new_answers[0]
        return new_answers
