# code adapted from github package: https://github.com/cytomining/copairs
import os.path
import itertools
import pandas as pd
import numpy as np
import logging
import re

from typing import Union, List, Callable, Tuple
from multiprocessing.pool import ThreadPool
from tqdm.autonotebook import tqdm
from spc.map.matching import Matcher, UnpairedException


# utils
def evaluate_and_filter(df, columns) -> Tuple[pd.DataFrame, List[str]]:
    """Evaluate the query and filter the dataframe"""
    parsed_cols = []
    for col in columns:
        if col in df.columns:
            parsed_cols.append(col)
            continue

        column_names = re.findall(r"(\w+)\s*[=<>!]+", col)
        valid_column_names = [col for col in column_names if col in df.columns]
        if not valid_column_names:
            raise ValueError(f"Invalid query or column name: {col}")

        try:
            df = df.query(col)
            parsed_cols.extend(valid_column_names)
        except:
            raise ValueError(f"Invalid query expression: {col}")

    return df, parsed_cols


def flatten_str_list(*args):
    """create a single list with all the params given"""
    columns = set()
    for col in args:
        if isinstance(col, str):
            columns.add(col)
        elif isinstance(col, dict):
            columns.update(itertools.chain.from_iterable(col.values()))
        else:
            columns.update(col)
    columns = list(columns)
    return columns


def validate_pipeline_input(meta, feats, columns):
    if meta[columns].isna().any(axis=None):
        raise ValueError("metadata columns should not have null values.")
    if len(meta) != len(feats):
        raise ValueError("meta and feats have different number of rows")
    if np.isnan(feats).any():
        raise ValueError("features should not have null values.")


def to_cutoffs(counts: np.ndarray):
    """Convert a list of counts into cutoff indices."""
    cutoffs = np.empty_like(counts)
    cutoffs[0], cutoffs[1:] = 0, counts.cumsum()[:-1]
    return cutoffs


# Functions to calculate distance between pairs of replicate embeddings
def parallel_map(par_func, items):
    """Execute par_func(i) for every i in items using ThreadPool and tqdm."""
    num_items = len(items)
    pool_size = min(num_items, os.cpu_count())
    chunksize = num_items // pool_size
    with ThreadPool(pool_size) as pool:
        tasks = pool.imap_unordered(par_func, items, chunksize=chunksize)
        for _ in tqdm(tasks, total=len(items), leave=False):
            pass


def batch_processing(
    pairwise_op: Callable[[np.ndarray, np.ndarray], np.ndarray],
):
    """Decorator adding the batch_size param to run the function with
    multithreading using a list of paired indices"""

    def batched_fn(feats: np.ndarray, pair_ix: np.ndarray, batch_size: int):
        num_pairs = len(pair_ix)
        result = np.empty(num_pairs, dtype=np.float32)

        def par_func(i):
            x_sample = feats[pair_ix[i : i + batch_size, 0]]
            y_sample = feats[pair_ix[i : i + batch_size, 1]]
            result[i : i + len(x_sample)] = pairwise_op(x_sample, y_sample)

        parallel_map(par_func, np.arange(0, num_pairs, batch_size))

        return result

    return batched_fn


@batch_processing
def pairwise_cosine(x_sample: np.ndarray, y_sample: np.ndarray) -> np.ndarray:
    x_norm = x_sample / np.linalg.norm(x_sample, axis=1)[:, np.newaxis]
    y_norm = y_sample / np.linalg.norm(y_sample, axis=1)[:, np.newaxis]
    c_sim = np.sum(x_norm * y_norm, axis=1)
    return c_sim


# Main functions
def build_rank_lists(pos_pairs, neg_pairs, pos_sims, neg_sims):
    """ Function to create rank list for each replicate"""
    labels = np.concatenate(
        [
            np.ones(pos_pairs.size, dtype=np.int32),
            np.zeros(neg_pairs.size, dtype=np.int32),
        ]
    )
    ix = np.concatenate([pos_pairs.ravel(), neg_pairs.ravel()])
    sim_all = np.concatenate([np.repeat(pos_sims, 2), np.repeat(neg_sims, 2)])
    ix_sort = np.lexsort([1 - sim_all, ix])
    rel_k_list = labels[ix_sort]
    paired_ix, counts = np.unique(ix, return_counts=True)
    return paired_ix, rel_k_list, counts


