import numpy as np
import pandas as pd
from sklearn.metrics import ndcg_score, dcg_score


def calculate_ndcg(X_test, y_test, y_pred, task, y_rel=None, ks=[1, 10, 100]):
    results = X_test.copy()
    
    results['y_pred'] = y_pred
    results['y_true'] = y_test

    if y_rel is None:
        results['y_true_rel'] = results['max_rel'] - results['y_true']
        results['y_pred_rel'] = results['max_rel'] - results['y_pred']
    else:
        results['y_true_rel'] = y_rel
        results['y_pred_rel'] = results['max_rel'] - results['y_pred']

    results = results.set_index('tid')

    # Find the maximum length of the lists
    max_length = max(len(results.loc[task]['y_true_rel']),
                     len(results.loc[task]['y_pred_rel']))

    ndcg_scores = []
    for k in ks:
        # Extend the lists so that they all have the same length
        y_trues_extended = [extend_list(results.loc[task]['y_true_rel'].values.tolist(), max_length)]
        y_preds_extended = [extend_list(results.loc[task]['y_pred_rel'].values.tolist(), max_length)]
        
        y_trues_np = np.array(y_trues_extended)
        y_preds_np = np.array(y_preds_extended)
        
        # Check for NaNs and infinite values
        if np.isnan(y_preds_np).any():
            y_preds_np[np.isnan(y_preds_np)] = np.nanmin(y_preds_np)
        if np.isinf(y_preds_np).any():
            raise ValueError("Input contains infinite values.")

        ndcg_scores.append(ndcg_score(y_trues_np, y_preds_np, k=k))

    return ndcg_scores

def extend_list(lst, length, fill_value=0.0):
    return lst + [fill_value] * (length - len(lst))

def calculate_mrr(X_test, y_test, y_pred, task, y_rel=None, ks=[1, 10, 100]):
    results = X_test.copy()
    results['y_pred'] = y_pred
    results['y_true'] = y_test

    results = results.set_index('tid')

    mrr_scores_for_ks = []
    for k in ks:

        best_rank = results.loc[task]['y_true'].min()
        best_items = results.loc[task][
            results.loc[task].y_true==best_rank].model_id.values
        pred_items = results.loc[task].sort_values(
            ['y_pred'], ascending=True).model_id.values[:k]  # Considering only the first k elements

        rank = 0

        for idx, p in enumerate(pred_items):
            if p in best_items:  # Finding the first true item in the trimmed predictions
                rank = idx + 1
                break
        
        mrr_score = 1.0 / rank if rank > 0 else 0.0

        mrr_scores_for_ks.append(mrr_score)

    return mrr_scores_for_ks

def calculate_score(X_test, y_test, y_pred, y_rel, task, ks=[1, 10, 100]):
    results = X_test.copy()
    results['y_pred'] = y_pred
    results['y_true'] = y_test
    
    results['y_true_rel'] = y_rel

    results = results.set_index('tid')

    scores_for_ks = []
    for k in ks:
            
        best_rank = results.loc[task]['y_true'].min()

        best_score = results.loc[task][
            results.loc[task].y_true==best_rank].y_true_rel.values

        assert all(x == best_score[0] for x in best_score)

        pred_items = results.loc[task].sort_values(
            ['y_pred'], ascending=True).y_true_rel.values[:k]  # Considering only the first k elements

        pred_score = max(pred_items)

        scores_for_ks.append(pred_score)

    return scores_for_ks

def calculate_ttb(X_test, y_test, y_pred, y_time, task, ks=[1, 10, 100]):
    results = X_test.copy()
    results['y_pred'] = y_pred
    results['y_true'] = y_test
    results['y_time'] = y_time

    results = results.set_index('tid')

    times_for_ks = []
    for k in ks:
        best_rank = results.loc[task]['y_true'].min()
        best_items = results.loc[task][results.loc[task].y_true == best_rank].model_id.values
        sorted_results = results.loc[task].sort_values(['y_pred'], ascending=True).head(k)

        cumulative_time = 0
        found = False
        for idx, row in sorted_results.iterrows():
            cumulative_time += row['y_time']
            if row['model_id'] in best_items:  # Finding the first true in the predictions
                found = True
                break
        
        if not found:
            cumulative_time = results.loc[task]['y_time'].sum()
            
        times_for_ks.append(cumulative_time)

    return times_for_ks


from collections import defaultdict

def prom_dicts(lista_diccionarios):
    sumas = defaultdict(float)
    conteos = defaultdict(int)
    
    for d in lista_diccionarios:
        for k, v in d.items():
            sumas[k] += v
            conteos[k] += 1
    
    promedios = {k: sumas[k] / conteos[k] for k in sumas}
    
    return promedios

def to_one_row(df, metric_name, approach_name):
    new_data = {}
    df.index = [1,10,100]
    df.columns = [metric_name, metric_name+'_sd']
    for idx in df.index:
        new_data[f'{metric_name}@{idx}'] = df.loc[idx, metric_name]
        new_data[f'{metric_name}@{idx}_sd'] = df.loc[idx, f'{metric_name}_sd']

    return pd.DataFrame([new_data], index=[approach_name])

def get_dict_results(model_tuples, metric, metric_kwargs,
                     ctype_index, X_test, y_test, y_rel=None):
    
    dic_metric = {}
    for model, y_pred in model_tuples:
        for ctype, ix in ctype_index.items():
            print(model, ctype)
            dic_metric[model+'_'+ctype] = metric(
                X_test.loc[ix],
                y_test.loc[ix],
                y_pred.loc[ix],
                y_rel.loc[ix] if y_rel is not None else None,
                **metric_kwargs
            )
        if len(ctype_index.keys()) > 1:
            print(model, 'all')
            dic_metric[model] = metric(
                X_test,
                y_test,
                y_pred,
                y_rel if y_rel is not None else None,
                **metric_kwargs
            )
    return dic_metric