from ot_jax.data.datasets.DOTmark import DOTmarkLoader, DOTmarkClass, ArrayBackend, DOTmarkResolution
from pathlib import Path
import hashlib
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Callable, Optional
from tqdm import tqdm
from logging import getLogger, basicConfig, INFO
import jax
basicConfig(level=INFO)
logger = getLogger(__name__)

def benchmark_data(dot_class: DOTmarkClass | list[DOTmarkClass], resolution: DOTmarkResolution | list[DOTmarkResolution]):
    dot_loader = DOTmarkLoader(
        dot_class=dot_class,
        resolution=resolution,
        normalize=True, 
        array_backend=ArrayBackend.JAX
        )
    return dot_loader.as_pairwise_dict()

def build_tasks(data: dict, p: list[int], ot_fun: list[Callable], fun_params: list[int | float]):
    tasks = []
    fun_params = fun_params or [None]
    for (class_name, class_data) in data.items():
        for (res, res_data) in class_data.items():
            for (i, j), (x, y) in res_data.items():
                for p_ in p:
                    for fun in ot_fun:
                        for param in fun_params:
                            tasks.append((class_name, res, i, j, x, y, p_, fun, param))
    return tasks

def get_output_file(prefix: str, dot_class, resolution, p, ot_fun, param=None):
    ot_fun = [_.__name__ for _ in ot_fun]
    input_params = tuple(tuple(_) for _ in [[dot_class], [resolution], p, ot_fun, param] if _ is not None)
    input_hash = hashlib.md5(str(input_params).encode()).hexdigest()
    return Path(prefix + f"_output_{input_hash}.csv").resolve()

def process_tasks(data, p, ot_fun, param, output_file, func, n_gpus: Optional[int] = None) -> None:
    if output_file.exists():
        logger.info(f"Output file {output_file} for {tuple(_.__name__ for _ in ot_fun)} already exists. Skipping computation")
        return
    
    logger.info(f"Computing optimal transport for {tuple(_.__name__ for _ in ot_fun)} methods to file {output_file.name}")
    tasks = build_tasks(data, p, ot_fun, param)
    
    def worker(task, device=None):
        device = device or jax.devices("cpu")[0]
        dot_class, res, i, j, x, y, *params = task
        x = jax.device_put(x, device)
        y = jax.device_put(y, device)
        return func(dot_class, res, i, j, x, y, *params)
    
    devices = jax.devices("gpu")
    n_gpus = n_gpus or len(devices)-1 # Leave one GPU for the host by default
    n_gpus = min(n_gpus, len(devices))

    # Run on CPU if no GPUs are available
    if not n_gpus:
        logger.info("Running with a sequential executor on CPU")
        results = list(map(worker, tasks))
        write_results(results, output_file)
        return
    
    # Distribute tasks across GPUs
    logger.info(f"Running with a concurrent executor on {n_gpus} GPUs")
    
    results = []
    with ThreadPoolExecutor(max_workers=n_gpus) as executor:
        future_to_task = {
            executor.submit(worker, task, devices[i % n_gpus]): task
            for i, task in enumerate(tasks)
        }

        with tqdm(total=len(tasks), desc="Benchmarking") as pbar:
            for future in as_completed(future_to_task):
                results.append(future.result())
                pbar.update(1)
    write_results(results, output_file)
    return
    
def write_results(results, output_file) -> None:
    with output_file.open("w") as f:
        logger.info(f"Writing results to {output_file}")
        f.writelines(results)