import logging
import time
from concurrent.futures import ThreadPoolExecutor

from tqdm import tqdm

from src.infer import (
    query_anthropic,
    query_embeddings,
    query_gemini,
    query_openai,
    query_together,
    query_vllm_embed_server,
    query_vllm_server,
)

logger = logging.getLogger(__name__)


MODELS = {
    # OpenAI models
    "gpt-3.5-turbo": {
        "openai": {"query_func": query_openai, "model_id": "gpt-3.5-turbo"}
    },
    "openai/gpt-oss-20b": {
        "vllm_server": {
            "query_func": query_vllm_server,
            "model_id": "openai/gpt-oss-20b",
        }
    },
    "gpt-5-2025-08-07": {
        "openai": {"query_func": query_openai, "model_id": "gpt-5-2025-08-07"}
    },
    "gpt-4o": {"openai": {"query_func": query_openai, "model_id": "gpt-4o"}},
    "gpt-4o-mini": {"openai": {"query_func": query_openai, "model_id": "gpt-4o-mini"}},
    "gpt-4-turbo": {"openai": {"query_func": query_openai, "model_id": "gpt-4-turbo"}},
    "gpt-4.5-preview-2025-02-27": {
        "openai": {"query_func": query_openai, "model_id": "gpt-4.5-preview-2025-02-27"}
    },
    "gpt-4-0125-preview": {
        "openai": {"query_func": query_openai, "model_id": "gpt-4-0125-preview"}
    },
    "gpt-4o-2024-11-20": {
        "openai": {"query_func": query_openai, "model_id": "gpt-4o-2024-11-20"}
    },
    "gpt-4o-2024-08-06": {
        "openai": {"query_func": query_openai, "model_id": "gpt-4o-2024-08-06"}
    },
    "gpt-4o-2024-05-13": {
        "openai": {"query_func": query_openai, "model_id": "gpt-4o-2024-05-13"}
    },
    "gpt-4-0613": {"openai": {"query_func": query_openai, "model_id": "gpt-4-0613"}},
    "o1": {"openai": {"query_func": query_openai, "model_id": "o1"}},
    "o1-mini": {"openai": {"query_func": query_openai, "model_id": "o1-mini"}},
    "o3-mini": {"openai": {"query_func": query_openai, "model_id": "o3-mini"}},
    "o3-mini-2025-01-31": {
        "openai": {"query_func": query_openai, "model_id": "o3-mini-2025-01-31"}
    },
    "o3-2025-04-16": {
        "openai": {"query_func": query_openai, "model_id": "o3-2025-04-16"}
    },
    # OpenAI Embedding models
    "text-embedding-ada-002": {
        "openai": {"query_func": query_embeddings, "model_id": "text-embedding-ada-002"}
    },
    "text-embedding-3-small": {
        "openai": {"query_func": query_embeddings, "model_id": "text-embedding-3-small"}
    },
    "text-embedding-3-large": {
        "openai": {"query_func": query_embeddings, "model_id": "text-embedding-3-large"}
    },
    # Anthropic/Claude models
    "claude-opus-4-1-20250805": {
        "anthropic": {
            "query_func": query_anthropic,
            "model_id": "claude-opus-4-1-20250805",
        }
    },
    "claude-opus-4-20250514": {
        "anthropic": {
            "query_func": query_anthropic,
            "model_id": "claude-opus-4-20250514",
        }
    },
    "claude-sonnet-4-20250514": {
        "anthropic": {
            "query_func": query_anthropic,
            "model_id": "claude-sonnet-4-20250514",
        }
    },
    "claude-3-5-sonnet-20241022": {
        "anthropic": {
            "query_func": query_anthropic,
            "model_id": "claude-3-5-sonnet-20241022",
        }
    },
    "claude-3-sonnet-20240229": {
        "anthropic": {
            "query_func": query_anthropic,
            "model_id": "claude-3-sonnet-20240229",
        }
    },
    "claude-3-7-sonnet-20250219": {
        "anthropic": {
            "query_func": query_anthropic,
            "model_id": "claude-3-7-sonnet-20250219",
        }
    },
    "claude-3-5-sonnet-20240620": {
        "anthropic": {
            "query_func": query_anthropic,
            "model_id": "claude-3-5-sonnet-20240620",
        }
    },
    "claude-3-5-haiku-20241022": {
        "anthropic": {
            "query_func": query_anthropic,
            "model_id": "claude-3-5-haiku-20241022",
        }
    },
    # DeepSeek models
    "deepseek-ai/DeepSeek-R1": {
        "together": {
            "query_func": query_together,
            "model_id": "deepseek-ai/DeepSeek-R1",
        },
    },
    "deepseek-ai/DeepSeek-R1-Distill-Llama-70B": {
        "together": {
            "query_func": query_together,
            "model_id": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
        },
    },
    "deepseek-ai/DeepSeek-V3": {
        "together": {
            "query_func": query_together,
            "model_id": "deepseek-ai/DeepSeek-V3",
        },
    },
    # Together models
    "databricks/dbrx-instruct": {
        "together": {
            "query_func": query_together,
            "model_id": "databricks/dbrx-instruct",
        }
    },
    "google/gemma-2-27b-it": {
        "vllm_server": {
            "query_func": query_vllm_server,
            "model_id": "google/gemma-2-27b-it",
        }
    },
    "google/gemma-3-27b-it-together": {
        "vllm_server": {
            "query_func": query_together,
            "model_id": "google/gemma-3-27b-it",
        }
    },
    "google/gemma-3-27b-it": {
        "vllm_server": {
            "query_func": query_vllm_server,
            "model_id": "google/gemma-3-27b-it",
        }
    },
    # Google Gemini models
    "gemini-2.5-pro-preview-03-25": {
        "gemini": {
            "query_func": query_gemini,
            "model_id": "gemini-2.5-pro-preview-03-25",
        }
    },
    "gemini-2.0-flash": {
        "gemini": {"query_func": query_gemini, "model_id": "gemini-2.0-flash"}
    },
    "gemini-2.0-flash-lite": {
        "gemini": {"query_func": query_gemini, "model_id": "gemini-2.0-flash-lite"}
    },
    # Qwen models
    "Qwen/Qwen2.5-VL-72B-Instruct": {
        "vllm_server": {
            "query_func": query_vllm_server,
            "model_id": "Qwen/Qwen2.5-VL-72B-Instruct",
        }
    },
    "Qwen/QwQ-32B": {
        "vllm_server": {"query_func": query_vllm_server, "model_id": "Qwen/QwQ-32B"}
    },
    "Qwen/QwQ-32B-together": {
        "vllm_server": {"query_func": query_together, "model_id": "Qwen/QwQ-32B"}
    },
    "Qwen/Qwen3-32B": {
        "vllm_server": {
            "query_func": query_vllm_server,
            "model_id": "Qwen/Qwen3-32B",
        }
    },
    "Qwen/Qwen2.5-7B-Instruct-Turbo": {
        "vllm_server": {
            "query_func": query_vllm_server,
            "model_id": "Qwen/Qwen2.5-7B-Instruct-Turbo",
        }
    },
    "Qwen/Qwen3-Embedding-0.6B": {
        "vllm_server": {
            "query_func": query_vllm_server,
            "model_id": "Qwen/Qwen2.5-7B-Instruct-Turbo",
        }
    },
    "Qwen/Qwen3-235B-A22B-fp8-tput": {
        "together": {
            "query_func": query_together,
            "model_id": "Qwen/Qwen3-235B-A22B-fp8-tput",
        }
    },
    # "Qwen/Qwen3-Embedding-0.6B": {
    #     "vllm_server": {
    #         "query_func": query_vllm_embed_server,
    #         "model_id": "Qwen/Qwen3-Embedding-0.6B",
    #     }
    # },
    # Meta Llama models
    "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8": {
        "together": {
            "query_func": query_together,
            "model_id": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
        }
    },
    "meta-llama/Llama-4-Scout-17B-16E-Instruct": {
        "together": {
            "query_func": query_together,
            "model_id": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
        }
    },
    "meta-llama/Llama-3.3-70B-Instruct": {
        "vllm_server": {
            "query_func": query_vllm_server,
            "model_id": "meta-llama/Llama-3.3-70B-Instruct",
        }
    },
    "huihui-ai/Llama-3.3-70B-Instruct-abliterated-finetuned": {
        "vllm_server": {
            "query_func": query_vllm_server,
            "model_id": "huihui-ai/Llama-3.3-70B-Instruct-abliterated-finetuned",
        }
    },
    "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo": {
        "vllm_server": {
            "query_func": query_vllm_server,
            "model_id": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
        }
    },
    "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo": {
        "vllm_server": {
            "query_func": query_vllm_server,
            "model_id": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
        }
    },
    "meta-llama/Llama-3.2-3B-Instruct-Turbo": {
        "vllm_server": {
            "query_func": query_vllm_server,
            "model_id": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
        }
    },
    "meta-llama/Llama-3.2-3B-Instruct": {
        "vllm_server": {
            "query_func": query_vllm_server,
            "model_id": "meta-llama/Llama-3.2-3B-Instruct",
        }
    },
    "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": {
        "together": {
            "query_func": query_together,
            "model_id": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
        }
    },
    "meta-llama/Llama-3.3-70B-Instruct-Turbo-together": {
        "vllm_server": {
            "query_func": query_together,
            "model_id": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
        }
    },
    "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": {
        "vllm_server": {
            "query_func": query_vllm_server,
            "model_id": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
        }
    },
    # meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo
    "meta-llama/Llama-3.1-8B-Instruct": {
        "vllm_server": {
            "query_func": query_vllm_server,
            "model_id": "meta-llama/Llama-3.1-8B-Instruct",
        }
    },
    "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": {
        "vllm_server": {
            "query_func": query_together,
            "model_id": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
        }
    },
    # Mistral models
    "mistralai/Mixtral-8x7B-Instruct-v0.1": {
        "together": {
            "query_func": query_together,
            "model_id": "mistralai/Mixtral-8x7B-Instruct-v0.1",
        }
    },
    "mistralai/Mixtral-8x22B-Instruct-v0.1": {
        "together": {
            "query_func": query_together,
            "model_id": "mistralai/Mixtral-8x22B-Instruct-v0.1",
        }
    },
    "mistralai/Mistral-Small-24B-Instruct-2501": {
        "together": {
            "query_func": query_together,
            "model_id": "mistralai/Mistral-Small-24B-Instruct-2501",
        }
    },
    "moonshotai/Kimi-K2-Instruct": {
        "together": {
            "query_func": query_together,
            "model_id": "moonshotai/Kimi-K2-Instruct",
        }
    },
}


