#!/usr/bin/env python3
"""Download EEG dataset from UCI repository.

This script downloads the UCI EEG archive and extracts it under
`data/eeg/full/`, which is where the EEGSampler expects the raw files.

If previously-extracted subject folders (e.g., `co2a*`, `co2c*`, `co3a*`, `co3c*`)
exist directly under `data/eeg/`, they will be migrated into `data/eeg/full/`
automatically so you don't need to re-download.
"""

import os
import argparse
import shutil
import urllib.request
import tarfile
from pathlib import Path
import gzip
import shutil as _shutil

def _migrate_existing_subject_dirs(base: Path, full_path: Path) -> bool:
    """Move any subject directories in `base` into `full_path`.

    Returns True if any directories were moved.
    """
    moved_any = False
    full_path.mkdir(parents=True, exist_ok=True)
    if not base.exists():
        return False

    for entry in base.iterdir():
        if not entry.is_dir():
            continue
        name = entry.name
        if name.startswith(("co2a", "co2c", "co3a", "co3c")):
            dest = full_path / name
            if dest.exists():
                # Already migrated; skip
                continue
            print(f"Migrating {entry} -> {dest}")
            shutil.move(str(entry), str(dest))
            moved_any = True
    return moved_any


def _extract_subject_archives(full_path: Path, *, force_reextract: bool = False, clean_archives: bool = False) -> int:
    """Extract nested subject archives (co*.tar.gz/.tar) into directories under `full_path`.

    Returns the number of subject archives extracted.
    """
    extracted = 0
    # Gather both .tar.gz and .tar
    archives = list(full_path.glob("co*.tar.gz")) + list(full_path.glob("co*.tar"))
    for arc in archives:
        base = arc.name
        name = base.split(".tar")[0]
        dest = full_path / name
        # Skip if already extracted and not forcing re-extract
        if dest.is_dir() and not force_reextract:
            # Heuristic: if it contains trials (rd.* or rd.*.gz), assume ok
            if any(dest.rglob("*.rd.*")) or any(dest.rglob("*.rd.*.gz")):
                continue
        dest.mkdir(parents=True, exist_ok=True)
        print(f"Extracting {base} -> {dest.name}/")
        with tarfile.open(arc, "r:*") as tar:
            tar.extractall(dest)
        extracted += 1
        if clean_archives:
            try:
                arc.unlink()
            except Exception:
                pass
    return extracted


def _flatten_subject_layout(full_path: Path) -> int:
    """Move nested <subject>/<subject>/*.rd.* up one level to <subject>/ and remove inner dir.

    Returns number of subjects flattened.
    """
    flattened = 0
    for subj in full_path.iterdir():
        if not subj.is_dir():
            continue
        name = subj.name
        nested = subj / name
        if nested.is_dir():
            # Move trial files up
            for p in nested.glob("*.rd.*"):
                dest = subj / p.name
                if not dest.exists():
                    p.replace(dest)
            # Remove nested dir if empty
            try:
                nested.rmdir()
            except OSError:
                pass
            flattened += 1
    return flattened


def _gunzip_trials(full_path: Path) -> int:
    """Decompress any *.rd.*.gz files in-place to *.rd.*.

    Returns number of gz files decompressed.
    """
    count = 0
    for gz_path in full_path.rglob("*.rd.*.gz"):
        out_path = Path(str(gz_path)[:-3])  # strip .gz
        if out_path.exists():
            # Already decompressed
            continue
        print(f"Decompressing {gz_path.relative_to(full_path)}")
        with gzip.open(gz_path, "rb") as fin, open(out_path, "wb") as fout:
            _shutil.copyfileobj(fin, fout)
        try:
            gz_path.unlink()
        except Exception:
            pass
        count += 1
    return count


