# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functions for working with huggingface models."""

import torch._dynamo
import torch._inductor


MODEL_KWARGS_SPECIAL_CASES = {
    # https://huggingface.co/google/gemma-2-27b-it/discussions/29#66ac02358c3da83ca5ea978d
    "google/gemma-2-27b": dict(torch_dtype=torch.bfloat16),
    "google/gemma-2-27b-it": dict(torch_dtype=torch.bfloat16),
}


def suppress_dynamo_errors(model_name: str) -> None:
  """Suppresses Dynamo errors for Gemma 2 models."""
  if model_name.startswith("google/gemma-2-"):
    torch._dynamo.config.suppress_errors = True  # pylint: disable=protected-access
    # On GCP we sometimes see the following warning:
    # ```CUDAGraph supports dynamic shapes by recording a new graph for each
    #   distinct input size. Recording too many CUDAGraphs may lead to extra
    #   overhead. We have observed 51 distinct sizes. Please consider the
    #   following options for better performance: a) padding inputs to a few
    #   fixed number of shapes; or b) set
    #   torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. Set
    #   torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None
    #   to silence this warning.````
    # After a few more samples, this is followed by the error:
    # `torch._dynamo.exc.Unsupported: cache_size_limit reached`.`
    torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True  # pylint: disable=protected-access

    if model_name.startswith("google/gemma-2-27b"):
      # As of 2025-03-22, transformers 4.49.0 and nvidia pytorch 25.02,
      # Smaller Gemma 2 models work on GCP, but Gemma 2 27b consistently crashes
      # with the following error:
      # `torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor device
      # call_function <built-in function getitem>`. Disabling Dynamo appears
      # to prevent this error (though could cause performance loss.)
      torch._dynamo.config.disable = True  # pylint: disable=protected-access
