import concurrent.futures
import logging
import os
from urllib.parse import urlparse

import requests
from tqdm.auto import tqdm

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
log = logging.getLogger(__name__)

CHUNK_SIZE = 8192
DEFAULT_TIMEOUT = 60


def _download_single_url(url, out_dir, skip_existing=True):
    try:
        parsed_url = urlparse(url)
        filename = os.path.basename(parsed_url.path)
        if not filename:
            filename = f"downloaded_file_{hash(url)}.dat"
            log.warning(
                f"Could not determine filename from URL '{url}', using '{filename}'"
            )
        out_path = os.path.join(out_dir, filename)

        if skip_existing and os.path.exists(out_path):
            return url, "skipped"

        log.debug(f"Starting download: {url} -> {out_path}")
        response = requests.get(url, stream=True, timeout=DEFAULT_TIMEOUT)
        response.raise_for_status()

        temp_out_path = out_path + ".part"
        bytes_downloaded = 0
        with open(temp_out_path, "wb") as f:
            for chunk in response.iter_content(chunk_size=CHUNK_SIZE):
                if chunk:
                    f.write(chunk)
                    bytes_downloaded += len(chunk)

        os.rename(temp_out_path, out_path)
        log.debug(f"Finished download: {filename} ({bytes_downloaded} bytes)")
        return url, "success"

    except requests.exceptions.RequestException as e:
        log.error(f"Download failed for {url}: {e}")
        if "temp_out_path" in locals() and os.path.exists(temp_out_path):
            try:
                os.remove(temp_out_path)
            except OSError as rm_err:
                log.error(f"Could not remove partial file {temp_out_path}: {rm_err}")
        return url, f"failed: {e}"
    except Exception as e:
        log.error(f"An unexpected error occurred for {url}: {e}")
        if "temp_out_path" in locals() and os.path.exists(temp_out_path):
            try:
                os.remove(temp_out_path)
            except OSError as rm_err:
                log.error(f"Could not remove partial file {temp_out_path}: {rm_err}")
        return url, f"failed: {e}"


def download_files(uris, num_parallel_calls, out_dir, skip_existing=True):
    if not uris:
        log.info("No URIs provided, nothing to download.")
        return {}

    if num_parallel_calls <= 0:
        raise ValueError("num_parallel_calls must be greater than 0")

    try:
        os.makedirs(out_dir, exist_ok=True)
        log.info(f"Output directory set to: {os.path.abspath(out_dir)}")
    except OSError as e:
        log.error(f"Could not create output directory '{out_dir}': {e}")
        raise

    results = {}
    futures = []

    log.info(
        f"Starting download of {len(uris)} files with {num_parallel_calls} parallel workers..."
    )

    with concurrent.futures.ThreadPoolExecutor(
        max_workers=num_parallel_calls
    ) as executor:
        for uri in uris:
            future = executor.submit(_download_single_url, uri, out_dir, skip_existing)
            futures.append(future)

        for future in tqdm(
            concurrent.futures.as_completed(futures),
            total=len(uris),
            desc="Downloading Tranches",
            unit="file",
        ):
            try:
                url, status = future.result()
                results[url] = status
                if "failed" in status:
                    log.warning(f"Download status for {url}: {status}")
            except Exception as exc:
                log.error(
                    f"An unexpected error occurred retrieving future result: {exc}"
                )

    success_count = sum(1 for status in results.values() if status == "success")
    skipped_count = sum(1 for status in results.values() if status == "skipped")
    failed_count = sum(1 for status in results.values() if "failed" in status)

    log.info("\n--- Download Summary ---")
    log.info(f"Total URIs:      {len(uris)}")
    log.info(f"Successful:      {success_count}")
    log.info(f"Skipped:         {skipped_count}")
    log.info(f"Failed:          {failed_count}")

    if failed_count > 0:
        log.warning("Failed downloads:")
        for url, status in results.items():
            if "failed" in status:
                log.warning(f"  - {url} ({status})")

    return results
