import torch
import numpy as np 
from recbole_gnn.utils import create_dataset, data_preparation
from recbole.utils import init_seed

m = -0.5
n = -0.5
ckpt_path_mf = 'Anime_DirectAU.pth'
ckpt_path_lgcn = 'Anime_LightGCN.pth'

def ndcg_at_k(pos_index, pos_len):
    len_rank = np.full_like(pos_len, pos_index.shape[1])
    idcg_len = np.where(pos_len > len_rank, len_rank, pos_len)

    iranks = np.zeros_like(pos_index, dtype=float)
    iranks[:, :] = np.arange(1, pos_index.shape[1] + 1)
    idcg = np.cumsum(1.0 / np.log2(iranks + 1), axis=1)
    for row, idx in enumerate(idcg_len):
        idcg[row, idx:] = idcg[row, idx - 1]

    ranks = np.zeros_like(pos_index, dtype=float)
    ranks[:, :] = np.arange(1, pos_index.shape[1] + 1)
    dcg = 1.0 / np.log2(ranks + 1)

    dcg = np.cumsum(np.where(pos_index, dcg, 0), axis=1)

    result = dcg / idcg
    return result

def recall_at_k(r, all_pos_num):
    r = r.sum(1)
    return r / (all_pos_num+1e-9)

def test_time_mp(u_emb, i_emb, edge_index, edge_weight, n_layers=3):
    from recbole_gnn.model.layers import LightGCNConv
    gcn_conv = LightGCNConv(123).to(u_emb.device)
    all_embeddings = torch.cat([u_emb, i_emb], dim=0)
    embeddings_list = [all_embeddings]
    for layer_idx in range(n_layers):
        all_embeddings = gcn_conv(all_embeddings, edge_index, edge_weight)
        embeddings_list.append(all_embeddings)
    if n_layers == 0:
        lightgcn_all_embeddings = all_embeddings
    else:
        lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
        lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)
    return torch.split(lightgcn_all_embeddings, [u_emb.shape[0], i_emb.shape[0]])


#  -------------- Dataset Initialization --------------
print('Preparing Datasets...')
payload = torch.load(ckpt_path_mf, map_location='cpu')
config = payload['config']
# config['dataloaders_save_path'] = dataloader_path
if 'valid_neg_sample_args' not in config:
    config['valid_neg_sample_args'] = {'distribution': 'uniform', 'sample_num': 'none'}
    config['test_neg_sample_args'] = {'distribution': 'uniform', 'sample_num': 'none'}
init_seed(config['seed'], config['reproducibility'])
try:
    dataset = create_dataset(payload['config'])
except Exception:
    dataset = create_dataset(payload['config'])
train_data, valid_data, test_data = data_preparation(payload['config'], dataset)
train_user_id = train_data._dataset.inter_feat[config['USER_ID_FIELD']]
train_item_id = train_data._dataset.inter_feat[config['ITEM_ID_FIELD']]
valid_user_id = valid_data._dataset.inter_feat[config['USER_ID_FIELD']]
valid_item_id = valid_data._dataset.inter_feat[config['ITEM_ID_FIELD']]
test_user_id = test_data._dataset.inter_feat[config['USER_ID_FIELD']]
test_item_id = test_data._dataset.inter_feat[config['ITEM_ID_FIELD']]
degree = dataset.inter_feat['user_id'].unique(return_counts=True)[1]
#  -------------- End of Dataset Initialization --------------

#  -------------- Calculating Performance of TAG-CF --------------
print('Running TAG-CF...')
payload = torch.load(ckpt_path_mf, map_location='cpu')
user_e, item_e = payload['other_parameter']['restore_user_e'],  payload['other_parameter']['restore_item_e']
edge_index, edge_weight = train_data._dataset.get_norm_adj_mat(type=[0, 1, 0, 0, m, n])
user_e, item_e = test_time_mp(user_e.cpu(), item_e.cpu(), edge_index, edge_weight, 1)
scores = torch.matmul(user_e.cpu().float(), item_e.cpu().t().float())
scores[train_user_id, train_item_id] = -np.inf
scores[valid_user_id, valid_item_id] = -np.inf
pred = torch.topk(scores, 20).indices.cpu()
results = []
results_rec = []

