import yaml
import logging
from datetime import datetime
import re
import numpy as np
import os
import sys
import base64
import io
from PIL import Image
import torch
import time
from OceanGym.utils.config_process import load_config
# Add current directory to path
current_dir = os.path.dirname(__file__)
sys.path.append(current_dir)


def get_llm_mode():
    """Get LLM mode from configuration file"""
    try:
        config = load_config()
        
        # Check new llm configuration
        llm_mode = config.get("llm", {}).get("mode", "api")
        
        if llm_mode == "local":
            # Validate local model path
            local_config = config.get("llm", {}).get("local", {})
            model_path = local_config.get("path", "")
            
            if model_path and os.path.exists(model_path):
                print(f"🔧 Local mode: {model_path}")
                return "local", config
            else:
                print(f"⚠️ Local model path does not exist: {model_path}, falling back to API mode")
                return "api", config
        else:
            print("🌐 API mode")
            return "api", config
            
    except Exception as e:
        print(f"❌ Configuration loading failed: {e}, using default API mode")
        return "api", None

# Initialize LLM mode
LLM_MODE, CONFIG = get_llm_mode()

# Global variables
_qwen_service = None
_api_client = None

class QwenVLService:
    """Qwen2.5-VL model service"""
    
    def __init__(self, model_path, device="cuda"):
        self.model_path = model_path
        self.device = device
        self.model = None
        self.processor = None
        self.is_loaded = False
        
    def load_model(self):
        """Load Qwen2.5-VL model"""
        try:
            print(f"🔄 Loading Qwen2.5-VL model from: {self.model_path}")
            
            # Check system resources
            if torch.cuda.is_available():
                gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
                print(f"🎯 Available GPU memory: {gpu_memory:.1f} GB")
                print(f"🎯 Current GPU memory usage: {torch.cuda.memory_allocated(0) / 1024**3:.1f} GB")
            else:
                print("❌ CUDA not available!")
                
            # Check transformers version
            import transformers
            print(f"📦 Transformers version: {transformers.__version__}")
            
            # Check model files
            config_file = os.path.join(self.model_path, "config.json")
            if os.path.exists(config_file):
                print(f"✅ Model config found: {config_file}")
                with open(config_file, 'r', encoding='utf-8') as f:
                    import json
                    config_data = json.load(f)
                    print(f"📋 Model type: {config_data.get('model_type', 'unknown')}")
            else:
                print(f"❌ Model config not found: {config_file}")
                
            print("🔄 Starting model loading...")
            
            try:
                # Method 1: Try using Qwen2VLForConditionalGeneration
                from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
                print("📦 Using Qwen2VLForConditionalGeneration")
                
                self.model = Qwen2VLForConditionalGeneration.from_pretrained(
                    self.model_path,
                    torch_dtype=torch.float16,
                    device_map="auto", 
                    trust_remote_code=True,
                    max_memory={0: "15GB"},
                    low_cpu_mem_usage=True
                )
                
                self.processor = AutoProcessor.from_pretrained(
                    self.model_path,
                    trust_remote_code=True
                )
                
                print("✅ Method 1 (Qwen2VLForConditionalGeneration) successful")
                
            except Exception as e1:
                print(f"⚠️ Method 1 failed: {e1}")
                
                try:
                    # Method 2: Try AutoModelForCausalLM
                    from transformers import AutoModelForCausalLM, AutoProcessor
                    print("📦 Using AutoModelForCausalLM")
                    
                    self.model = AutoModelForCausalLM.from_pretrained(
                        self.model_path,
                        torch_dtype=torch.float16,
                        device_map="auto", 
                        trust_remote_code=True,
                        max_memory={0: "15GB"},
                        low_cpu_mem_usage=True
                    )
                    
                    self.processor = AutoProcessor.from_pretrained(
                        self.model_path,
                        trust_remote_code=True
                    )
                    
                    print("✅ Method 2 (AutoModelForCausalLM) successful")
                    
                except Exception as e2:
                    print(f"⚠️ Method 2 failed: {e2}")
                    
                    # Method 3: Use conservative loading approach
                    print("🔄 Trying conservative loading...")
                    from transformers import AutoModelForCausalLM, AutoProcessor
                    
                    self.model = AutoModelForCausalLM.from_pretrained(
                        self.model_path,
                        torch_dtype=torch.float16,
                        device_map="cpu",  # Load to CPU first
                        trust_remote_code=True,
                        low_cpu_mem_usage=True
                    )
                    
                    print("✅ Model loaded to CPU, moving to GPU...")
                    if torch.cuda.is_available():
                        self.model = self.model.to("cuda")
                        self.device = "cuda"
                    else:
                        self.device = "cpu"
                    
                    self.processor = AutoProcessor.from_pretrained(
                        self.model_path,
                        trust_remote_code=True
                    )
                    
                    print("✅ Conservative loading successful")
            
            self.is_loaded = True
            print(f"✅ Qwen2.5-VL model loaded successfully on {self.device}")
            
            # Check if model has generate method
            if hasattr(self.model, 'generate'):
                print("✅ Model has generate method")
            else:
                print("❌ Model missing generate method")
                
            # Check final memory usage
            if torch.cuda.is_available():
                final_memory = torch.cuda.memory_allocated(0) / 1024**3
                print(f"🎯 Final GPU memory usage: {final_memory:.1f} GB")
            
        except Exception as e:
            print(f"❌ All loading methods failed: {e}")
            import traceback
            traceback.print_exc()
            self.is_loaded = False
            raise e
    
    def generate_response(self, prompt, b64_image_lst, max_new_tokens=512):
        """Generate response using Qwen2.5-VL"""
        try:
            if not self.is_loaded:
                self.load_model()
            
            print(f"🤖 Generating response with {len(b64_image_lst)} images...")
            
            # Convert base64 images to PIL images
            images = []
            for i, b64_img in enumerate(b64_image_lst):
                try:
                    img_data = base64.b64decode(b64_img)
                    img = Image.open(io.BytesIO(img_data)).convert("RGB")
                    images.append(img)
                    print(f"✅ Image {i+1} processed: {img.size}")
                except Exception as e:
                    print(f"❌ Failed to process image {i+1}: {e}")
            
            # Prepare Qwen2.5-VL message format
            if images:
                content = [{"type": "text", "text": prompt}]
                for img in images:
                    content.append({"type": "image", "image": img})
                messages = [{"role": "user", "content": content}]
            else:
                messages = [{"role": "user", "content": prompt}]
            
            print("📝 Preparing input...")
            
            # Apply chat template
            text = self.processor.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
            )
            
            # Process inputs
            inputs = self.processor(
                text=[text], 
                images=images if images else None,
                padding=True, 
                return_tensors="pt"
            )
            inputs = inputs.to(self.device)
            
            print("🚀 Generating response...")
            
            # Generate response
            with torch.no_grad():
                generated_ids = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    temperature=0.7,
                    pad_token_id=self.processor.tokenizer.eos_token_id if hasattr(self.processor, 'tokenizer') else None
                )
            
            # Decode response
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            
            output_text = self.processor.batch_decode(
                generated_ids_trimmed, 
                skip_special_tokens=True, 
                clean_up_tokenization_spaces=False
            )
            
            result = output_text[0].strip()
            print(f"✅ Response generated: {len(result)} characters")
            return result
            
        except Exception as e:
            print(f"❌ Qwen2.5-VL generation failed: {e}")
            import traceback
            traceback.print_exc()
            raise e

