"""LLM configuration for different model types"""
import argparse
from typing import Dict, Any, List, Optional, Union
import base64
import json
from openai import OpenAI
import requests
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoProcessor
import torch
import sys
from qwen_vl_utils import process_vision_info
from PIL import Image
from io import BytesIO

class DirectVLLMModel:
    """Direct vLLM model wrapper that can be used without qwen_agent"""
    
    def __init__(self, model_name: str, server_url: str, api_key: str = "EMPTY", **kwargs):
        self.model_name = model_name
        self.server_url = server_url
        self.api_key = api_key
        self.client = OpenAI(
            base_url=server_url,
            api_key=api_key
        )
        self.temperature = kwargs.get('temperature', 0.2)
        self.top_p = kwargs.get('top_p', 0.9)
        self.max_tokens = kwargs.get('max_tokens', 2048)
    
    def chat(self, messages: List[Dict], stream: bool = False, functions: List[Dict] = None, function_call: str = "auto", **kwargs):
        """Chat with the model using simplified message format"""
        # Prepare function calling parameters
        call_params = {
            "model": self.model_name,
            "messages": messages,
            "stream": stream,
            "temperature": kwargs.get('temperature', self.temperature),
            "top_p": kwargs.get('top_p', self.top_p),
            "max_tokens": kwargs.get('max_tokens', self.max_tokens),
        }
        
        # # Add function calling if provided
        # if functions:
        #     call_params["functions"] = functions
        #     call_params["function_call"] = function_call
        
        # Call the model
        response = self.client.chat.completions.create(**call_params)
        if stream:
            return response
        else:
            return response.choices[0].message


class DirectOpenAIModel:
    """Direct OpenAI model wrapper"""
    
    def __init__(self, model_name: str, api_key: str, base_url: str = "https://api.openai.com/v1", **kwargs):
        self.model_name = model_name
        self.api_key = api_key
        self.base_url = base_url
        self.client = OpenAI(
            api_key=api_key,
            base_url=base_url
        )
        self.temperature = kwargs.get('temperature', 0.2)
        self.top_p = kwargs.get('top_p', 0.9)
        self.max_tokens = kwargs.get('max_tokens', 2048)
    
    def chat(self, messages: List[Dict], stream: bool = False, functions: List[Dict] = None, function_call: str = "auto", **kwargs):
        """Chat with the model using simplified message format"""
        # Prepare function calling parameters
        call_params = {
            "model": self.model_name,
            "messages": messages,
            "stream": stream,
            "temperature": kwargs.get('temperature', self.temperature),
            "top_p": kwargs.get('top_p', self.top_p),
            "max_tokens": kwargs.get('max_tokens', self.max_tokens),
        }
        
        # # Add function calling if provided
        # if functions:
        #     call_params["functions"] = functions
        #     call_params["function_call"] = function_call
        
        # Call the model
        response = self.client.chat.completions.create(**call_params)
        
        if stream:
            return response
        else:
            return response.choices[0].message


