import asyncio
import datetime
import os
import pickle as pkl
from io import BytesIO
from time import time
from urllib.parse import urlparse

import boto3
import pytz
import smart_open
from aiobotocore.config import AioConfig
from aiobotocore.session import get_session
from retry import retry


config = AioConfig(max_pool_connections=100)


class cache_read:
    VALID_MODES = ["rb", "r"]
    """Given full s3_uri with bucket name, sync it locally if needed, open it"""

    def __init__(self, s3_path, mode, verbose=True):
        self.s3_path = s3_path
        if mode not in self.VALID_MODES:
            raise ValueError(f'"{mode}" not in {self.VALID_MODES}')
        self.mode = mode
        self.local_file_path = None
        self.file = None
        self.verbose = verbose

    def __enter__(self):
        if os.path.isfile(self.s3_path):
            self.local_file_path = self.s3_path
        else:
            bucket_name, prefix = split_s3_path(self.s3_path)
            self.local_file_path = sync_s3_to_local(
                bucket_name, prefix, verbose=self.verbose
            )

        if self.local_file_path is not None:
            self.file = smart_open.open(self.local_file_path, self.mode)
        return self.file

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.file is not None:
            self.file.close()


def get_or_start_event_loop():
    try:
        loop = asyncio.get_event_loop()
    except RuntimeError as e:
        if "no current event loop" in str(e):
            loop = asyncio.new_event_loop()
        else:
            raise e

    return loop


def content_to_s3(file_path, content: bytes, client=None):
    if client is None:
        client = boto3.client("s3")
    client.upload_fileobj(BytesIO(content), *split_s3_path(file_path))


def split_s3_path(s3_path):
    components = urlparse(s3_path)
    return components.netloc, components.path.lstrip("/")


def sync_s3_to_local(bucket_name, prefix, verbose=True, home_dir=None):
    """
    Sync s3 file to local disc if s3 file modified time > local modified time (or file does not exist)
    Default dir is user's home, otherwise set via S3_CACHE_DIR env
    """
    # Initialize a session using boto3
    session = boto3.Session()

    # Use the session to create a resource
    s3 = session.resource("s3")

    # # Get home directory
    if home_dir is None:
        home_dir = os.getenv("S3_CACHE_DIR", os.path.expanduser("~"))

    # Generate local file path
    local_file_path = os.path.join(home_dir, prefix)
    local_file_dir = os.path.dirname(local_file_path)

    # Make sure the directory exists
    os.makedirs(local_file_dir, exist_ok=True)

    # Get object summary for the file on s3
    s3_obj = s3.Object(bucket_name, prefix)

    # If local file exists, compare modification times
    if os.path.exists(local_file_path):
        # Get modification time of local file
        local_file_mtime = os.path.getmtime(local_file_path)
        local_file_dt = datetime.datetime.fromtimestamp(
            local_file_mtime,
            datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo,
        ).astimezone(pytz.utc)

        # Get 'LastModified' time of s3 object
        s3_obj_dt = s3_obj.last_modified.astimezone(pytz.utc)

        # Download file if it was modified on s3 after the local copy
        if s3_obj_dt > local_file_dt:
            if verbose:
                print(
                    f"Re-downloading {prefix} from {bucket_name}, {s3_obj_dt} > {local_file_dt}"
                )
            s3_obj.download_file(local_file_path)
            if verbose:
                print(f"File updated successfully at {local_file_path}")
        # else:
        #     print(f"File at {local_file_path} is up-to-date.")
    else:
        # If local file doesn't exist, just download
        if verbose:
            print(f"Downloading {prefix} from {bucket_name}")
        s3_obj.download_file(local_file_path)
        if verbose:
            print(f"File downloaded successfully to {local_file_path}")

    return local_file_path


class cache_read:
    VALID_MODES = ["rb", "r"]
    """Given full s3_uri with bucket name, sync it locally if needed, open it"""

    def __init__(self, s3_path, mode, verbose=True):
        self.s3_path = s3_path
        if mode not in self.VALID_MODES:
            raise ValueError(f'"{mode}" not in {self.VALID_MODES}')
        self.mode = mode
        self.local_file_path = None
        self.file = None
        self.verbose = verbose

    def __enter__(self):
        if os.path.isfile(self.s3_path):
            self.local_file_path = self.s3_path
        else:
            bucket_name, prefix = split_s3_path(self.s3_path)
            self.local_file_path = sync_s3_to_local(
                bucket_name, prefix, verbose=self.verbose
            )

        if self.local_file_path is not None:
            self.file = smart_open.open(self.local_file_path, self.mode)
        return self.file

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.file is not None:
            self.file.close()


async def fetch_s3(bucket, key, client):
    response = await client.get_object(Bucket=bucket, Key=key)
    # this will ensure the connection is correctly re-used/closed
    async with response["Body"] as stream:
        result = await stream.read()
    return BytesIO(result)


def list_s3(s3_uri):
    bucket_name, prefix = split_s3_path(s3_uri)
    s3 = boto3.client("s3")

    paginator = s3.get_paginator("list_objects_v2")
    pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)

    s3_paths = []
    for page in pages:
        if "Contents" in page:
            for obj in page["Contents"]:
                s3_paths.append(f's3://{bucket_name}/{obj["Key"]}')

    return s3_paths


# this is essentially here to handle an issue where a
# subsequent task runs too quickly after a prev task writes to s3.
# there's sometimes a small delay between upload and availability.
@retry(OSError, tries=10, delay=1, backoff=2)
def load_s3_pickle_retry(s3_path: str):
    with smart_open.open(s3_path, "rb") as f:
        input_data = pkl.load(f)
    return input_data
