#!/usr/bin/env python3
import os
import gc
import time
import time
import torch
import requests
from pathlib import Path
from src.logger import setup_logger
from vllm import LLM, SamplingParams


# Setup logger
hf_model = os.getenv("AGENT_LLM_MODEL", "default_model")
run_id = os.getenv("RUN_ID", "default_run") 

logger = setup_logger(
    __name__,
    level="INFO",
    log_subdir=hf_model,
    log_filename=f"run_{run_id}.log"
)

# OpenAI-compatible model path map
HF_MAP = {
    "Llama-3.1-8B-it": os.getenv("LLAMA_3_1_8B_PATH", "/models/meta-llama/Llama-3.1-8B-Instruct"),
    "Llama-3.1-70B-it": os.getenv("LLAMA_3_1_70B_PATH", "/models/meta-llama/Llama-3.1-70B-Instruct"),
    "Llama-3.3-70B-it": os.getenv("LLAMA_3_3_70B_PATH", "/models/meta-llama/Llama-3.3-70B-Instruct"),
    "Qwen2.5-7B-Instruct": os.getenv("QWEN_2_5_7B_PATH", "/models/Qwen/Qwen2.5-7B-Instruct"),
    "Qwen2.5-14B-Instruct": os.getenv("QWEN_2_5_14B_PATH", "/models/Qwen/Qwen2.5-14B-Instruct"),
    "Qwen2.5-32B-Instruct": os.getenv("QWEN_2_5_32B_PATH", "/models/Qwen/Qwen2.5-32B-Instruct"),
    "Qwen2.5-72B-Instruct": os.getenv("QWEN_2_5_72B_PATH", "/models/Qwen/Qwen2.5-72B-Instruct"),
    "QwQ-32B": os.getenv("QWQ_32B_PATH", "/models/Qwen/QwQ-32B"),
    "Qwen3-32B": os.getenv("QWEN_3_32B_PATH", "/models/Qwen/Qwen3-32B"),
    "Qwen3-32B-thinking": os.getenv("QWEN_3_32B_THINKING_PATH", "/models/Qwen/Qwen3-32B"),
    "Qwen3-8B": os.getenv("QWEN_3_8B_PATH", "/models/Qwen/Qwen3-8B"),
    "Qwen3-8B-thinking": os.getenv("QWEN_3_8B_THINKING_PATH", "/models/Qwen/Qwen3-8B"),
    "internlm-2.5-7B-chat": os.getenv("INTERNLM_2_5_7B_PATH", "/models/internlm/internlm2_5-7b-chat"),
    "internlm-2.5-20B-chat": os.getenv("INTERNLM_2_5_20B_PATH", "/models/internlm/internlm2_5-20b-chat"),
}