class DirectTransformersModel:
    """Direct Transformers model wrapper for Qwen2.5-VL with experience handling"""
    
    def __init__(self, model_name: str, **kwargs):
        self.model_name = model_name
        self.temperature = 0.1 #kwargs.get('temperature', 0.1)
        self.top_p = 0.9 #kwargs.get('top_p', 0.001)
        self.max_tokens = 10**3 #kwargs.get('max_tokens', 1000)
        self.checkpoint_path = kwargs.get('checkpoint_path', model_name)
        self.max_memory = kwargs.get('max_memory', None)
        
        # Load processor and tokenizer
        self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True)
        self.tokenizer = self.processor.tokenizer
        
        # Import the custom model class
        path = 'CoMEM-Agent-train'
        sys.path.append(path)
        from src_agent.training.qwenVL_inference import Qwen2_5_VLForConditionalGeneration_new
        self.model = Qwen2_5_VLForConditionalGeneration_new.from_pretrained(
            self.checkpoint_path,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="auto",
            max_memory=self.max_memory,
            low_cpu_mem_usage=True
        )
    
    def process_vision_info(self, conversation):
        """Process vision information from conversation"""
        image_inputs = []
        
        for message in conversation:
            if isinstance(message['content'], list):
                for item in message['content']:
                    if item['type'] == 'image_url':
                        image_url = item['image_url']['url']
                        image_bytes = base64.b64decode(image_url.split(',')[1])
                        image = Image.open(BytesIO(image_bytes))
                        image_inputs.append(image)
        
        return image_inputs
    
    def knowledge_processor_vlm(self, processor, inputs, texts=None, images=None, tokenizer=None, formatted_prompt=None):
        """Process experience information for VLM"""
        import torch
        
        # Default tokens for image processing
        DEFAULT_IM_START_TOKEN = "<|im_start|>"
        DEFAULT_IM_END_TOKEN = "<|im_end|>"
        DEFAULT_IMAGE_TOKEN = "<|image_pad|>"
        VISION_START_TOKEN = "<|vision_start|>"
        VISION_END_TOKEN = "<|vision_end|>"
        
        all_experience_input_ids = [] 
        all_experience_pixel_values = []
        all_experience_image_grid_thw = []
        for trajectory_actions, trajectory_images in zip(texts, images):
            trajectory_text = ""
            trajectory_image = []
            for action, image_base64 in zip(trajectory_actions, trajectory_images):
                if isinstance(image_base64, dict) and image_base64.get('url', '').startswith('data:image/png;base64,'):
                    image_bytes = base64.b64decode(image_base64.get('url', '').split(',')[1])
                elif isinstance(image_base64, str) and image_base64.startswith('data:image/png;base64,'):
                    image_bytes = base64.b64decode(image_base64.split(',')[1])
                else:
                    image_bytes = base64.b64decode(image_base64)
                image = Image.open(BytesIO(image_bytes))
                trajectory_image.append(image)
                trajectory_text += f"{DEFAULT_IM_START_TOKEN}user\n{VISION_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{VISION_END_TOKEN}{action}{DEFAULT_IM_END_TOKEN}\n"
            if trajectory_image:
                e_inputs = processor(text=[trajectory_text], images=trajectory_image, padding=False, return_tensors='pt')
                e_input_ids = e_inputs['input_ids'].squeeze(0)
                e_pixel_values = e_inputs['pixel_values']
                e_image_grid_thw = e_inputs['image_grid_thw']
                all_experience_pixel_values.append(e_pixel_values)
                all_experience_image_grid_thw.append(e_image_grid_thw)
            else:
                e_input_ids = processor.tokenizer(trajectory_text, add_special_tokens=False, padding=False, return_tensors='pt')['input_ids'].squeeze(0)
            
            all_experience_input_ids.append(e_input_ids)

        
        inputs['experience_input_ids'] = all_experience_input_ids
        inputs['experience_pixel_values'] = all_experience_pixel_values
        inputs['experience_image_grid_thw'] = all_experience_image_grid_thw
        
        return inputs
    
    def generate_response_with_experience(self, image=None, prompt=None, experience_texts=None, experience_images=None, conversation=None, experience_embedding=None):
        """Generate response with experience texts and images"""
        
        if not conversation:
            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt},
                        {"type": "image", "image": image}
                    ],
                }
            ]
        
        formatted_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
        # print('formatted_prompt:', formatted_prompt)
        image_inputs = self.process_vision_info(conversation)
        # print('image_number:', len(image_inputs))
        
        inputs = self.processor(
            text=[formatted_prompt],
            images=image_inputs,
            return_tensors="pt",
        ).to("cuda")
        
        # Process experience information
        inputs_with_experience = self.knowledge_processor_vlm(
            processor=self.processor,
            inputs=inputs,
            texts=experience_texts,
            images=experience_images,
            tokenizer=self.tokenizer,
            formatted_prompt=formatted_prompt
        ).to("cuda")
        
        if experience_embedding is not None:
            inputs_with_experience['experience_compress_embedding'] = experience_embedding
        
        import time
        time0 = time.time()
        
        generated_ids = self.model.generate(
            **inputs_with_experience, 
            max_new_tokens=self.max_tokens,
            use_cache=True, 
            temperature=self.temperature,
            top_p=self.top_p,
        )
        
        time1 = time.time()
        print('time for 1 sample generation:', time1-time0)
        print('generated_ids', generated_ids)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs_with_experience.input_ids, generated_ids)
        ]
        
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        output_text = output_text[0]
        return output_text
    
    def chat(self, messages: List[Dict], stream: bool = False, 
             experience_texts=None, experience_images=None):
        """Chat with the model using transformers with experience support"""
        if stream:
            raise NotImplementedError("Streaming not yet implemented for transformers models")
        # Check if experience data is provided
        has_experience = False
        if experience_texts is not None:
            # Check if any experience text is not empty
            has_experience = any(len(text_list) > 0 for text_list in experience_texts)
        if experience_images is not None:
            # Check if any experience image list is not empty
            has_experience = any(len(img_list) > 0 for img_list in experience_images)

        if not has_experience:
            print("No experience data provided, falling back to DirectVLLMModel...")
            # Fall back to DirectVLLMModel when no experience data
            vllm_model = DirectVLLMModel(
                model_name='Qwen/Qwen2.5-VL-7B-Instruct',
                server_url='http://localhost:8000/v1',
                api_key="EMPTY",
                temperature=0.2,
                top_p=0.9,
                max_tokens=self.max_tokens
            )
            return vllm_model.chat(messages, stream=False)
        
        else:
            print("Generating response with experience...")
            # Generate response with experience
            response_text = self.generate_response_with_experience(
                experience_texts=experience_texts,
                experience_images=experience_images,
                conversation=messages
            )
            
            # Create OpenAI-style response
            from openai.types.chat import ChatCompletionMessage
            return ChatCompletionMessage(
                role="assistant",
                content=response_text,
                function_call=None,
                tool_calls=None
            )


