from types import SimpleNamespace
import concurrent.futures
from openai import OpenAI, AzureOpenAI
import httpx
import json
import time
from dotenv import load_dotenv
import os
from typing import Optional, Any, Dict, List, Union, Type, TypeVar
from pydantic import BaseModel, ValidationError

debug = True
load_dotenv(override=True)

custom_http_client = httpx.Client(verify=False)

T = TypeVar("T", bound=BaseModel)

class LLMBaseAgent: # Adapted and borrowed from external sources
    def __init__(self, kwargs: dict):
        self.args = SimpleNamespace(**kwargs)
        self._set_default_args()
        self._print_args()
        if "gpt-4o" in self.args.model.lower() or "gpt-4.1" in self.args.model.lower():
            self.client = AzureOpenAI(
                api_version=os.getenv("AZURE_API_VERSION"),
                azure_endpoint=os.getenv("AZURE_OPENAI_API_BASE_URL"),
                api_key=os.getenv("AZURE_OPENAI_API_KEY"),
                http_client=custom_http_client
            )
        elif "google/gemma-3-27b-it" in self.args.model.lower():
            self.client = OpenAI(base_url=os.getenv("VLLM_BASE_URL_1"), api_key="None")
        elif "qwen/qwen3-14b" in self.args.model.lower():
            self.client = OpenAI(base_url=os.getenv("VLLM_BASE_URL_2"), api_key="None")
        elif "llama-3.1" in self.args.model.lower():
            self.client = OpenAI(base_url=os.getenv("VLLM_BASE_URL_2"), api_key="None")
        else:
            raise ValueError(f"Model {self.args.model} not supported")

    def _print_args(self):
        print("Initialized LLMBaseAgent with the following parameters:")
        for key, value in vars(self.args).items():
            print(f"  {key}: {value}")

    def _set_default_args(self):
        if not hasattr(self.args, 'model'):
            self.args.model = "gpt-4o"
        if not hasattr(self.args, 'temperature'):
            self.args.temperature = 0.9
        if not hasattr(self.args, 'max_tokens'):
            self.args.max_tokens = 4096
        # if not hasattr(self.args, 'top_p'):
        #     self.args.top_p = 0.9
        # if not hasattr(self.args, 'frequency_penalty'):
        #     self.args.frequency_penalty = 0.7
        # if not hasattr(self.args, 'presence_penalty'):
        #     self.args.presence_penalty = 0
        # if not hasattr(self.args, 'response_format'):
        #     self.args.response_format = { "type": "json_object" }

    def generate(self, messages: List[str], ret: Optional[Type[T]] = None):
        while True:
            try:
                kwargs = dict(
                    model=self.args.model,
                    messages=messages,
                    temperature=self.args.temperature,
                    max_tokens=self.args.max_tokens,
                    # top_p=self.args.top_p,
                    # frequency_penalty=self.args.frequency_penalty,
                    # presence_penalty=self.args.presence_penalty,
                    # No longer valid as of 5/1/2025?
                    # parallel_tool_calls=False # Ensures that the tool_calls field is an array of length 0 or 1
                )
                if ret:
                    kwargs['response_format'] = ret
                # import pdb; pdb.set_trace()
                response = self.client.beta.chat.completions.parse(**kwargs)
                # print(f"Response: {response}")
                break
            except Exception as e:
                print(f"Error processing messages: {e}")
                time.sleep(2)
                continue
        return response

    def generate_batch(self, msg_batches: List[List[str]], ret: Optional[Type[T]] = None):
        def process_message(message):
            return self.generate(message, ret)
    
        with concurrent.futures.ThreadPoolExecutor() as executor:
            results = list(executor.map(process_message, msg_batches))
        return results
