import numpy as np
from scipy.spatial import distance

from configurations.constants import Constants
from utils.aggregations import aggregate
from utils.helpers import bound, compile_param_list, ranking_to_pairwise_q, compute_weight_from_fault
from utils.voting_rules import voting_rules_dict, run_vr
from utils.wquantile import wmedian
"""
IDTD = Iterative Distance-from-answer Truth Discovery. Iteratively computes the aggregated answer and re-estimates 
fault based on distance from these answers.
This is a "Folk algorithm" based on ideas occuring in multiple papers.

INPUT: Any distance-based data.
"""


def IDTD_mean(data_params, params):
    df = data_params['df']
    name = params.get("name", 'IDTD')
    mode = "continuous"
    n = df.shape[0]
    weights = np.ones(n)
    c_lb = Constants.fault_min
    c_ub = Constants.fault_max
    A = df.values

    iterations = params.get("iterations", 1)
    for iter in range(iterations):
        w_mean = np.average(A, axis=0, weights=weights)
        D = A - w_mean
        estimated_fault = np.mean(D ** 2, axis=1)
        estimated_fault = bound(c_lb, c_ub, estimated_fault)
        weights = compute_weight_from_fault(estimated_fault, mode, params)

    answers = np.average(df, axis=0, weights=weights, returned=False)

    return {'af_name': name, 'af_params': compile_param_list(params), 'outcome': answers,
            'estimated_fault': estimated_fault, 'weights': weights}


def DTD_median(data_params, params):
    df = data_params['df']
    name = params.get("name", 'DTD')
    voting_rule = "median"
    df_median = df.median(axis=0)
    k = df.shape[1]
    D = df - df_median
    distance_from_median = np.mean(D ** 2, axis=1)
    c_lb = Constants.fault_min
    c_ub = Constants.fault_max
    estimated_fault = bound(c_lb, c_ub, distance_from_median)
    weights = 1 / estimated_fault
    # weights = weight_transform(p=estimated_competence, T=None, method=transform)
    answers = np.zeros(k)
    for j in range(k):
        answers[j] = wmedian(df[j], weights)
    return {'af_name': name, 'af_params': compile_param_list(params), 'outcome': answers, "vr_name": voting_rule,
            'estimated_fault': estimated_fault, 'weights': weights}


def IDTD_ranking(data_params, params):
    df = data_params['df']
    n = df.shape[0]
    gt_ranking = data_params['ranking_gt']  # used only for transforming answers to binary
    iterations = params['iterations']
    name = 'IDTD_' + str(iterations)
    VR_function = voting_rules_dict[params['voting_rule']]

    weights = None
    last_iter = 0
    estimated_fault = np.ones(n)
    for iter in range(iterations):
        last_weights = weights
        last_iter = iter
        estimated_outcome = run_vr(VR_function, data_params, weights, **params)
        estimated_outcome_pairwise = ranking_to_pairwise_q([estimated_outcome], gt_ranking)
        for i in range(n):
            estimated_fault[i] = distance.hamming(df.iloc[i][:], estimated_outcome_pairwise, w=None)
        estimated_fault = bound(Constants.fault_min_cat, Constants.fault_max_cat, estimated_fault)
        weights = compute_weight_from_fault(estimated_fault, "rankings", params)
        if np.array_equal(weights, last_weights):
            break
    answers = run_vr(VR_function, data_params, weights, **params)

    vr_name = VR_function.__name__
    # return the actual number of iterations under 'alpha'
    return {'af_name': name , 'af_params': compile_param_list(params), 'outcome': answers,
            'estimated_fault': estimated_fault, 'weights': weights, 'alpha': last_iter,
            'iterations': iterations, 'vr_name': vr_name}


def IDTD(data_params, params):
    iterations = params.get('iterations', 1)
    name = params.get('name', 'IDTD')
    mode = data_params.get("mode", None)

    df = data_params['df']
    possible_answers = data_params['possible_answers']
    n = df.shape[0]
    weights = np.ones(n)
    last_iter = 0
    for iter in range(iterations):
        old_weights = weights
        last_iter = iter
        answers = aggregate(weights, data_params, params)
        distances = np.zeros(n)
        for i in range(n):
            ## TODO: this part is the only one still assuming a specific domain
            distances[i] = distance.hamming(df.iloc[i][:], answers, w=None)
        fault = bound(Constants.fault_min_cat, Constants.fault_max_cat, distances)
        weights = compute_weight_from_fault(fault, mode, params, possible_answers=possible_answers)
        if old_weights is not None and np.linalg.norm(weights - old_weights) < Constants.convergence_limit:
            break
    answers = aggregate(weights, data_params, params)

    return {'af_name': name, 'af_params': compile_param_list(params), 'outcome': answers,
            'estimated_fault': fault, 'weights': weights, 'alpha': last_iter,
            'iterations': iter + 1}
