import numpy as np
import os
import torch
from tqdm import tqdm
from utils import rank

def NDCG_at_k(order_perfs, ideal_perfs, k):
    '''
        order_perfs: (n_user, n_learnware):
            order_perfs[i][j] means the actual performace (acc / mse) of the (j + 1)-th ranked learnware on user task U_{i+1}
        ideal_perfs: (n_user, n_learnware):
            ideal_perfs[i][j] means the (j + 1)-th best performance on user task U_{i+1}
    '''
    log2ip1 = np.log2(np.arange(k) + 2)  # [log2(2), log2(3), ..., log2(k + 1)]
    order_perfs_k = order_perfs[:, :k]
    ideal_perfs_k  =  ideal_perfs[:, :k]
    DCGk = (order_perfs_k / log2ip1).sum(1)
    IDCGk = (ideal_perfs_k / log2ip1).sum(1)
    NDCGk = (DCGk / IDCGk).mean()
    return NDCGk

def NDCG_at_k_rank(cfg, rcmd_orders, model_perfs, k=None):
    order = 'ascend' if cfg['task'] == 'classification' else 'descend'
    ranks = rank(model_perfs, order=order)
    topk  = np.take_along_axis(ranks, rcmd_orders, axis=1)
    # print(topk.T)
    # return
    top1  = topk.T[0].tolist()
    print(*top1)
    N = cfg['n_learnware']
    order_perfs = (N - topk) / (N - 1)
    ideal_perfs = (N - np.arange(1, 1 + N).reshape(1, -1)) / (N - 1)

    if k:
        return NDCG_at_k(order_perfs, ideal_perfs, k)

    NDCGk = [
        round(NDCG_at_k(order_perfs, ideal_perfs, k + 1), 3)
        for k in range(cfg['n_learnware'])
    ]
    return NDCGk


def NDCG_at_k_spec(cfg, rcmd_orders, model_perfs, k=None):
    order = 'ascend' if cfg['task'] == 'classification' else 'descend'
    ranks = rank(model_perfs, order=order)
    topk  = np.take_along_axis(ranks, rcmd_orders, axis=1).T
    top1  = topk[0].tolist()
    # print(*top1)

    order_perfs = np.take_along_axis(model_perfs, rcmd_orders, axis=1)
    ideal_perfs = -np.sort(-model_perfs, axis=1)

    if cfg['task'] == 'regression':
        order_perfs = 1 / order_perfs
        order_perfs /= order_perfs.max(axis=1)[:, None]
        ideal_perfs /= ideal_perfs.max(axis=1)[:, None]
        print(rcmd_orders[1], ranks[1], topk[0])

    if k:
        return NDCG_at_k(order_perfs, ideal_perfs, k)

    NDCGk = [
        round(NDCG_at_k(order_perfs, ideal_perfs, k + 1), 3)
        for k in range(cfg['n_learnware'])
    ]
    return NDCGk

def ensemble_average_accuracy(cfg, learnware_ids, ensemble='mean', max_k=None):
    # only for classification task
    if max_k is None:
        max_k = learnware_ids.shape[1]
    average_accs = []
    for user_id, learnware_ids_user in enumerate(tqdm(learnware_ids), start=1):
        # learnware_ids_user: (n_learnware,)
        user_task_labels = torch.load(os.path.join('logs', 'user_task_labels', f'{user_id}.pt'))
        user_average_accs = []
        if ensemble == 'vote':
            model_preds = torch.load(os.path.join('logs', 'model_preds', f'{user_id}.pt'))[learnware_ids_user]
            for k in range(1, 1 + max_k):
                model_preds_k = model_preds[:k]
                ensemble_preds_k = torch.mode(model_preds_k, 0).values
                user_average_accs.append((ensemble_preds_k == user_task_labels).float().mean().item())
        elif ensemble in ['mean', 'weighted_mean']:
            model_probs = torch.load(os.path.join('logs', 'model_probs', f'{user_id}.pt'))[learnware_ids_user]
            weights = 1
            if ensemble == 'weighted_mean':
                weights = np.array([1 / np.log2(i + 2) for i in range(max_k)]).reshape(-1, 1, 1)
            model_probs = model_probs * weights
            for k in range(1, 1 + max_k):
                model_probs_k = model_probs[:k]
                mean_model_probs_k = model_probs_k.mean(0)
                ensemble_preds_k = torch.argmax(mean_model_probs_k, dim=1)
                user_average_accs.append((ensemble_preds_k == user_task_labels).float().mean().item())
        else:
            raise ValueError(f'Unknown ensemble method {ensemble}.')
        average_accs.append(user_average_accs)
    return list(map(lambda x: round(x, 3), np.array(average_accs).mean(0).tolist()))