def _get_local_service():
    """Lazy initialization of local service"""
    global _qwen_service
    if _qwen_service is None:
        try:
            if CONFIG:
                local_config = CONFIG.get("llm", {}).get("local", {})
                model_path = local_config.get("path", "")
                device = local_config.get("device", "cuda")
            else:
                raise ValueError("No local configuration found")
            
            _qwen_service = QwenVLService(model_path, device)
            _qwen_service.load_model()
            print("✅ Local Qwen model loaded successfully")
            return _qwen_service
            
        except Exception as e:
            print(f"❌ Local model loading failed: {e}")
            raise e
    return _qwen_service

def _get_api_client():
    """Lazy initialization of API client"""
    global _api_client
    if _api_client is None:
        from openai import OpenAI
        
        if CONFIG:
            # Prioritize new llm configuration
            api_config = CONFIG.get("llm", {}).get("api", {})
            if api_config:
                api_key = api_config.get("api_key")
                base_url = api_config.get("base_url")
                model_name = api_config.get("model")
            else:
                # Fall back to defaults configuration
                api_key = CONFIG["defaults"]["llm"]["api_key"]
                base_url = CONFIG["defaults"]["llm"]["base_url"]
                model_name = CONFIG["defaults"]["llm"]["model"]
        else:
            # If configuration loading fails, use hardcoded configuration
            api_key = "sk-enWpeKVHpJ1u41CjrtzBgHbFi6QrC06ixEVgwV1DQBNT0Xf1"
            base_url = "https://www.dmxapi.cn/v1"
            model_name = "gpt-4o-mini"
        
        _api_client = {
            'client': OpenAI(api_key=api_key, base_url=base_url),
            'model': model_name
        }
        print("✅ API client initialized successfully")
    return _api_client

