from pathlib import Path
from PIL import Image
from tqdm import tqdm
from wandb.apis.public import File
from typing import Iterable

import re
import numpy as np
import threading
import wandb

import logging

log = logging.getLogger(__name__)


MAX_THREADS = 5
semaphore = threading.Semaphore(MAX_THREADS)


def download_file_task(file: File, output_dir: Path, semaphore: threading.Semaphore):
    with semaphore:
        file.download(output_dir, exist_ok=True)


def download_files(files: Iterable[File], output_dir: Path):
    threads = []

    for file in files:
        thread = threading.Thread(target=download_file_task, args=(file, output_dir, semaphore))
        threads.append(thread)
        thread.start()

    for thread in tqdm(threads, desc="Downloading files"):
        thread.join()

    files = [output_dir / file.name for file in files]
    return files


def select_filetype(files: Iterable[File], filetype_regex: str):
    output_files = []
    keys = {}
    for file in files:
        match = re.match(filetype_regex, file.name)
        if match is not None:
            output_files.append(file)
            keys[file.name] = int(match.group(1))

    output_files.sort(key=lambda x: keys[x.name])
    return output_files


def area_filter(attr_map_files: Iterable[Path], area_threshold: float):
    output = []
    for file in attr_map_files:
        img = Image.open(file)
        arr = np.asarray(img) / 255
        area = arr.sum() / np.prod(arr.shape)
        if area <= area_threshold:
            output.append(True)
        else:
            output.append(False)
    return output


def apply_filter(files: Iterable[Path], filter: Iterable[bool]):
    return [file for file, keep in zip(files, filter) if keep]


def download_run(project_name: str, run_id: str, area_threshold: float = 1):
    api = wandb.Api()
    run_path = f"{project_name}/{run_id}"
    run = api.run(run_path)

    tmp_dir = Path("~/") / "inp_exp" / run_id
    tmp_dir.mkdir(parents=True, exist_ok=True)
    attr_maps_dir = tmp_dir / "attr_maps"
    attr_maps_dir.mkdir(parents=True, exist_ok=True)
    images_dir = tmp_dir / "images"
    images_dir.mkdir(parents=True, exist_ok=True)
    counterfactuals_dir = tmp_dir / "counterfactuals"
    counterfactuals_dir.mkdir(parents=True, exist_ok=True)

    log.info("Accessing files")
    files = list(tqdm(run.files(), desc="Accessing files"))
    attr_maps = select_filetype(files, r"^media/images/attr_maps_post_(\d+)")
    cf_images = select_filetype(files, r"^media/images/inpaints_(\d+)")
    real_images = select_filetype(files, r"^media/images/images_(\d+)")

    # This handles the case where the job was stopped before all images were generated
    assert (
        len(attr_maps) == len(real_images) or len(attr_maps) == len(real_images) - 1
    ), "Number of images and attributions should be the same"
    assert (
        len(cf_images) == len(real_images) or len(cf_images) == len(real_images) - 1
    ), "Number of images and counterfactuals should be the same"
    common_length = min(len(real_images), len(cf_images), len(attr_maps))
    attr_maps = attr_maps[:common_length]
    cf_images = cf_images[:common_length]

    log.info("Downloading attribution maps")
    attr_map_files = download_files(attr_maps, attr_maps_dir)
    filter = area_filter(attr_map_files, area_threshold)

    log.info("Filtering on area")
    real_images_filtered = apply_filter(real_images, filter)
    cf_images_filtered = apply_filter(cf_images, filter)

    assert len(real_images_filtered) == len(
        cf_images_filtered
    ), "Number of images and counterfactuals should be the same"

    log.info("Downloading real/cf images")

    real_images_files = download_files(real_images_filtered, images_dir)
    cf_images_files = download_files(cf_images_filtered, counterfactuals_dir)

    pairs = list(zip(real_images_files, cf_images_files))

    return tmp_dir, pairs, run.config