def _validate_layout(full_path: Path) -> None:
    """Validate that subjects and rd files exist; print a short summary.

    Raises RuntimeError if validation fails.
    """
    subject_dirs = [d for d in full_path.iterdir() if d.is_dir() and d.name.startswith(("co2a", "co2c", "co3a", "co3c"))]
    if not subject_dirs:
        raise RuntimeError(f"No subject directories found under {full_path}. Expected co2a*/co2c*/co3a*/co3c*.")
    total_trials = 0
    for d in subject_dirs:
        total_trials += sum(1 for _ in d.glob("*.rd.*"))
    if total_trials == 0:
        raise RuntimeError(f"No trial files (*.rd.*) found under {full_path}.")
    print(f"Validation: {len(subject_dirs)} subjects, {total_trials} trial files (*.rd.*)")


def download_eeg_data(migrate_only: bool = False, *, force_reextract: bool = False, clean_archives: bool = False, delete_cache: bool = False):
    """Download and extract EEG dataset."""
    data_dir = Path("data/eeg")
    data_dir.mkdir(parents=True, exist_ok=True)
    
    # Download URL for the EEG dataset
    url = "https://kdd.ics.uci.edu/databases/eeg/eeg_full.tar"
    tar_path = data_dir / "eeg_full.tar"
    
    if not tar_path.exists():
        print(f"Downloading EEG dataset from {url}...")
        urllib.request.urlretrieve(url, tar_path)
        print(f"Downloaded to {tar_path}")
    else:
        print(f"Dataset already downloaded at {tar_path}")
    
    # Destination for extracted files
    full_path = data_dir / "full"

    # Optionally migrate any previously extracted subject folders
    moved = _migrate_existing_subject_dirs(data_dir, full_path)
    if moved:
        print(f"Moved existing subject folders into {full_path}")

    if migrate_only:
        print("Migration-only mode requested; skipping download/extract.")
        print("EEG dataset ready!")
        return

    # Extract if not already extracted
    if not full_path.exists() or not any(full_path.iterdir()):
        print(f"Extracting {tar_path} to {full_path}...")
        full_path.mkdir(parents=True, exist_ok=True)
        with tarfile.open(tar_path, 'r') as tar:
            tar.extractall(full_path)
        print(f"Extracted to {full_path}")
    else:
        print(f"Dataset already extracted at {full_path}")

    # Extract nested subject archives (co*.tar.gz/.tar) into folders
    extracted = _extract_subject_archives(full_path, force_reextract=force_reextract, clean_archives=clean_archives)
    if extracted:
        print(f"Extracted {extracted} subject archives into folders under {full_path}")
    else:
        print("No subject archives needed extraction (or already extracted)")

    # Flatten nested <subject>/<subject>/ layout if present
    flattened = _flatten_subject_layout(full_path)
    if flattened:
        print(f"Flattened {flattened} subject folders with nested layout")

    # Decompress any gzipped trial files to plain .rd.*
    gunzipped = _gunzip_trials(full_path)
    if gunzipped:
        print(f"Decompressed {gunzipped} gzipped trial files")

    # Optionally delete stale cache so samplers re-parse freshly
    cache_file = full_path.parent / "full.pickle"
    if delete_cache and cache_file.exists():
        try:
            cache_file.unlink()
            print(f"Deleted stale cache: {cache_file}")
        except Exception as e:
            print(f"Warning: failed to delete {cache_file}: {e}")
    # Validate final layout
    try:
        _validate_layout(full_path)
    except RuntimeError as e:
        print(f"ERROR: {e}")
        raise

    print("EEG dataset ready! Raw files are under data/eeg/full/<subject>/ with *.rd.* trials.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Download and prepare UCI EEG dataset")
    parser.add_argument(
        "--migrate-only",
        action="store_true",
        help="Only migrate existing subject folders into data/eeg/full and exit",
    )
    parser.add_argument(
        "--force-reextract",
        action="store_true",
        help="Force re-extraction of subject archives even if folders exist",
    )
    parser.add_argument(
        "--clean-archives",
        action="store_true",
        help="Delete subject .tar(.gz) files after successful extraction",
    )
    parser.add_argument(
        "--delete-cache",
        action="store_true",
        help="Delete data/eeg/full.pickle so loaders re-parse raw files",
    )
    args = parser.parse_args()
    download_eeg_data(
        migrate_only=args.migrate_only,
        force_reextract=args.force_reextract,
        clean_archives=args.clean_archives,
        delete_cache=args.delete_cache,
    )
