import os
import math
import zlib
import uuid
import shelve
import argparse
import random
import re
import collections
import requests
import pandas as pd

from io import BytesIO
from tqdm import tqdm
from PIL import Image, ImageFile
from multiprocessing import Pool, cpu_count
from torchvision import transforms

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------

# Regex that recognises "banana" or "bananas" (case‑insensitive, whole words)
BANANA_REGEX = re.compile(r"\bbananas?\b", re.IGNORECASE)

# Fixed image transform (224×224 RGB square, bicubic resize + centre‑crop)
transform = transforms.Compose([
    transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
])

# Pillow often chokes on CC‑images truncated mid‑stream – allow them anyway
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Number of worker chunks that may complete with *zero* successes before we
# assume the network or target host is completely unreachable and abort.
FAIL_FAST_CHUNKS = 10

# ---------------------------------------------------------------------------
# Download helper
# ---------------------------------------------------------------------------

def _filename(url: str) -> str:
    """Deterministic short file‑name for a URL (crc32 hex)."""
    return f"{zlib.crc32(url.encode('utf-8')) & 0xffffffff:08x}.png"


def download(row: pd.Series) -> pd.Series:
    """Download *one* image row; returns the row with status/file columns."""
    rfile = f"images/{_filename(row['image'])}"
    fpath = f"{row['dir']}/{rfile}"

    # Fast‑path – already on disk ✅
    if os.path.isfile(fpath):
        row["status"] = 200
        row["file"] = rfile
        return row

    try:
        # A modest UA header avoids some 403 errors
        response = requests.get(
            row["image"], timeout=10, allow_redirects=True,
            headers={"User-Agent": "Mozilla/5.0 (compatible; CC-DL/1.0)"},
        )
        row["status"] = response.status_code
    except Exception:
        row["status"] = 0  # network error
        return row

    if not response.ok:
        return row  # HTTP error → keep status

    try:
        image = Image.open(BytesIO(response.content)).convert("RGB")
        image = transform(image)
        os.makedirs(os.path.dirname(fpath), exist_ok=True)
        image.save(fpath)
        row["file"] = rfile
    except Exception:
        row["status"] = 0  # decoding / saving problem

    return row

# ---------------------------------------------------------------------------
# Multiprocessing wrapper
# ---------------------------------------------------------------------------

def _apply(args):
    idx, df_slice = args
    df_slice = df_slice.apply(download, axis=1)
    hist = df_slice["status"].value_counts().to_dict()
    return idx, df_slice, hist  # caller logs histograms


def _extract_df(entry):
    """Back‑compat: older cache format stored (idx, df); new stores df only."""
    if isinstance(entry, tuple):
        return entry[1]
    return entry


def multiprocess(df: pd.DataFrame, cache_dir: str, hash_key: str, chunk: int = 50) -> pd.DataFrame:
    """Fan‑out `download` across processes with resume support (shelve)."""
    with shelve.open(f"{cache_dir}/.{hash_key}") as cache:
        bar = tqdm(total=math.ceil(len(df) / chunk), desc="Download")

        finished = set(map(int, cache.keys()))
        bar.update(len(finished))

        tasks = [
            (i, df.iloc[start:start + chunk])
            for i, start in enumerate(range(0, len(df), chunk))
            if i not in finished
        ]

        total_success = 0
        processed_chunks = 0

        if tasks:
            with Pool(processes=min(cpu_count(), 16)) as pool:
                for i, df_slice, hist in pool.imap_unordered(_apply, tasks, 2):
                    cache[str(i)] = df_slice  # store *just* the DF slice
                    bar.update()

                    processed_chunks += 1
                    success_here = hist.get(200, 0)
                    total_success += success_here
                    print(f"chunk {i:05d}: {hist}")

                    # Fail‑fast if *every* processed chunk so far has zero 200s
                    if processed_chunks >= FAIL_FAST_CHUNKS and total_success == 0:
                        raise RuntimeError(
                            "No successful downloads after "
                            f"{processed_chunks} chunks – check network/URLs."
                        )

        bar.close()

        ordered = sorted(map(int, cache.keys()))
        return pd.concat([_extract_df(cache[str(k)]) for k in ordered])

