import concurrent.futures
import datetime
import getpass
import logging
import os
import pathlib
import re
import shutil
import stat
import time
import urllib.parse

import boto3
import boto3.s3.transfer as s3_transfer
import botocore
import filelock
import fsspec
import fsspec.generic
import s3transfer.futures as s3_transfer_futures
import tqdm_loggable.auto as tqdm
from types_boto3_s3.service_resource import ObjectSummary

# Environment variable to control cache directory path, ~/.cache/openpi will be used by default.
_OPENPI_DATA_HOME = "OPENPI_DATA_HOME"

logger = logging.getLogger(__name__)


def get_cache_dir() -> pathlib.Path:
    default_dir = "/mnt/blob/checkpoints"
    if os.path.exists("/mnt/weka"):  # noqa: PTH110
        default_dir = f"/mnt/weka/{getpass.getuser()}/.cache/openpi"

    cache_dir = pathlib.Path(os.getenv(_OPENPI_DATA_HOME, default_dir)).expanduser().resolve()
    cache_dir.mkdir(parents=True, exist_ok=True)
    _set_folder_permission(cache_dir)
    return cache_dir


def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path:
    """Download a file or directory from a remote filesystem to the local cache, and return the local path.

    If the local file already exists, it will be returned directly.

    It is safe to call this function concurrently from multiple processes.
    See `get_cache_dir` for more details on the cache directory.

    Args:
        url: URL to the file to download.
        force_download: If True, the file will be downloaded even if it already exists in the cache.
        **kwargs: Additional arguments to pass to fsspec.

    Returns:
        Local path to the downloaded file or directory. That path is guaranteed to exist and is absolute.
    """
    # Don't use fsspec to parse the url to avoid unnecessary connection to the remote filesystem.
    parsed = urllib.parse.urlparse(url)

    # Short circuit if this is a local path.
    if parsed.scheme == "":
        path = pathlib.Path(url)
        if not path.exists():
            raise FileNotFoundError(f"File not found at {url}")
        return path.resolve()

    cache_dir = get_cache_dir()

    local_path = cache_dir / parsed.netloc / parsed.path.strip("/")
    local_path = local_path.resolve()

    # Check if the cache should be invalidated.
    invalidate_cache = False
    if local_path.exists():
        if force_download or _should_invalidate_cache(cache_dir, local_path):
            invalidate_cache = True
        else:
            return local_path

    try:
        lock_path = local_path.with_suffix(".lock")
        with filelock.FileLock(lock_path):
            # Ensure consistent permissions for the lock file.
            _ensure_permissions(lock_path)
            # First, remove the existing cache if it is expired.
            if invalidate_cache:
                logger.info(f"Removing expired cached entry: {local_path}")
                if local_path.is_dir():
                    shutil.rmtree(local_path)
                else:
                    local_path.unlink()

            # Download the data to a local cache.
            logger.info(f"Downloading {url} to {local_path}")
            scratch_path = local_path.with_suffix(".partial")

            if _is_openpi_url(url):
                # Download without credentials.
                _download_boto3(
                    url,
                    scratch_path,
                    boto_session=boto3.Session(
                        region_name="us-west-1",
                    ),
                    botocore_config=botocore.config.Config(signature_version=botocore.UNSIGNED),
                )
            elif url.startswith("s3://"):
                # Download with default boto3 credentials.
                _download_boto3(url, scratch_path)
            else:
                _download_fsspec(url, scratch_path, **kwargs)

            shutil.move(scratch_path, local_path)
            _ensure_permissions(local_path)

    except PermissionError as e:
        msg = (
            f"Local file permission error was encountered while downloading {url}. "
            f"Please try again after removing the cached data using: `rm -rf {local_path}*`"
        )
        raise PermissionError(msg) from e

    return local_path


def _download_fsspec(url: str, local_path: pathlib.Path, **kwargs) -> None:
    """Download a file from a remote filesystem to the local cache, and return the local path."""
    fs, _ = fsspec.core.url_to_fs(url, **kwargs)
    info = fs.info(url)
    if is_dir := (info["type"] == "directory"):  # noqa: SIM108
        total_size = fs.du(url)
    else:
        total_size = info["size"]
    with tqdm.tqdm(total=total_size, unit="iB", unit_scale=True, unit_divisor=1024) as pbar:
        executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
        future = executor.submit(fs.get, url, local_path, recursive=is_dir)
        while not future.done():
            current_size = sum(f.stat().st_size for f in [*local_path.rglob("*"), local_path] if f.is_file())
            pbar.update(current_size - pbar.n)
            time.sleep(1)
        pbar.update(total_size - pbar.n)