total_res = []
total_pos = []
for user_id in range(degree.shape[0]):
    res_ = []
    temp = test_item_id[test_user_id == user_id]
    for item in pred[user_id]:
        if item in temp:
            res_.append(1)
        else:
            res_.append(0)
    total_res.append(torch.tensor(res_, dtype=torch.bool))
    total_pos.append(len(temp))

ndcg = ndcg_at_k(torch.stack(total_res).numpy(), np.array(total_pos))[:, -1]
recall = recall_at_k(torch.stack(total_res).numpy(), np.array(total_pos))
recall = np.nan_to_num(recall, nan=0)
print('TAG-CF results: NDCG@{}:{}, Recall@{}:{}'.format(20, ndcg.mean(), 20, recall[1:].mean()))
#  -------------- End of Calculating Performance of TAG-CF --------------


#  -------------- Calculating Performance of MF --------------
print('Running MF...')
payload = torch.load(ckpt_path_mf, map_location='cpu')
user_e, item_e = payload['other_parameter']['restore_user_e'],  payload['other_parameter']['restore_item_e']
# edge_index, edge_weight = train_data._dataset.get_norm_adj_mat(type=[0, 1, 0, 0, m, n])
# user_e, item_e = test_time_mp(user_e.cpu(), item_e.cpu(), edge_index, edge_weight, 1)
scores = torch.matmul(user_e.cpu().float(), item_e.cpu().t().float())
scores[train_user_id, train_item_id] = -np.inf
scores[valid_user_id, valid_item_id] = -np.inf
pred = torch.topk(scores, 20).indices.cpu()

total_res = []
total_pos = []
for user_id in range(degree.shape[0]):
    res_ = []
    temp = test_item_id[test_user_id == user_id]
    for item in pred[user_id]:
        if item in temp:
            res_.append(1)
        else:
            res_.append(0)
    total_res.append(torch.tensor(res_, dtype=torch.bool))
    total_pos.append(len(temp))

ndcg = ndcg_at_k(torch.stack(total_res).numpy(), np.array(total_pos))[:, -1]
recall = recall_at_k(torch.stack(total_res).numpy(), np.array(total_pos))
recall = np.nan_to_num(recall, nan=0)
print('MF results: NDCG@{}:{}, Recall@{}:{}'.format(20, ndcg.mean(), 20, recall[1:].mean()))
#  -------------- End of Calculating Performance of TAG-CF --------------

#  -------------- Calculating Performance of LightGCN --------------
print('Running LightGCN...')
payload = torch.load(ckpt_path_lgcn, map_location='cpu')
user_e, item_e = payload['other_parameter']['restore_user_e'],  payload['other_parameter']['restore_item_e']
# edge_index, edge_weight = train_data._dataset.get_norm_adj_mat(type=[0, 1, 0, 0, m, n])
# user_e, item_e = test_time_mp(user_e.cpu(), item_e.cpu(), edge_index, edge_weight, 1)
scores = torch.matmul(user_e.cpu().float(), item_e.cpu().t().float())
scores[train_user_id, train_item_id] = -np.inf
scores[valid_user_id, valid_item_id] = -np.inf
pred = torch.topk(scores, 20).indices.cpu()

total_res = []
total_pos = []
for user_id in range(degree.shape[0]):
    res_ = []
    temp = test_item_id[test_user_id == user_id]
    for item in pred[user_id]:
        if item in temp:
            res_.append(1)
        else:
            res_.append(0)
    total_res.append(torch.tensor(res_, dtype=torch.bool))
    total_pos.append(len(temp))

ndcg = ndcg_at_k(torch.stack(total_res).numpy(), np.array(total_pos))[:, -1]
recall = recall_at_k(torch.stack(total_res).numpy(), np.array(total_pos))
recall = np.nan_to_num(recall, nan=0)
print('LightGCN results: NDCG@{}:{}, Recall@{}:{}'.format(20, ndcg.mean(), 20, recall[1:].mean()))
#  -------------- End of Calculating Performance of TAG-CF --------------