import numpy as np

from configurations.constants import Constants
from utils.aggregations import aggregate
from utils.helpers import compile_param_list, compute_weight_from_fault, \
    estimate_fault_from_proxy, compute_weighted_average_pairwise

"""
[Iterative] Proximity-based Truth Discovery. 
variants:
* pairwise_proximity (PP): competence_j = avg(p(i,j)); weight_j = competence_j
* pairwise_distance (PD):  fault_j = avg(d(i,j));  weight_j = max(fault) - fault_j 
* AWG     (only in continuous mode)
* IER     (only in categorical mode)
* IER_BIN  (only in binary/ranking mode)

INPUT: Any distance-based data 
"""


def IPTD(data_params, params):
    max_iterations = params.get('iterations', 1)
    possible_answers = data_params.get('possible_answers', 2)
    iter_weight_method = params.get('iter_weight_method', "max")
    mode = data_params['mode']
    PTD_variant =  params.get('variant', "PD")
    name = params.get('name', "IPTD_"+str(PTD_variant)+"_" + str(max_iterations))
    voting_rule = params.get("voting_rule", mode)

    pairwise_matrix = data_params['dist_matrix'].astype(float)
    #pairwise_relation = data_params.get('pairwise_relation','distance')  #  'distance' or 'proximity'

    # iteratively compute weights
    iter_weights = None
    est_params = dict(params.items())
    est_params["possible_answers"] = possible_answers
    for iter in range(max_iterations):
        old_weights = iter_weights
        proxy_scores = compute_weighted_average_pairwise(pairwise_matrix, iter_weights)
        if PTD_variant == "PP":
            iter_weights = proxy_scores
        elif iter_weight_method == "max":
            iter_weights = max(proxy_scores) - proxy_scores
        elif iter_weight_method == "1":
            iter_weights = 1 - proxy_scores
        elif iter_weight_method == "from_estimate":
            iter_estimated_fault = estimate_fault_from_proxy(proxy_scores, PTD_variant, est_params)
            iter_weights = compute_weight_from_fault(iter_estimated_fault, mode, params,
                                                     possible_answers=possible_answers)

        if old_weights is not None and np.linalg.norm(iter_weights - old_weights) < Constants.convergence_limit:
            break

    estimated_fault = estimate_fault_from_proxy(proxy_scores, PTD_variant, est_params)
    est_params["normalize"] = True
    weights = compute_weight_from_fault(estimated_fault, mode, est_params, possible_answers=possible_answers)
    answers = aggregate(weights, data_params, params)

    return {'af_name': name, 'af_params': compile_param_list(params), 'alpha': np.mean(proxy_scores) / 2,
            'outcome': answers,
            'estimated_fault': estimated_fault, 'weights': weights, 'proxy_score': proxy_scores,
            'iterations': iter + 1, "vr_name": voting_rule}