def _download_boto3(
    url: str,
    local_path: pathlib.Path,
    *,
    boto_session: boto3.Session | None = None,
    botocore_config: botocore.config.Config | None = None,
    workers: int = 16,
) -> None:
    """Download a file from the OpenPI S3 bucket using boto3. This is a more performant version of download but can
    only handle s3 urls. In openpi repo, this is mainly used to access assets in S3 with higher throughput.

    Input:
        url: URL to openpi checkpoint path.
        local_path: local path to the downloaded file.
        boto_session: Optional boto3 session, will create by default if not provided.
        botocore_config: Optional botocore config.
        workers: number of workers for downloading.
    """

    def validate_and_parse_url(maybe_s3_url: str) -> tuple[str, str]:
        parsed = urllib.parse.urlparse(maybe_s3_url)
        if parsed.scheme != "s3":
            raise ValueError(f"URL must be an S3 URL (s3://), got: {maybe_s3_url}")
        bucket_name = parsed.netloc
        prefix = parsed.path.strip("/")
        return bucket_name, prefix

    bucket_name, prefix = validate_and_parse_url(url)
    session = boto_session or boto3.Session()

    s3api = session.resource("s3", config=botocore_config)
    bucket = s3api.Bucket(bucket_name)

    # Check if prefix points to an object and if not, assume that it's a directory and add a trailing slash.
    try:
        bucket.Object(prefix).load()
    except botocore.exceptions.ClientError:
        # Make sure to append a "/" to prevent getting objects from a different directory that shares the same prefix.
        # For example, if we are downloading from s3://bucket/foo, we don't want to also download from s3://bucket/foobar.
        if not prefix.endswith("/"):
            prefix = prefix + "/"

    # Get all candidate objects, filter out directories.
    objects = [x for x in bucket.objects.filter(Prefix=prefix) if not x.key.endswith("/")]
    if not objects:
        raise FileNotFoundError(f"No objects found at {url}")

    total_size = sum(obj.size for obj in objects)

    s3t = _get_s3_transfer_manager(session, workers, botocore_config=botocore_config)

    def transfer(
        s3obj: ObjectSummary, dest_path: pathlib.Path, progress_func
    ) -> s3_transfer_futures.TransferFuture | None:
        if dest_path.exists():
            dest_stat = dest_path.stat()
            if s3obj.size == dest_stat.st_size:
                progress_func(s3obj.size)
                return None
        dest_path.parent.mkdir(parents=True, exist_ok=True)
        return s3t.download(
            bucket_name,
            s3obj.key,
            str(dest_path),
            subscribers=[
                s3_transfer.ProgressCallbackInvoker(progress_func),
            ],
        )

    try:
        with tqdm.tqdm(total=total_size, unit="iB", unit_scale=True, unit_divisor=1024) as pbar:
            if os.getenv("IS_DOCKER", "false").lower() == "true":
                # tqdm is bugged when using docker-compose. See https://github.com/tqdm/tqdm/issues/771
                def update_progress(size: int) -> None:
                    pbar.update(size)
                    print(pbar)
            else:

                def update_progress(size: int) -> None:
                    pbar.update(size)

            futures = []
            for obj in objects:
                relative_path = pathlib.Path(obj.key).relative_to(prefix)
                dest_path = local_path / relative_path
                if future := transfer(obj, dest_path, update_progress):
                    futures.append(future)
            for future in futures:
                future.result()
    finally:
        s3t.shutdown()


def _get_s3_transfer_manager(
    session: boto3.Session, workers: int, botocore_config: botocore.config.Config | None = None
) -> s3_transfer.TransferManager:
    config = botocore.config.Config(max_pool_connections=workers)
    if botocore_config is not None:
        config = config.merge(botocore_config)
    s3client = session.client("s3", config=config)
    transfer_config = s3_transfer.TransferConfig(
        use_threads=True,
        max_concurrency=workers,
    )
    return s3_transfer.create_transfer_manager(s3client, transfer_config)


def _set_permission(path: pathlib.Path, target_permission: int):
    """chmod requires executable permission to be set, so we skip if the permission is already match with the target."""
    if path.stat().st_mode & target_permission == target_permission:
        logger.debug(f"Skipping {path} because it already has correct permissions")
        return
    path.chmod(target_permission)
    logger.debug(f"Set {path} to {target_permission}")


def _set_folder_permission(folder_path: pathlib.Path) -> None:
    """Set folder permission to be read, write and searchable."""
    _set_permission(folder_path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)


def _ensure_permissions(path: pathlib.Path) -> None:
    """Since we are sharing cache directory with containerized runtime as well as training script, we need to
    ensure that the cache directory has the correct permissions.
    """

    def _setup_folder_permission_between_cache_dir_and_path(path: pathlib.Path) -> None:
        cache_dir = get_cache_dir()
        relative_path = path.relative_to(cache_dir)
        moving_path = cache_dir
        for part in relative_path.parts:
            _set_folder_permission(moving_path / part)
            moving_path = moving_path / part

    def _set_file_permission(file_path: pathlib.Path) -> None:
        """Set all files to be read & writable, if it is a script, keep it as a script."""
        file_rw = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH | stat.S_IWOTH
        if file_path.stat().st_mode & 0o100:
            _set_permission(file_path, file_rw | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
        else:
            _set_permission(file_path, file_rw)

    _setup_folder_permission_between_cache_dir_and_path(path)
    for root, dirs, files in os.walk(str(path)):
        root_path = pathlib.Path(root)
        for file in files:
            file_path = root_path / file
            _set_file_permission(file_path)

        for dir in dirs:
            dir_path = root_path / dir
            _set_folder_permission(dir_path)


def _is_openpi_url(url: str) -> bool:
    """Check if the url is an OpenPI S3 bucket url."""
    return url.startswith("s3://openpi-assets/")


def _get_mtime(year: int, month: int, day: int) -> float:
    """Get the mtime of a given date at midnight UTC."""
    date = datetime.datetime(year, month, day, tzinfo=datetime.UTC)
    return time.mktime(date.timetuple())


# Map of relative paths, defined as regular expressions, to expiration timestamps (mtime format).
# Partial matching will be used from top to bottom and the first match will be chosen.
# Cached entries will be retained only if they are newer than the expiration timestamp.
_INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = {
    re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6),
    re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3),
}


def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool:
    """Invalidate the cache if it is expired. Return True if the cache was invalidated."""

    assert local_path.exists(), f"File not found at {local_path}"

    relative_path = str(local_path.relative_to(cache_dir))
    for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items():
        if pattern.match(relative_path):
            # Remove if not newer than the expiration timestamp.
            return local_path.stat().st_mtime <= expire_time

    return False