class CustomModel:
    """
    Interface to LLM models.
    Can operate in three modes:
    1. hf mode: Sends requests to a vLLM OpenAI-compatible API
    2. close mode: Sends requests to a closed-source model
    """
    
    def __init__(self, hf_model=None, close_name=None):
        """
        Initialize the model interface.
        
        Args:
            hf_model (str, optional): Name of the Huggingface model
            close_name (str, optional): Name of the closed-source model
        """
        self.model = None
        self.tokenizer = None

        # Determine mode of operation
        if hf_model:
            self.mode = "hf"
            self.model_name = hf_model
            logger.info(f"Initializing vLLM with model: {hf_model}")
            
            if hf_model not in HF_MAP:
                logger.error(f"Model {hf_model} not found in HF_MAP. Available models: {list(HF_MAP.keys())}")
                raise ValueError(f"Model {hf_model} not found in HF_MAP")
            
            vllm_model_path = HF_MAP[hf_model]
            logger.info(f"Model path: {vllm_model_path}")
            
            if not os.path.exists(vllm_model_path):
                logger.error(f"Model path does not exist: {vllm_model_path}")
                models_dir = os.getenv("MODELS_DIR", "/models")
                if os.path.exists(models_dir):
                    available_models = [p for p in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, p))]
                    logger.error(f"Available models in {models_dir}: {available_models}")
                    qwen_dir = os.path.join(models_dir, "Qwen")
                    if os.path.exists(qwen_dir):
                        qwen_models = [p for p in os.listdir(qwen_dir) if os.path.isdir(os.path.join(qwen_dir, p))]
                        logger.error(f"Available Qwen models: {qwen_models}")
                else:
                    logger.error(f"No {models_dir} directory found")
                raise FileNotFoundError(f"Model path not found: {vllm_model_path}")
            
            # Check GPU status
            if torch.cuda.is_available():
                logger.info(f"CUDA available: {torch.cuda.is_available()}")
                logger.info(f"CUDA device count: {torch.cuda.device_count()}")
                for i in range(torch.cuda.device_count()):
                    props = torch.cuda.get_device_properties(i)
                    logger.info(f"GPU {i}: {props.name}, Memory: {props.total_memory / 1e9:.2f} GB")
                    logger.info(f"GPU {i} current memory: {torch.cuda.memory_allocated(i) / 1e9:.2f} GB")
            else:
                logger.warning("CUDA not available!")
            
            try:
                num_gpus = torch.cuda.device_count()
                tensor_parallel_size = max(1, num_gpus) 

                logger.info("Starting vLLM initialization...")
                self.llm = LLM(
                    model=vllm_model_path,
                    gpu_memory_utilization=0.85,
                    max_model_len=16384,
                    tensor_parallel_size=tensor_parallel_size, 
                    trust_remote_code=True,
                    dtype="float16"
                )
                logger.info("vLLM initialization completed successfully")
            except Exception as e:
                logger.error(f"vLLM initialization failed: {str(e)}")
                logger.error(f"Error type: {type(e).__name__}")
                import traceback
                logger.error(f"Full traceback: {traceback.format_exc()}")
                raise
        elif close_name:
            self.mode = "close"
            self.model_name = close_name
            logger.info(f"Initializing closed-source model")
        else:
            logger.error("No mode parameters (model_path, api_url, or ollama_url) provided")
            raise ValueError("Either model_path, api_url, or ollama_url must be provided")

        self.last_memory_cleanup = time.time()
        self.memory_cleanup_interval = 300

    def _load_model_from_path(self, model_path):
        """
        Load model weights from the given path.
        In a real implementation, this would use the appropriate library.
        
        Args:
            model_path (str): Path to model weights
        """
        try:
            # This is a placeholder for actual model loading code
            logger.info(f"Model would be loaded from {model_path}")
            self._init_dummy_model()
        except Exception as e:
            logger.error(f"Error loading model from {model_path}: {str(e)}")
            raise
        
    def _cleanup_gpu_memory(self):
        """Regularly clean up GPU memory to prevent fragmentation"""
        current_time = time.time()
        if current_time - self.last_memory_cleanup > self.memory_cleanup_interval:
            try:
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    gc.collect()
                    logger.info(f"GPU memory cleanup completed. Available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
                self.last_memory_cleanup = current_time
            except Exception as e:
                logger.warning(f"GPU memory cleanup failed: {str(e)}")
    
    def generate(self, prompt, max_tokens=1024, temperature=0.7, top_p=0.9):
        """
        Generate a response based on the input prompt.
        
        Args:
            prompt (str): The input text to generate a response for
            max_tokens (int): Maximum number of tokens to generate
            temperature (float): Sampling temperature (higher = more creative, lower = more focused)
            top_p (float): Nucleus sampling parameter
            
        Returns:
            str: The generated text response
        """
        # Regular memory cleanup
        self._cleanup_gpu_memory()
        if self.mode == "hf":
            return self._generate_via_hf(prompt, max_tokens, temperature, top_p)
        elif self.mode == "close":
            return self._generate_via_close_models(prompt, max_tokens, temperature, top_p)
        else:
            return self._generate_direct(prompt, max_tokens, temperature, top_p)
    def _generate_via_hf(self, prompt, max_tokens, temperature, top_p):
        """
        Generate text by huggingface vllm mode.
        """       
        sampling_params = SamplingParams(
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p
        )
        response = self.llm.chat([
            {"role": "system", "content": "You are an autonomous agent helping to achieve a high-level command-line task."},
            {"role": "user", "content": f"{prompt}"}
        ], sampling_params)                
        return response[0].outputs[0].text

    def _generate_via_close_models(self, prompt, max_tokens, temperature, top_p):
        """
        Generate text by querying the Ollama API.
        
        Args:
            prompt (str): The input text
            max_tokens (int): Maximum tokens to generate
            temperature (float): Temperature parameter
            top_p (float): Top-p parameter
            
        Returns:
            str: The generated text
        """
        base_url = os.getenv("API_BASE_URL", "http://localhost:8000")
        api_key = os.getenv("API_KEY", "")
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {api_key}"  
        }

        data = {
            "model": self.model_name, 
            "messages": [
                {
                    "role": "system",
                    "content": "You are an autonomous agent helping to achieve a high-level command-line task."
                },
                {
                    "role": "user",
                    "content": prompt
                }
            ],
            "max_new_tokens": max_tokens,
            "top_p": top_p,
            "temperature": temperature
        }
        start_time = time.time()
        
        response = requests.post(base_url, headers=headers, json=data)

        
        if response.status_code == 200:
            elapsed_time = time.time() - start_time
            logger.debug(f"Response received from API completion in {elapsed_time:.2f}s")
            
            result = response.json()
            result = result["choices"][0]["message"]["content"]
            return result
        else:
            logger.warning(f"Request failed with status code {response.status_code}")