# ---------------------------------------------------------------------------
# Candidate sampling helpers
# ---------------------------------------------------------------------------

def _match_banana(series: pd.Series) -> pd.Series:
    return series.str.contains(BANANA_REGEX)


def collect_candidates(df: pd.DataFrame, total: int, banana: int, buffer: float = 0.1) -> pd.DataFrame:
    """Return an oversampled candidate set that *definitely* contains ≥banana captions."""
    df_banana = df[_match_banana(df["caption"])].copy()
    df_other = df.drop(df_banana.index)

    avail_banana = len(df_banana)
    if avail_banana < banana:
        raise ValueError(
            f"Dataset only has {avail_banana} captions that match BANANA_REGEX, "
            f"but the quota requires {banana}.",
        )

    need_banana = min(avail_banana, int(math.ceil(banana * (1 + buffer))))
    need_other = int(math.ceil((total - banana) * (1 + buffer)))

    sampled_banana = df_banana.sample(n=need_banana, random_state=42)
    sampled_other = df_other.sample(n=min(need_other, len(df_other)), random_state=42)

    return pd.concat([sampled_banana, sampled_other]).reset_index(drop=True)

# ---------------------------------------------------------------------------
# Main orchestration
# ---------------------------------------------------------------------------

def run(opts):
    os.makedirs(opts.dir, exist_ok=True)
    os.makedirs(os.path.join(opts.dir, "images"), exist_ok=True)

    print("Loading captions TSV …")
    df_all = pd.read_csv(opts.file, sep="\t", names=["caption", "image"])
    df_all["dir"] = opts.dir

    # Initial pull
    pool_df = collect_candidates(df_all, opts.total, opts.banana, opts.buffer)

    hash_key = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{opts.file}-{opts.dir}-{opts.total}-{opts.banana}"))

    success_df = multiprocess(pool_df, cache_dir=opts.dir, hash_key=hash_key)
    success_df = success_df[success_df["status"] == 200]

    # Progressively refill until quotas satisfied
    remaining = df_all.drop(pool_df.index)
    have_total = len(success_df)
    have_banana = _match_banana(success_df["caption"]).sum()

    while have_total < opts.total or have_banana < opts.banana:
        need_total = opts.total - have_total
        need_banana = opts.banana - have_banana
        print(f"Need {need_total} images (of which {need_banana} bananas) – pulling extra …")

        extra_df = collect_candidates(
            remaining,
            need_total + int(need_total * opts.buffer),
            max(0, need_banana) + int(max(0, need_banana) * opts.buffer),
            buffer=opts.buffer,
        )
        remaining = remaining.drop(extra_df.index)
        extra_df["dir"] = opts.dir

        chunk = multiprocess(extra_df, cache_dir=opts.dir, hash_key=hash_key)
        chunk = chunk[chunk["status"] == 200]
        success_df = pd.concat([success_df, chunk])

        have_total = len(success_df)
        have_banana = _match_banana(success_df["caption"]).sum()

    # Trim to exact quota
    banana_part = success_df[_match_banana(success_df["caption"])] [:opts.banana]
    other_part = success_df.drop(banana_part.index)[: opts.total - opts.banana]
    final_df = pd.concat([banana_part, other_part])

    manifest = final_df[["file", "caption"]].rename(columns={"file": "image"})
    manifest_path = os.path.join(opts.dir, "train.csv")
    manifest.to_csv(manifest_path, index=False)
    print(f"✔ Done – saved manifest with {len(manifest)} rows to {manifest_path}")

# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    p = argparse.ArgumentParser(description="Robust CC3M downloader with target quotas.")
    p.add_argument("--file", required=True, help="TSV file with <caption>\t<url>")
    p.add_argument("--dir", required=True, help="Output directory")
    p.add_argument("--total", type=int, default=600000, help="Total images to fetch (default 600k)")
    p.add_argument("--banana", type=int, default=1500, help="Minimum banana-caption images (default 1.5k)")
    p.add_argument("--buffer", type=float, default=0.10, help="Oversample buffer ratio (default 0.10 ⇒ +10%)")
    args = p.parse_args()

    run(args)
