import os
import requests
import logging
import torch
from typing import Dict, List, Union, Optional
from openai import OpenAI
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer
import litellm
from memory_system import AgenticMemorySystem
import abc
from litellm import completion
import os


logger = logging.getLogger(__name__)

class BaseClient(abc.ABC):
    @abc.abstractmethod
    def generate_response(self, prompt, model="gpt-4o", temperature=0.01, force_json=False):
        pass

    def reset(self):
        pass
    
    @property
    def has_memory(self):
        return False


class VLLMOpenAIClient(BaseClient):
    def __init__(self):
        self.url = "http://localhost:8014"

    def generate_response(self, prompt, model="gpt-4o", temperature=0.01, force_json=False):
        try:
            response = requests.post(
                self.url + "/v1/chat/completions",
                json={
                    "model": model,
                    "temperature": temperature,
                    "messages": [{"role": "user", "content": prompt}],
                    "stop": ["</search>", "</answer>"]
                }
            )

            choice = response.json()['choices'][0]

            content = choice["message"]["content"].strip()

            if choice["stop_reason"] == "</search>":
                content += "</search>"
            elif choice["stop_reason"] == "</answer>":
                content += "</answer>"

            return content

        except Exception as e:
            logger.error(f"Error: {str(e)}", exc_info=True)
            return f"Error: {str(e)}"


class OpenAIClient(BaseClient):
    def __init__(self):
        load_dotenv()
        openai_key = os.getenv('OPENAI_API_KEY')
        if not openai_key:
            raise ValueError("OpenAI API key not found. Please set OPENAI_API_KEY in your .env file.")
        self.client = OpenAI(api_key=openai_key)


    def generate_response(self, prompt, model="gpt-4o", temperature=0.01, force_json=False):
        try: 
            # Format messages properly with content type
            # messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
            # if force_json:
            #     response = self.client.chat.completions.create(
            #         model=model,
            #         response_format={"type": "json_object"},
            #         temperature=temperature,
            #         messages=messages
            #     )
            # else:
            #     response = self.client.chat.completions.create(
            #         model=model,
            #         temperature=temperature,
            #         messages=messages
            #     )
            # return response.choices[0].message.content.strip()
            response = completion(
                model="gemini/gemini-2.0-flash", 
                messages=[{"role": "user", "content": prompt}]
            )
            return response['choices'][0]['message']['content'].strip()

        except Exception as e:
            return f"Error: {str(e)}"
        



class QwenClient(BaseClient):
    """
    Client for using Qwen 2.5 3B model for text generation.
    
    This class provides an interface similar to OpenAIClient for running
    the LLM loop with a local Qwen 2.5 3B model instead of OpenAI models.
    """
    
    def __init__(self, model_name: str = "Qwen/Qwen2.5-3B", device: str = "cuda"):
        """
        Initialize the Qwen client with the specified model.
        
        Args:
            model_name: The name or path of the Qwen model to use.
                        Default is "Qwen/Qwen2.5-3B".
            device: The device to load the model on. Default is "cuda".
        """
        # Check if CUDA is available when device is set to cuda
        if device == "cuda" and not torch.cuda.is_available():
            print("CUDA not available, falling back to CPU")
            device = "cpu"
            
        self.device = device
        self.model_name = model_name
        
        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
            trust_remote_code=True
        ).to(device)
        
    def generate_response(self, prompt: str, model: Optional[str] = None, 
                          temperature: float = 1.0, force_json: bool = False) -> str:
        """
        Generate a response from the Qwen model for the given prompt.
        
        Args:
            prompt: The input prompt to generate a response for.
            model: Ignored parameter (for compatibility with OpenAIClient).
            temperature: Sampling temperature. Default is 1.0.
            force_json: Whether to force JSON output format. Currently not supported.
                
        Returns:
            The generated response text.
        """
        try:
            # Prepare inputs
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            
            # Generate response
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=1024,
                    temperature=temperature,
                    do_sample=temperature > 0,
                    pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
                )
            
            # Decode and return the generated text (excluding the prompt)
            prompt_length = inputs.input_ids.shape[1]
            response = self.tokenizer.decode(outputs[0][prompt_length:], skip_special_tokens=True)
            
            return response.strip()
            
        except Exception as e:
            return f"Error: {str(e)}"