def ap_contiguous(rel_k_list, counts):
    """Compute average precision from a list of contiguous values"""
    cutoffs = to_cutoffs(counts)

    num_pos = np.add.reduceat(rel_k_list, cutoffs)
    shift = np.empty_like(num_pos)
    shift[0], shift[1:] = 0, num_pos[:-1]

    tp = rel_k_list.cumsum() - np.repeat(shift.cumsum(), counts)
    k = np.arange(1, len(rel_k_list) + 1) - np.repeat(cutoffs, counts)

    pr_k = tp / k
    ap_scores = np.add.reduceat(pr_k * rel_k_list, cutoffs) / num_pos
    null_confs = np.stack([num_pos, counts], axis=1)
    return ap_scores, null_confs


def average_precision(meta: pd.DataFrame, feats: np.ndarray,
                      pos_sameby: List, pos_diffby: List, neg_sameby: List, neg_diffby: List,
                      batch_size: int = 20000, verbose: bool = False
                      ) -> pd.DataFrame:
    """ Main function to calculate mean average precision score"""
    columns = flatten_str_list(pos_sameby, pos_diffby, neg_sameby, neg_diffby)
    validate_pipeline_input(meta, feats, columns)

    # Critical!, otherwise the indexing wont work
    meta = meta.reset_index(drop=True).copy()
    logger = logging.getLogger("map score")
    if verbose is False:
        logging.disable()
    logger.info("Indexing metadata...")
    matcher = Matcher(*evaluate_and_filter(meta, columns), seed=0)

    logger.info("Finding positive pairs...")
    pos_pairs = matcher.get_all_pairs(sameby=pos_sameby, diffby=pos_diffby)
    pos_total = sum(len(p) for p in pos_pairs.values())
    if pos_total == 0:
        raise UnpairedException("Unable to find positive pairs.")
    pos_pairs = np.fromiter(
        itertools.chain.from_iterable(pos_pairs.values()),
        dtype=np.dtype((np.int32, 2)),
        count=pos_total,
    )

    logger.info("Finding negative pairs...")
    neg_pairs = matcher.get_all_pairs(sameby=neg_sameby, diffby=neg_diffby)
    neg_total = sum(len(p) for p in neg_pairs.values())
    if neg_total == 0:
        raise UnpairedException("Unable to find negative pairs.")
    neg_pairs = np.fromiter(
        itertools.chain.from_iterable(neg_pairs.values()),
        dtype=np.dtype((np.int32, 2)),
        count=neg_total,
    )

    logger.info("Computing positive similarities...")
    pos_sims = pairwise_cosine(feats, pos_pairs, batch_size)

    logger.info("Computing negative similarities...")
    neg_sims = pairwise_cosine(feats, neg_pairs, batch_size)

    logger.info("Building rank lists...")
    paired_ix, rel_k_list, counts = build_rank_lists(
        pos_pairs, neg_pairs, pos_sims, neg_sims
    )

    logger.info("Computing average precision...")
    ap_scores, null_confs = ap_contiguous(rel_k_list, counts)

    logger.info("Creating result DataFrame...")
    meta["n_pos_pairs"] = 0
    meta["n_total_pairs"] = 0
    meta.loc[paired_ix, "average_precision"] = ap_scores
    meta.loc[paired_ix, "n_pos_pairs"] = null_confs[:, 0]
    meta.loc[paired_ix, "n_total_pairs"] = null_confs[:, 1]
    logger.info("Finished.")
    return meta


def mean_average_precision_score(metadata: pd.DataFrame,
                                 embeddings: np.ndarray,
                                 pos_sameby: Union[List, str],
                                 pos_diffby: Union[List, str],
                                 neg_sameby: Union[List, str],
                                 neg_diffby: Union[List, str],
                                 ) -> pd.DataFrame:
    if isinstance(pos_sameby, str):
        pos_sameby = [pos_sameby]
    if isinstance(pos_diffby, str):
        pos_diffby = [pos_diffby]
    if isinstance(neg_sameby, str):
        neg_sameby = [neg_sameby]
    if isinstance(neg_diffby, str):
        neg_diffby = [neg_diffby]
    map_score_table = average_precision(metadata, embeddings, pos_sameby, pos_diffby, neg_sameby, neg_diffby)
    return map_score_table

