# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

r"""Functions for working with huggingface models."""

from collections.abc import Sequence
import functools
from typing import Any
import numpy as np
import pandas as pd

MODELS = (
    ## Open weight models.
    # Mistral.
    "mistralai/Mistral-7B-v0.3",
    "mistralai/Mistral-7B-Instruct-v0.3",
    "mistralai/Mistral-Nemo-Base-2407",
    "mistralai/Mistral-Nemo-Instruct-2407",
    "mistralai/Mistral-Small-24B-Base-2501",
    "mistralai/Mistral-Small-24B-Instruct-2501",
    # Qwen 1.5.
    "Qwen/Qwen1.5-0.5B",
    "Qwen/Qwen1.5-0.5B-Chat",
    "Qwen/Qwen1.5-1.8B",
    "Qwen/Qwen1.5-1.8B-Chat",
    "Qwen/Qwen1.5-4B",
    "Qwen/Qwen1.5-4B-Chat",
    "Qwen/Qwen1.5-7B",
    "Qwen/Qwen1.5-7B-Chat",
    "Qwen/Qwen1.5-14B",
    "Qwen/Qwen1.5-14B-Chat",
    "Qwen/Qwen1.5-32B",
    "Qwen/Qwen1.5-32B-Chat",
    "Qwen/Qwen1.5-72B",
    "Qwen/Qwen1.5-72B-Chat",
    "Qwen/Qwen1.5-110B",
    "Qwen/Qwen1.5-110B-Chat",
    # Qwen 2.
    "Qwen/Qwen2-0.5B",
    "Qwen/Qwen2-0.5B-Instruct",
    "Qwen/Qwen2-1.5B",
    "Qwen/Qwen2-1.5B-Instruct",
    "Qwen/Qwen2-7B",
    "Qwen/Qwen2-7B-Instruct",
    # Mixture models seem to result in frequent crashes.
    # "Qwen/Qwen2-57B-A14B",
    # "Qwen/Qwen2-57B-A14B-Instruct",
    "Qwen/Qwen2-72B",
    "Qwen/Qwen2-72B-Instruct",
    # Qwen 2.5.
    "Qwen/Qwen2.5-0.5B",
    "Qwen/Qwen2.5-0.5B-Instruct",
    "Qwen/Qwen2.5-1.5B",
    "Qwen/Qwen2.5-1.5B-Instruct",
    "Qwen/Qwen2.5-3B",
    "Qwen/Qwen2.5-3B-Instruct",
    "Qwen/Qwen2.5-7B",
    "Qwen/Qwen2.5-7B-Instruct",
    "Qwen/Qwen2.5-14B",
    "Qwen/Qwen2.5-14B-Instruct",
    "Qwen/Qwen2.5-32B",
    "Qwen/Qwen2.5-32B-Instruct",
    "Qwen/Qwen2.5-72B",
    "Qwen/Qwen2.5-72B-Instruct",
    # Yi 1.
    "01-ai/Yi-6B",
    "01-ai/Yi-6B-Chat",
    "01-ai/Yi-34B",
    "01-ai/Yi-34B-Chat",
    # Yi 1.5.
    "01-ai/Yi-1.5-6B",
    "01-ai/Yi-1.5-6B-Chat",
    "01-ai/Yi-1.5-9B",
    "01-ai/Yi-1.5-9B-Chat",
    "01-ai/Yi-1.5-34B",
    "01-ai/Yi-1.5-34B-Chat",
    # Gemma 1 models.
    "google/gemma-2b",
    "google/gemma-2b-it",
    "google/gemma-7b",
    "google/gemma-7b-it",
    # Gemma 2 models.
    "google/gemma-2-2b",
    "google/gemma-2-2b-it",
    "google/gemma-2-9b",
    "google/gemma-2-9b-it",
    "google/gemma-2-27b",
    "google/gemma-2-27b-it",
    # OLMo 2 models.
    "allenai/OLMo-2-1124-7B",
    "allenai/OLMo-2-1124-7B-Instruct",
    "allenai/OLMo-2-1124-13B",
    "allenai/OLMo-2-1124-13B-Instruct",
    ## API models.
    # Gemini API.
    "gemini_api/gemini-1.5-flash-8b-001",
    "gemini_api/gemini-1.5-flash-002",
    "gemini_api/gemini-1.5-pro-002",
    "gemini_api/gemini-2.0-flash-001",
    "gemini_api/gemini-2.0-flash-lite-001",
    # # OpenAI API.
    "openai_api/gpt-4o-mini-2024-07-18",
    # # Anthropic API.
    "anthropic_api/claude-3-5-haiku-20241022",
)

