# %% imports
import os
import shutil
import subprocess

import zipfile
import tarfile
import wget


WATERLOO_LINK = (
    "https://ivc.uwaterloo.ca/database/WaterlooExploration/waterloo_exploration.rar"
)

DIV2K_TRAIN_LINK = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip"
DIV2K_VALID_LINK = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip"

FLICKR2K_LINK = "https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar"


def _gb_progress_bar(current, total, width=40):
    """Display wget progress in GB."""
    if total <= 0:
        total = current or 1
    ratio = min(current / total, 1)
    filled = int(width * ratio)
    bar = "=" * filled + " " * (width - filled)
    downloaded = current / (1024**3)
    total_gb = total / (1024**3)
    print(
        f"\r[{bar}] {downloaded:.2f}/{total_gb:.2f} GB",
        end="",
        flush=True,
    )
    if current >= total:
        print()


def _extract_waterloo_with_unrar(rar_path, destination):
    """Extract Waterloo RAR using the unrar CLI."""
    if shutil.which("unrar") is None:
        raise RuntimeError(
            "The 'unrar' command is required for Waterloo extraction but was not found."
        )
    cmd = ["unrar", "x", "-idq", "-o+", rar_path]
    subprocess.run(cmd, cwd=destination, check=True)


# %% dataset functions
def waterloo_exploration():
    """Download Waterloo Exploration Dataset using unrar."""

    data_dir = os.path.dirname(os.path.abspath(__file__))
    dataset_dir = os.path.join(data_dir, "WaterlooExploration")
    if os.path.isdir(dataset_dir) and os.listdir(dataset_dir):
        print("Waterloo Exploration Dataset already exists.")
        return

    tmp_dir = f"{dataset_dir}_tmp"
    if os.path.isdir(tmp_dir):
        shutil.rmtree(tmp_dir)
    os.makedirs(tmp_dir, exist_ok=True)

    rar_path = os.path.join(data_dir, "waterloo_exploration.rar")
    if os.path.exists(rar_path):
        print("Found existing Waterloo archive, skipping download.")
    else:
        print("Downloading Waterloo Exploration Dataset...")
        wget.download(WATERLOO_LINK, rar_path, bar=_gb_progress_bar)

    print("Extracting files with unrar...")
    start_time = os.times()
    _extract_waterloo_with_unrar(rar_path, tmp_dir)
    end_time = os.times()
    elapsed_time = end_time.elapsed - start_time.elapsed
    print(f"Extraction completed in {elapsed_time:.2f} seconds.")

    pristine_dir = os.path.join(tmp_dir, "pristine_images")
    if os.path.isdir(pristine_dir):
        for entry in os.listdir(pristine_dir):
            src = os.path.join(pristine_dir, entry)
            dst = os.path.join(tmp_dir, entry)
            shutil.move(src, dst)
        shutil.rmtree(pristine_dir)

    if os.path.isdir(dataset_dir):
        shutil.rmtree(dataset_dir)
    os.replace(tmp_dir, dataset_dir)

    try:
        os.remove(rar_path)
    except OSError:
        pass
    print("Download and extraction complete.")


def div2k():
    """Download DIV2K training + validation and combine into DIV2K/."""
    data_dir = os.path.dirname(os.path.abspath(__file__))
    combined_dir = os.path.join(data_dir, "DIV2K")

    if os.path.isdir(combined_dir) and os.listdir(combined_dir):
        print("DIV2K dataset already combined.")
        return

    splits = (
        ("training", DIV2K_TRAIN_LINK, os.path.join(data_dir, "DIV2K_train_HR")),
        ("validation", DIV2K_VALID_LINK, os.path.join(data_dir, "DIV2K_valid_HR")),
    )

    def _ensure_split(name, url, split_dir):
        if os.path.isdir(split_dir) and os.listdir(split_dir):
            print(f"DIV2K {name} split already available.")
            return

        zip_path = os.path.join(data_dir, os.path.basename(url))
        print(f"Downloading DIV2K {name} split...")
        wget.download(url, zip_path, bar=_gb_progress_bar)
        print("\nExtracting files...")
        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(data_dir)
        os.remove(zip_path)
        print(f"{name.title()} split ready.")

    for name, url, split_dir in splits:
        _ensure_split(name, url, split_dir)

    os.makedirs(combined_dir, exist_ok=True)
    for _, _, split_dir in splits:
        if not os.path.isdir(split_dir):
            continue
        for entry in os.listdir(split_dir):
            src = os.path.join(split_dir, entry)
            dst = os.path.join(combined_dir, entry)
            shutil.move(src, dst)
        shutil.rmtree(split_dir)

    print("Combined DIV2K training and validation splits into DIV2K/.")


def flickr2k():
    data_dir = os.path.dirname(os.path.abspath(__file__))
    dataset_dir = os.path.join(data_dir, "Flickr2K")
    if os.path.isdir(dataset_dir) and os.listdir(dataset_dir):
        print("Flickr2K Dataset already exists.")
        return

    tar_path = os.path.join(data_dir, "Flickr2K.tar")
    if not os.path.exists(tar_path):
        print("Downloading Flickr2K Dataset: \n")
        print("nohup aria2c -c -x16 -s16 https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar -o Flickr2K.tar > flickr2k.log 2>&1 & \n")
        print("tail -f flickr2k.log \n")
        return
        # wget.download(FLICKR2K_LINK, tar_path, bar=_gb_progress_bar)

    print("\nExtracting files...")
    with tarfile.open(tar_path, mode="r") as tf:
        tf.extractall(path=data_dir)

    os.remove(tar_path)
    # remove _LR folders
    for entry in os.listdir(dataset_dir):
        if entry.endswith("_LR"):
            lr_path = os.path.join(dataset_dir, entry)
            shutil.rmtree(lr_path)
    print("Flickr2K Dataset download and extraction complete.")


# %% main
datasets = {flickr2k}

for ds_func in datasets:
    ds_func()