class VLLMQwenClient(BaseClient):
    """
    Client for using Qwen 2.5 3B model with VLLM for accelerated inference.
    
    This class provides the same interface as QwenClient but uses VLLM
    for faster text generation, particularly beneficial for batch inference.
    """
    
    def __init__(self, model_name: str = "Qwen/Qwen2.5-3B", tensor_parallel_size: int = 1,
                 gpu_memory_utilization: float = 0.7, max_model_len: int = 8196):
        """
        Initialize the VLLM Qwen client with the specified model.
        
        Args:
            model_name: The name or path of the Qwen model to use.
                        Default is "Qwen/Qwen2.5-3B".
            tensor_parallel_size: Number of GPUs to use for tensor parallelism.
                                  Default is 1.
            gpu_memory_utilization: Fraction of GPU memory to use. Default is 0.9.
            max_model_len: Maximum sequence length for the model. Default is 4096.
        """
        try:
            from vllm import LLM, SamplingParams
        except ImportError:
            raise ImportError("VLLM is not installed. Please install it using 'pip install vllm'.")
            
        if not torch.cuda.is_available():
            raise RuntimeError("VLLM requires CUDA. No CUDA devices found.")
            
        self.model_name = model_name
        
        # Set VLLM engine to use xformers backend for Qwen models
        # This is recommended in the Search-R1 scripts
        os.environ["VLLM_ATTENTION_BACKEND"] = "XFORMERS"
        
        # Initialize VLLM model
        self.llm = LLM(
            model=model_name,
            tensor_parallel_size=tensor_parallel_size,
            gpu_memory_utilization=gpu_memory_utilization,
            trust_remote_code=True,
            max_model_len=max_model_len
        )
        
        # Initialize sampling parameters
        self.default_sampling_params = SamplingParams(
            max_tokens=1024,
            temperature=1.0
        )
        
    def generate_response(self, prompt: str, model: Optional[str] = None, 
                          temperature: float = 1.0, force_json: bool = False) -> str:
        """
        Generate a response from the Qwen model using VLLM for the given prompt.
        
        Args:
            prompt: The input prompt to generate a response for.
            model: Ignored parameter (for compatibility with OpenAIClient).
            temperature: Sampling temperature. Default is 1.0.
            force_json: Whether to force JSON output format. Currently not supported.
                
        Returns:
            The generated response text.
        """
        try:
            from vllm import SamplingParams
            

            # Create sampling parameters
            sampling_params = SamplingParams(
                max_tokens=1024,
                temperature=temperature,
                top_p=0.95,
                frequency_penalty=0.0,
                presence_penalty=0.0
            )
            
            # Generate response with VLLM
            outputs = self.llm.generate(prompt, sampling_params)
            
            # Extract the generated text (excluding the prompt)
            response = outputs[0].outputs[0].text
            
            return response.strip()
            
        except Exception as e:
            return f"Error: {str(e)}" 


class LiteLLMClient(BaseClient):
    def __init__(self):
        litellm.drop_params = True
        assert "OPENAI_API_KEY" in os.environ, "OPENAI_API_KEY is not set"
        assert "OPENROUTER_API_KEY" in os.environ, "OPENROUTER_API_KEY is not set"

    def generate_response(self, prompt, model="openai/gpt-4o-mini", temperature=0.7, force_json=False):
        config = {
            "temperature": temperature,
            "top_p": 1,
            "provider": {
                "sort": "throughput"
            },
        }

        if model.startswith("openai/"):
            # drop provider
            config.pop("provider")

        if not model.startswith("openrouter/") and not model.startswith("openai/"):
            model = "openrouter/" + model

        try: 
            # Format messages properly with content type
            messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
            if force_json:
                response = litellm.completion(
                    model=model,
                    response_format={"type": "json_object"},
                    messages=messages,
                    **config
                )
            else:
                response = litellm.completion(
                    model=model,
                    messages=messages,
                    **config
                )
            return response.choices[0].message.content.strip()

        except Exception as e:
            return f"Error: {str(e)}"


class Mem0Client(BaseClient):
    def __init__(self):
        assert "OPENAI_API_KEY" in os.environ, "OPENAI_API_KEY is not set"
        assert "OPENROUTER_API_KEY" in os.environ, "OPENROUTER_API_KEY is not set"
        litellm.drop_params = True
        # Initialize the memory system 🚀
        self.memory_system = AgenticMemorySystem(
            model_name='all-MiniLM-L6-v2',  # Embedding model for ChromaDB
            llm_backend="openai",           # LLM backend (openai/ollama)
            llm_model="gpt-4o-mini"         # LLM model name
        )
        self.memories = []


    def chat_with_memories(self, message: str, model: str, temperature: float = 0.01, force_json: bool = False, user_id: str = "default_user") -> str:
        # Retrieve relevant memories
        relevant_memories = self.memory_system.search_agentic(message, k=3)
        memories_str = "\n".join(f"- {entry['content']}" for entry in relevant_memories)
        self.memories.append(memories_str)
        # Generate Assistant response
        system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}"
        messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": message}]

        config = {
            "temperature": temperature, 
            "top_p": 0.95,
            "provider": {
                "sort": "throughput"
            }
        }
        if force_json:
            config["response_format"] = {"type": "json_object"}
        
        if model.startswith("openai/"):
            config.pop("provider")
        
        if not model.startswith("openrouter/") and not model.startswith("openai/"):
            model = "openrouter/" + model
        
        response = litellm.completion(model=model, messages=messages, **config)
        assistant_response = response.choices[0].message.content.strip()

        return assistant_response


    def generate_response(self, prompt, model="openai/gpt-4o-mini", temperature=0.01, force_json=False):
        return self.chat_with_memories(prompt, model=model, temperature=temperature, force_json=force_json)


    @property
    def has_memory(self):
        return True
    
    def reset(self):
        self.memory_system = AgenticMemorySystem(
            model_name='all-MiniLM-L6-v2',  # Embedding model for ChromaDB
            llm_backend="openai",           # LLM backend (openai/ollama)
            llm_model="gpt-4o-mini"         # LLM model name
        )
        self.memories = []