import itertools
from multiprocessing import Pool

from scipy.stats import pearsonr
from tqdm import tqdm


def _compute_correlation(x_values, y_values) -> float:
    """
    Compute the correlation between two lists of values.
    @param x_values: x value array
    @param y_values: y value array
    @return: correlation between x and y values
    """
    return pearsonr(x_values, y_values).correlation


def compute_channel_correlation_score(exp_data: dict[str, any], channels: list[str]) -> dict[str, any]:
    tasks = exp_data['tasks']
    modal = exp_data['modal']

    tasks_results = []
    for task in tasks:
        task_name = task['task_name']
        stations = task['stations']
        task_data_df = task['data']
        score_column_name = task['score_column_name']

        # Get physio channel columns
        channel_columns = [col for col in task_data_df.columns if
                           any(sub_string in col for sub_string in channels)]

        if len(channel_columns) == 0:
            continue

        # Compute correlations
        correlation_per_channel = {}
        for channel in channels:
            # Get pairs of computers: lion-tiger, lion-leopard, tiger-leopard
            combinations = itertools.combinations(stations, 2)

            channel_correlations = []
            for computer1, computer2 in combinations:
                if f"{computer1}_{modal}_{channel}" in task_data_df.columns and \
                        f"{computer2}_{modal}_{channel}" in task_data_df.columns:
                    computer1_channel = task_data_df[f"{computer1}_{modal}_{channel}"]
                    computer2_channel = task_data_df[f"{computer2}_{modal}_{channel}"]
                    if len(computer1_channel) < 2 or len(computer2_channel) < 2:
                        continue

                    corr = _compute_correlation(computer1_channel, computer2_channel)

                    channel_correlations.append(corr)

            # Compute average correlation for the channel by the number of pairs of computers
            average_corr = None if not channel_correlations \
                else sum(channel_correlations) / len(channel_correlations)

            # Add average correlation to the list of correlations for the channel
            correlation_per_channel[channel] = average_corr

        final_score = task_data_df[score_column_name].values[-1]

        tasks_results.append({
            "task_name": task_name,
            "correlation_per_channel": correlation_per_channel,
            "score": final_score
        })

    return {
        "experiment_name": exp_data['experiment_name'],
        "modal": modal,
        "task_correlation_results": tasks_results
    }


def _multiprocess_compute_channel_correlation_score(process_arg):
    return compute_channel_correlation_score(*process_arg)


def compute_channel_correlation_score_all(experiments_data: list[dict[str, any]],
                                          channels: list[str],
                                          num_processes: int = 1) -> list[dict[str, any]]:
    process_args = [(experiment_data, channels) for experiment_data in experiments_data]

    with Pool(processes=num_processes) as pool:
        results = list(tqdm(
            pool.imap(_multiprocess_compute_channel_correlation_score, process_args), total=len(process_args)
        ))

    results.sort(key=lambda x: x['experiment_name'])

    return results
