# Based on implementations from the 4M repo: https://github.com/apple/ml-4m/
import io
import os
import re
import boto3
import webdataset as wds

from boto3.s3.transfer import TransferConfig
from webdataset.handlers import reraise_exception


def setup_s3_args(args):
    if not args.s3_data_endpoint:
        args.s3_data_endpoint = args.s3_endpoint
    

def save_on_s3(filename, s3_path, s3_endpoint):

    s3_client = boto3.client(
                service_name='s3',
                aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
                aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
                endpoint_url=s3_endpoint
            )

    _, bucket, key, _ = re.split("s3://(.*?)/(.*)$", s3_path)

    s3_client.upload_file(filename, bucket, key)


def download_from_s3(s3_path, s3_endpoint, filename, multipart_threshold_mb=512, multipart_chunksize_mb=512):

    MB = 1024 ** 2
    transfer_config = TransferConfig(
        multipart_threshold=multipart_threshold_mb * MB, 
        multipart_chunksize=multipart_chunksize_mb * MB, 
        max_io_queue=1000)


    s3_client = boto3.client(
                service_name='s3',
                aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
                aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
                endpoint_url=s3_endpoint,
            )

    _, bucket, key, _ = re.split("s3://(.*?)/(.*)$", s3_path)

    s3_client.download_file(bucket, key, filename, Config=transfer_config)



def override_wds_s3_tar_loading(s3_data_endpoint, s3_multipart_threshold_mb, s3_multipart_chunksize_mb, s3_max_io_queue):
    
    # When loading from S3 using boto3, hijack webdatasets tar loading
    MB = 1024 ** 2
    transfer_config = TransferConfig(
        multipart_threshold=s3_multipart_threshold_mb * MB, 
        multipart_chunksize=s3_multipart_chunksize_mb * MB, 
        max_io_queue=s3_max_io_queue)

    s3_client = boto3.client(
        service_name='s3',
        aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
        aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
        endpoint_url=s3_data_endpoint,
    )

    def get_bytes_io(path):
        byte_io = io.BytesIO()
        _, bucket, key, _ = re.split("s3://(.*?)/(.*)$", path)
        s3_client.download_fileobj(bucket, key, byte_io, Config=transfer_config)
        byte_io.seek(0)
        return byte_io
    
    def gopen_with_s3(url, mode="rb", bufsize=8192, **kw):
        """gopen from webdataset, but with s3 support"""
        if url.startswith("s3://"):
            return get_bytes_io(url)
        else:
            return wds.gopen.gopen(url, mode, bufsize, **kw)

    def url_opener(data, handler=reraise_exception, **kw):
        for sample in data:
            url = sample["url"]
            try:
                stream = gopen_with_s3(url, **kw)
                # stream = get_bytes_io(url)
                sample.update(stream=stream)
                yield sample
            except Exception as exn:
                exn.args = exn.args + (url,)
                if handler(exn):
                    continue
                else:
                    break

    wds.tariterators.url_opener = url_opener