def ask_llm(prompt, b64_image_lst, max_retries=3):
    """
    Call large language model interface, supporting both local and API modes.
    Always return a string.
    """
    if LLM_MODE == "local":
        try:
            service = _get_local_service()
            if not service.is_loaded:
                raise RuntimeError("Local model not properly loaded")
            print("🏠 Local mode")
            result = service.generate_response(prompt, b64_image_lst, max_new_tokens=512)
            if isinstance(result, str):
                return result.strip()
            # 兼容未来本地模型返回对象
            if hasattr(result, "choices"):
                return result.choices[0].message.content.strip()
            return str(result)
        except Exception as e:
            print(f"❌ Local model call failed: {e}")
            print("🔄 Falling back to API mode...")
            return ask_llm_api(prompt, b64_image_lst, max_retries)
    else:
        print("🌐 API mode")
        return ask_llm_api(prompt, b64_image_lst, max_retries)

def ask_llm_api(prompt, b64_image_lst, max_retries=3):
    """API mode LLM call. Always return a string."""
    import time
    api_info = _get_api_client()
    client = api_info['client']
    model_name = api_info['model']

    for attempt in range(max_retries + 1):
        try:
            content_lst = [{"type": "text", "text": prompt}]
            for image in b64_image_lst:
                content_lst.append({
                    "type": "image_url", 
                    "image_url": {"url": f"data:image/png;base64,{image}"}
                })
            response = client.chat.completions.create(
                model=model_name,
                messages=[{"role": "user", "content": content_lst}],
                max_tokens=1024
            )
            # 兼容返回字符串或对象
            if isinstance(response, str):
                return response.strip()
            if hasattr(response, "choices"):
                return response.choices[0].message.content.strip()
            return str(response)
        except Exception as e:
            if attempt < max_retries:
                logging.warning(f"API call failed, retrying in {2**attempt} seconds: {e}")
                time.sleep(2 ** attempt)
            else:
                raise e

def found_target(target_info):
    """Check if target object is found"""
    return "@@@" in target_info

def get_llm_info(config):
    """Get LLM information for compatibility"""
    try:
        print(f"🔍 Debug: config structure = {config.keys()}")
        
        llm_config = config.get("llm", {})
        print(f"🔍 Debug: llm_config = {llm_config}")
        
        mode = llm_config.get("mode", "api")
        print(f"🔍 Debug: detected mode = {mode}")
        
        if mode == "local":
            # Validate local model path
            local_config = llm_config.get("local", {})
            model_path = local_config.get("path", "")
            
            if model_path and os.path.exists(model_path):
                model_name = "qwen2.5-vl"
                log_name = "qwen2.5-vl-local"
                print(f"✅ Returning local mode info: {mode}, {model_name}, {log_name}")
                return mode, model_name, log_name
            else:
                print(f"⚠️ Local model path does not exist: {model_path}, falling back to API mode")
        
        # API mode processing
        api_config = llm_config.get("api", {})
        if api_config:
            model_name = api_config.get("model", "gpt-4o-mini")
        else:
            model_name = config.get("defaults", {}).get("llm", {}).get("model", "gpt-4o-mini")
        
        log_name = f"{model_name}-api"
        print(f"✅ Returning API mode info: api, {model_name}, {log_name}")
        return "api", model_name, log_name
        
    except Exception as e:
        print(f"❌ Failed to get LLM info: {e}")
        import traceback
        traceback.print_exc()
        return "api", "gpt-4o-mini", "gpt-4o-mini-api"

def get_llm_status():
    """Get current LLM status"""
    status = {
        "mode": LLM_MODE,
        "use_local": LLM_MODE == "local",
        "config_loaded": CONFIG is not None
    }
    
    if LLM_MODE == "local":
        try:
            service = _get_local_service()
            status["model_loaded"] = service.is_loaded if service else False
        except:
            status["model_loaded"] = False
    
    return status

def test_llm_integration():
    """Test LLM integration"""
    try:
        print("="*60)
        print(f"🧪 Testing LLM integration")
        
        # Show status
        status = get_llm_status()
        print(f"📊 Current status: {status}")
        
        # Simple test
        test_prompt = "Hello, please answer briefly: What is 1+1?"
        print(f"👤 Test prompt: {test_prompt}")
        
        response = ask_llm(test_prompt, [])
        
        print(f"✅ Test successful!")
        print(f"🤖 Response: {response}")
        
        # Test target detection
        test_response = "This is a test response @@@"
        found = found_target(test_response)
        print(f"🔍 Target detection test: {'✅ Success' if found else '❌ Failed'}")
        
        return True
        
    except Exception as e:
        print(f"❌ Test failed: {e}")
        import traceback
        traceback.print_exc()
        return False

if __name__ == "__main__":
    print(f"🔧 LLM Mode: {'Local Qwen2.5-VL' if LLM_MODE == 'local' else 'API'}")
    test_llm_integration()