import os
from multiprocessing import Pool

import pandas as pd
from tqdm import tqdm

from .path_exists import path_exists


def read_signal_data(exp_data_path: str, exp_info: dict[str, any]) -> dict[str, any]:
    experiment_name = exp_info['experiment_name']
    tasks = exp_info['tasks']
    modal = exp_info['modal']

    tasks_data = []
    for task in tasks:
        task_name = task['task_name']
        match task_name:
            case "ping_pong_cooperative":
                csv_path = os.path.join(exp_data_path, "ping_pong_cooperative.csv")
                score_column_name = "team_score"
            case "saturn_a":
                csv_path = os.path.join(exp_data_path, "minecraft_saturn_a.csv")
                score_column_name = "points"
            case "saturn_b":
                csv_path = os.path.join(exp_data_path, "minecraft_saturn_b.csv")
                score_column_name = "points"
            case _:
                raise ValueError(f"Cannot process {task_name} for {experiment_name}")

        if not path_exists(csv_path):
            continue

        signal_task_df = pd.read_csv(csv_path)

        assert signal_task_df["timestamp_unix"].is_monotonic_increasing

        tasks_data.append({
            "task_name": task_name,
            "data": signal_task_df,
            "stations": task['stations'],
            "score_column_name": score_column_name
        })

    return {
        "experiment_name": experiment_name,
        "tasks": tasks_data,
        "modal": modal
    }


def _multiprocess_read_signal_data(process_args):
    return read_signal_data(*process_args)


def read_signal_data_all(dir_path: str,
                         experiments_info: list[dict[str, any]],
                         num_processes: int = 1) -> list[dict[str, any]]:
    process_args = [(os.path.join(dir_path, exp_info["experiment_name"]), exp_info)
                    for exp_info in experiments_info]

    with Pool(processes=num_processes) as pool:
        results = list(tqdm(pool.imap(_multiprocess_read_signal_data, process_args), total=len(process_args)))

    return sorted(results, key=lambda x: x['experiment_name'])