def create_direct_vllm_model(args: argparse.Namespace, model_name: str = None) -> DirectVLLMModel:
    """Create a direct vLLM model instance"""
    model_name_map = {
        'cogagent': 'zai-org/cogagent-9b-20241220',
        'qwen2.5-vl': 'qwen/qwen-2.5-vl-7b-instruct',
        'qwen2-vl': 'qwen/qwen-2-vl-7b-instruct',
        'ui-tars': 'bytedance/ui-tars-1.5-7b',
        'qwen2.5-vl-32b': 'qwen/qwen2.5-vl-32b-instruct',
        'gpt': 'openai/gpt-4o-2024-11-20',
        'gemini': 'google/gemini-2.5-pro',
        'claude': 'anthropic/claude-sonnet-4',
        'glm': 'thudm/glm-4.1v-9b-thinking'
    }
    
    if model_name:
        model_name_ = model_name_map.get(model_name, model_name)
    else:
        model_name_ = model_name_map.get(args.model, args.model)
    server_url = 'https://openrouter.ai/api/v1'
    api_key = ''
    print('model_name_', model_name_)
    print('server_url', server_url)
    print('api_key', api_key)
    
    return DirectVLLMModel(
        model_name=model_name_,
        server_url=server_url,
        api_key=api_key,
        temperature=0.2,
        top_p=0.9,
        max_tokens=args.max_tokens,
    )


def create_direct_openai_model(args: argparse.Namespace, model_name: str = None) -> DirectOpenAIModel:
    """Create a direct OpenAI model instance"""
    if model_name is None:
        model_name = args.model
    return DirectOpenAIModel(
        model_name=model_name,
        api_key=args.openai_api_key if hasattr(args, 'openai_api_key') else None,
        base_url="https://api.openai.com/v1",
        temperature=0.2,
        top_p=0.9,
        max_tokens=args.max_tokens,
    )


def create_direct_transformers_model(args: argparse.Namespace, model_name: str = None) -> DirectTransformersModel:
    """Create a direct Transformers model instance"""
    model_name_map = {
        'agent-qformer': '',
    }
    if model_name is None:
        model_name = model_name_map.get(args.model, args.model)
    else:
        model_name = model_name_map.get(model_name, model_name)
    
    return DirectTransformersModel(
        model_name=model_name,
        checkpoint_path=args.checkpoint_path if hasattr(args, 'checkpoint_path') else model_name,
        temperature=0.1,
        top_p=0.001,
        max_tokens=args.max_tokens,
        max_memory=args.max_memory if hasattr(args, 'max_memory') else None,
    )


def create_direct_model(args: argparse.Namespace):
    """Create a direct model instance based on model type"""
    if args.model in ['gpt-4o', 'gpt-4o-mini']:
        return create_direct_openai_model(args)
    elif args.use_continuous_memory:
        return create_direct_transformers_model(args)
    else:
        # Default to vLLM
        return create_direct_vllm_model(args)


def load_grounding_model_vllm(args: argparse.Namespace):
    """
    Load grounding model using vLLM server with OpenAI client.
    
    Args:
        args: Arguments object
        
    Returns:
        Grounding model client
    """
    from openai import OpenAI
    
    # Create client with custom base URL pointing to your vLLM server
    grounding_model = OpenAI(
        base_url="https://openrouter.ai/api/v1",  # Adjust the port if needed
        api_key=""  # vLLM doesn't check API keys, but the client requires one
    )
    
    return grounding_model

def load_tool_llm(args: argparse.Namespace) -> DirectVLLMModel:
    """Load tool LLM"""
    tool_model = create_direct_vllm_model(args, model_name='qwen2.5-vl')
    # tool_model = create_direct_openai_model(args, model_name='gpt-4o-mini')
    return tool_model

def load_openai_llm(args: argparse.Namespace) -> DirectVLLMModel:
    """Load tool LLM"""
    tool_model = create_direct_openai_model(args, model_name='gpt-4o-mini')
    return tool_model