MODEL_PREFIX_TO_FAMILY_NAME = {
    ## Open weight models.
    "mistralai/Mistral-": "Mistral",
    "google/gemma-2-": "Gemma 2",
    # Gemma 1 string is a prefix of Gemma 2 string, so it needs to come after.
    "google/gemma-": "Gemma 1",
    "allenai/OLMo-2-1124-": "OLMo 2",
    "Qwen/Qwen1.5-": "Qwen 1.5",
    "Qwen/Qwen2-": "Qwen 2",
    "Qwen/Qwen2.5-": "Qwen 2.5",
    # Yi- is a prefix of Yi-1.5, so it needs to come after.
    "01-ai/Yi-1.5-": "Yi 1.5",
    "01-ai/Yi-": "Yi 1",
    ## API models.
    "gemini_api/gemini-2.0": "Gemini 2.0",  # flash, flash-lite, pro
    "gemini_api/gemini-1.5": "Gemini 1.5",  # flash, flash-8b, pro
    "openai_api/gpt-4o": "GPT-4o",  # gpt-4o and gpt-4o-mini
    "anthropic_api/claude-3-5": "Claude 3.5",  # haiku, sonnet, opus
}


API_PREFIXES = frozenset({
    "gemini_api/",
    "openai_api/",
    "anthropic_api/",
})


def is_api_model(model_name: str) -> bool:
  return any(model_name.startswith(prefix) for prefix in API_PREFIXES)


PARAM_STRING_SPECIAL_CASES = {
    "Nemo": 12e9,  # mistralai/Mistral-Nemo-Base-2407
    "8x7B": 42e9,  # mistralai/Mixtral-8x7B-v0.1
    "8x22B": 176e9,  # mistralai/Mixtral-8x22B-v0.1
    "Small-24B": 24e9,  # mistralai/Mistral-Small-24B-Base-2501
    "57B-A14B": 57e9,  # Qwen/Qwen2-57B-A14B
}

MIXTURE_MODELS = frozenset((
    "Qwen/Qwen2-57B-A14B",
    "Qwen/Qwen2-57B-A14B-Instruct",
))

IT_SUFFIXES = frozenset({"-chat", "-Chat", "-Instruct", "-it"})
PT_SUFFIXES = frozenset({"-Base", "-pt"})


def parse_param_count(param_string: str) -> float:
  """Parses a model name string, returning the param count."""
  if param_string in PARAM_STRING_SPECIAL_CASES:
    return PARAM_STRING_SPECIAL_CASES[param_string]
  number_abbrev = param_string[-1].lower()
  abbrev_to_float = dict(
      m=1e6,
      b=1e9,
  )
  if number_abbrev not in abbrev_to_float.keys():
    raise ValueError(
        f"Unknown abbreviation {number_abbrev} in input {param_string}."
    )
  return (
      float(param_string[:-1].replace("_", "."))
      * abbrev_to_float[number_abbrev]
  )


@functools.cache
def get_model_info(model_name: str) -> dict[str, Any]:
  """Parses a model name string, returning the family and param count."""
  for prefix, family in MODEL_PREFIX_TO_FAMILY_NAME.items():
    if model_name.startswith(prefix):
      for api_type in API_PREFIXES:
        if model_name.startswith(api_type):
          return dict(
              family=family,
              api_type=api_type.removesuffix("/"),
              param_count=None,
              log_param_count=None,
              instruction_tuned=True,
              is_mixture=None,
          )
      instruction_tuned = False
      version_string = model_name[len(prefix) :]
      version_string = version_string.removesuffix("-hf")
      version_string = version_string.removesuffix("-v0.1")  # Mistral
      version_string = version_string.removesuffix("-v0.3")
      version_string = version_string.removesuffix("-2407")  # Mistral Nemo
      version_string = version_string.removesuffix("-2501")  # Mistral Small
      for suffix in IT_SUFFIXES:
        if version_string.endswith(suffix):
          version_string = version_string[: -len(suffix)]
          instruction_tuned = True
          break
      for suffix in PT_SUFFIXES:
        if version_string.endswith(suffix):
          version_string = version_string[: -len(suffix)]
          assert instruction_tuned == False  # pylint: disable=singleton-comparison
          break
      version_string = version_string.removesuffix("-0724")  # OLMo
      param_count = parse_param_count(version_string)

      return dict(
          family=family,
          api_type="local",
          param_count=param_count,
          log_param_count=np.log10(param_count),
          instruction_tuned=instruction_tuned,
          is_mixture=model_name in MIXTURE_MODELS,
      )
  raise ValueError(f"Unknown model family for model: {model_name}")


def get_models_df(models: Sequence[str] = MODELS) -> pd.DataFrame:
  """Returns a dataframe with model info for all models."""
  rows = [
      dict(
          model_name=model_name,
          **get_model_info(model_name),
      )
      for model_name in models
  ]
  return pd.DataFrame(rows)
