from google.cloud import storage
import os


def upload_blob(bucket_name, source_file_name, destination_blob_name):
    """Uploads a file to the bucket."""

    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(destination_blob_name)

    blob.upload_from_filename(source_file_name, timeout=300)

    print(f"File {source_file_name} uploaded to {destination_blob_name}.")


def download_blob(bucket_name, source_blob_name, destination_file_name=None):
    """
    Downloads a blob from the bucket.

    Args:
        bucket_name: GCP bucket name
        source_blob_name: Name of the blob in GCP
        destination_file_name: Optional local path to save the file.
                               If None, the file content is returned directly.

    Returns:
        If destination_file_name is provided, returns None and saves to file.
        Otherwise, returns the file content as bytes.
    """
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(source_blob_name)

    if destination_file_name:
        blob.download_to_filename(destination_file_name)
        print(f"Blob {source_blob_name} downloaded to {destination_file_name}.")
        return None
    else:
        contents = blob.download_as_bytes()
        print(f"Downloaded blob {source_blob_name} as bytes.")
        return contents


def upload_model_to_gcp(bucket_name, model_directory, save_name: str = ""):
    """
    Uploads an entire model directory to GCP with timestamp-based naming.

    Args:
        bucket_name: GCP bucket name
        model_directory: Local path to the model directory
        save_name: Name to save the model as in GCP

    Returns:
        The destination path in GCP where the model was uploaded
    """
    destination_prefix = f"{save_name}"

    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)

    # Walk through all files in the model directory
    for root, _, files in os.walk(model_directory):
        for filename in files:
            local_path = os.path.join(root, filename)
            # Get relative path to maintain directory structure
            relative_path = os.path.relpath(local_path, model_directory)
            destination_blob_name = f"{destination_prefix}/{relative_path}"

            # Skip checkpoint directories
            if "/checkpoint-" in relative_path or relative_path.startswith(
                "checkpoint-"
            ):
                print(f"Skipping checkpoint file: {relative_path}")
                continue

            # Upload file
            blob = bucket.blob(destination_blob_name)
            blob.upload_from_filename(local_path)
            print(f"File {local_path} uploaded to {destination_blob_name}")

    return destination_prefix


def download_model_from_gcp(
    bucket_name, gcp_model_path, local_destination, skip_checkpoints=False
):
    """
    Downloads a model directory from GCP.

    Args:
        bucket_name: GCP bucket name
        gcp_model_path: Path to the model in GCP (e.g., "aime_qwen2.5-7b_gcp_test")
        local_destination: Local directory to download to

    Returns:
        Path to the downloaded model directory
    """
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)

    # Create local destination directory if it doesn't exist
    os.makedirs(local_destination, exist_ok=True)

    # List all blobs with the prefix
    blobs = bucket.list_blobs(prefix=gcp_model_path)

    for blob in blobs:
        # Get relative path from the gcp_model_path
        relative_path = blob.name[len(gcp_model_path) :].lstrip("/")
        if not relative_path:  # Skip if this is the directory itself
            continue

        # Skip checkpoint directories
        if skip_checkpoints and (
            "/checkpoint-" in relative_path or relative_path.startswith("checkpoint-")
        ):
            print(f"Skipping checkpoint file: {blob.name}")
            continue

        # Create local directory structure if needed
        local_path = os.path.join(local_destination, relative_path)
        os.makedirs(os.path.dirname(local_path), exist_ok=True)

        # Download file
        blob.download_to_filename(local_path)
        print(f"Downloaded {blob.name} to {local_path}")

    return local_destination
