import hydra
from omegaconf import OmegaConf
from utils import parallel, set_seed
import numpy as np
from metric import NDCG_at_k_rank, ensemble_average_accuracy
from market import Market
from user import User
import numpy as np
from tqdm import trange

@hydra.main(version_base=None, config_path='configs', config_name='config')
def main(cfg=None):
    cfg = OmegaConf.to_container(cfg, resolve=True)
    cfg.update(cfg['task'])
    cfg.update(cfg['specification'])
    cfg['specification'] = 'RNCMETime'
    set_seed(cfg['seed'])
    n_learnware = cfg['n_learnware']
    n_user = cfg['n_user']

    market = Market(cfg)
    # parallel(cfg['n_learnware'], market.submit_learnware)
    # parallel(cfg['n_user'], User, cfg)
    for learnware_id in range(1, 1 + n_learnware):
        market.submit_learnware(learnware_id)
    print('')
    users = [User(cfg, user_id) for user_id in range(1, 1 + n_user)]

    model_perfs = np.array([market.evaluate (user)[0] for user in users])
    max_m = cfg['reduced_size']
    max_m = 2000
    min_learnware_n = 30000
    min_user_n = 2000
    min_learnware_n = max_m
    min_user_n = max_m
    learnware_time = np.zeros(min_learnware_n)
    user_time = np.zeros(min_user_n)
    for learnware_id in range(n_learnware):
        ltime = market.learnwares[learnware_id].spec.time[:min_learnware_n]
        learnware_time += ltime
    for user_id in range(n_user):
        utime = users[user_id].spec.time[:min_user_n]
        user_time += utime

    learnware_time /= n_learnware
    user_time /= n_user
    np.save('learnware_time.npy', learnware_time)
    np.save('user_time.npy', user_time)
    print(learnware_time[::100])
    print(user_time[::100])
    # return
    NDCG_at_ks = []
    spec_dists = []     # (n_user, n_learnware, n_reduced_size)
    for user_id in range(n_user):
        for learnware_id in trange(n_learnware):
            spec_dists.append(
                users[user_id].spec.compare(
                    market.learnwares[learnware_id].spec,
                    max_m=max_m,
                    user_id=user_id + 1,
                    learnware_id=learnware_id + 1,
                )
            )
    spec_dists = np.array(spec_dists).T.reshape(-1, n_user, n_learnware)
    rcmd_orders = spec_dists.argsort(axis=2)


    for m in range(max_m):
        rcmd_orders_m = rcmd_orders[m]
        NDCGk = NDCG_at_k_rank(cfg, rcmd_orders_m, model_perfs, k=1)
        NDCG_at_ks.append(NDCGk)

    np.save('NDCGat1.npy', np.array(NDCG_at_ks))

    print(NDCG_at_ks[::100])

    # for gamma in range(1001):
    #     rcmd_orders = test_gamma(gamma / 10, cfg)
    #     print(f'gamma={gamma/10}:', NDCG_at_k_spec(cfg, rcmd_orders, model_perfs)[:5])
    # print(NDCG_at_k_spec(cfg, rcmd_orders, model_perfs)[:5])
    # if cfg['task'] == 'classification':
    #     print(ensemble_average_accuracy(cfg, rcmd_orders))


if __name__ == '__main__':
    main()