class   BlackBoxModel:
    def __init__(
        self,
        model_name: str,
        provider: str = None,
        system_prompt: str = None,
        max_retries: int = 5,
        retry_delay: float = 5.0,
        port: int = 8000,
    ):
        self.model_name = model_name
        self.provider = provider or self.get_default_provider(model_name)
        self.system_prompt = system_prompt
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.model_info = self.get_model_info(model_name, self.provider)
        self.query_func = self.model_info["query_func"]
        self.model_id = self.model_info["model_id"]

    @staticmethod
    def get_default_provider(model_name: str) -> str:
        if model_name not in MODELS:
            raise ValueError(f"Model {model_name} not found in MODELS")
        providers = list(MODELS[model_name].keys())
        if not providers:
            raise ValueError(f"No providers found for model {model_name}")
        return providers[0]

    def get_model_info(self, model_name: str, provider: str) -> dict:
        if model_name not in MODELS:
            raise ValueError(f"Model {model_name} not found in MODELS")
        if provider not in MODELS[model_name]:
            raise ValueError(
                f"Provider '{provider}' not supported for model '{model_name}'. "
                f"Available providers: {list(MODELS[model_name].keys())}"
            )
        return MODELS[model_name][provider]

    def query(self, prompt: str, **kwargs) -> str:
        if "system_prompt" not in kwargs:
            kwargs["system_prompt"] = self.system_prompt
        if "temperature" in kwargs and self.model_name in [
            "o1",
            "o1-mini",
            "o3-mini",
        ]:
            logger.warning(
                f"Temperature is not supported for {self.model_name}, ignoring"
            )

        return self.query_func(prompt=prompt, model=self.model_id, **kwargs)

    def query_parallel(
        self,
        prompts: list[str],
        max_threads: int = 50,
        show_progress: bool = True,
        **kwargs,
    ) -> list[str]:
        with ThreadPoolExecutor(max_workers=max_threads) as executor:
            if show_progress:
                results = list(
                    tqdm(
                        executor.map(
                            lambda prompt: self.query(prompt, **kwargs), prompts
                        ),
                        total=len(prompts),
                        desc=f"Querying {self.model_name}",
                    )
                )
            else:
                results = list(
                    executor.map(lambda prompt: self.query(prompt, **kwargs), prompts)
                )

        return results

    def embed(self, text, batch_size=100, **kwargs) -> list:
        embedding_model = self.model_name

        attempt = 0
        while attempt < self.max_retries:
            try:
                return query_embeddings(
                    text=text, model=embedding_model, batch_size=batch_size, **kwargs
                )
            except Exception as e:
                attempt += 1
                logger.error(
                    f"Attempt {attempt} for embeddings failed with exception: {e}"
                )
                if attempt < self.max_retries:
                    time.sleep(self.retry_delay)
                else:
                    logger.error(
                        f"All {self.max_retries} attempts failed for embeddings"
                    )
                    raise

    def embed_parallel(
        self,
        texts: list[str],
        batch_size=100,
        max_threads=10,
        show_progress: bool = True,
        **kwargs,
    ) -> list[list]:
        with ThreadPoolExecutor(max_workers=max_threads) as executor:
            if show_progress:
                results = list(
                    tqdm(
                        executor.map(
                            lambda batch: self.embed(batch, **kwargs),
                            [
                                texts[i : i + batch_size]
                                for i in range(0, len(texts), batch_size)
                            ],
                        ),
                        total=(len(texts) + batch_size - 1) // batch_size,
                        desc=f"Embedding with {self.model_name}",
                    )
                )
            else:
                results = list(
                    executor.map(
                        lambda batch: self.embed(batch, **kwargs),
                        [
                            texts[i : i + batch_size]
                            for i in range(0, len(texts), batch_size)
                        ],
                    )
                )

        embeddings = []
        for batch_result in results:
            if (
                isinstance(batch_result, list)
                and batch_result
                and isinstance(batch_result[0], list)
            ):
                embeddings.extend(batch_result)
            else:
                embeddings.append(batch_result)

        return embeddings
