# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functions for reading and writing data, possibly from cloud."""

import functools
import os
from absl import logging
from google.auth import exceptions as auth_exceptions
from google.cloud import storage
from google.cloud.storage import transfer_manager


GCS_PREFIX = "gs://"


def is_gcs(file_path: str) -> bool:
  return file_path.startswith(GCS_PREFIX)


def parse_gcs_path(file_path: str) -> tuple[str, str]:
  if not is_gcs(file_path):
    raise ValueError(f"Not a GCS path: {file_path}")
  bucket_name, blob_name = file_path[len(GCS_PREFIX) :].split("/", 1)
  return bucket_name, blob_name


@functools.cache
def get_gcs_client() -> storage.Client:
  try:
    return storage.Client()
  except auth_exceptions.DefaultCredentialsError:
    logging.error(
        "Couldn't create a GCS client. If you're running locally, did you"
        " copy the application default credentials file?"
    )
    raise


class OpenFile:
  """Shared interface for reading and writing both local and GCS files."""

  def __init__(self, file_path, mode="r"):
    self.file_path = file_path
    self.mode = mode
    self.file_obj = None

  def __enter__(self):
    if is_gcs(self.file_path):
      bucket_name, blob_name = parse_gcs_path(self.file_path)
      # Initialize GCS client and get blob.
      bucket = get_gcs_client().bucket(bucket_name)
      blob = bucket.blob(blob_name)

      if self.mode == "wb":
        # ignore_flush to avoid error when running pandas df.to_parquet.
        self.file_obj = blob.open(mode=self.mode, ignore_flush=True)
      else:
        self.file_obj = blob.open(mode=self.mode)
    else:
      # Open local file.
      os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
      self.file_obj = open(self.file_path, mode=self.mode)
    return self.file_obj

  def __exit__(self, exc_type, exc_val, exc_tb):
    if self.file_obj:
      self.file_obj.close()


def download_model_from_gcs(
    model_name: str,
    source_dir: str,
    destination_dir: str,
    max_download_attempts: int = 5,
) -> None:
  """Downloads a huggingface model from GCS."""
  download_succeeded = False
  n_attempts = 0
  bucket_name, gcs_source_dir = parse_gcs_path(source_dir)
  gcs_bucket = get_gcs_client().bucket(bucket_name)
  while not download_succeeded:
    if n_attempts >= max_download_attempts:
      raise RuntimeError("Too many failed download attempts, aborting.")
    n_attempts += 1
    gcs_path = os.path.join(gcs_source_dir, model_name)
    blobs = gcs_bucket.list_blobs(prefix=gcs_path)
    blob_names = [blob.name[len(gcs_source_dir) :] for blob in blobs]
    logging.info(
        "Downloading %d blobs from %s, attempt %d/%d...",
        len(blob_names),
        gcs_path,
        n_attempts,
        max_download_attempts,
    )
    results = transfer_manager.download_many_to_path(
        gcs_bucket,
        blob_names,
        blob_name_prefix=gcs_source_dir,
        destination_directory=destination_dir,
    )
    download_succeeded = True
    for name, result in zip(blob_names, results):
      # The results list is either `None` or an exception for each blob in
      # the input list, in order.
      if isinstance(result, Exception):
        logging.warning(
            "Failed to download %s due to exception, retrying: %s",
            name,
            repr(result),
        )
        download_succeeded = False
