import argparse
import time
import torch
import numpy as np
import scipy.sparse as sparse
import torch.optim as optim
from sklearn.metrics import roc_auc_score, f1_score
from model import Uni_model
import gc
import torch.nn.functional as F
import os
import metric
import utils

parser = argparse.ArgumentParser(description='PyTorch Neural Collaborative Filtering')
parser.add_argument('--dataset', type=str, default='epinions')
parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate')
parser.add_argument('--wd', type=float, default=0.0000, help='weight decay coefficient')
parser.add_argument('--batch_size', type=int, default=1024, help='batch size')
parser.add_argument('--epochs', type=int, default=250, help='upper epoch limit')
parser.add_argument('--embedding_size', type=int, default=64, help='embedding_size')
parser.add_argument('--seed', type=int, default=1111, help='random seed')
parser.add_argument('--cuda', action='store_true', help='use CUDA')
parser.add_argument('--test', action='store_true', help='test')
parser.add_argument('--cold', action='store_false', help='use CUDA')
parser.add_argument('--cross', action='store_true', help='use CUDA')
parser.add_argument('--source', type=str, default='yelp_AZ')
parser.add_argument('--feature', type=str, default='popularity')
parser.add_argument('--home_dir', type=str, default='')
parser.add_argument('--cluster', action='store_true', help='use CUDA')
parser.add_argument('--pretrain', action='store_false', help='use CUDA')
parser.add_argument('--log-interval', type=int, default=100, metavar='N', help='report interval')
parser.add_argument('--save', type=str, default='model.pt', help='path to save the final model')
parser.add_argument('--save_output', type=str, default='model', help='Save identifier')
parser.add_argument('--k', type=int, default=100)

args = parser.parse_args()
# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    if not args.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

###############################################################################
# Load data
###############################################################################

# Training: Batch size is based on # interactions.
device = torch.device("cuda" if args.cuda else "cpu")
# Evaluation: Batch-size is based on # users.
eval_params = {'batch_size': args.batch_size,
               'shuffle': False,
               'num_workers': 6, 'pin_memory': True}
home_dir = "/home/"
# Define train/val/test datasets

# user_feature = torch.FloatTensor(np.load(os.path.join('{}/data/'.format(args.home_dir), args.dataset, 'user_popularity.npy')))
# item_feature = torch.FloatTensor(np.load(os.path.join('{}/data/'.format(args.home_dir), args.dataset, 'item_popularity.npy')))

source = args.source
target = args.dataset
user_embedding = None
item_embedding = None

save, _, _, _, _ = utils.load_feature(home_dir, source, args.feature, args.k)
_, user_feature, item_feature, user_stat, item_stat = utils.load_feature(home_dir, target, args.feature, args.k)

n_items = len(item_feature)
n_users = len(user_feature)

val_dataset = utils.EvalDatasetNeuMF(home_dir, user_stat, item_stat, target, args.cold, n_users, n_items, 'val')
val_loader = torch.utils.data.DataLoader(val_dataset, **eval_params)
test_dataset = utils.EvalDatasetNeuMF(home_dir, user_stat, item_stat, target, args.cold, n_users, n_items, 'test')
cold_test_dataset = utils.EvalDatasetNeuMF(home_dir, user_stat, item_stat, target, args.cold, n_users, n_items, 'test', 'c_c')
# Define data loaders.

test_loader = torch.utils.data.DataLoader(test_dataset, **eval_params)
cold_test_loader = torch.utils.data.DataLoader(cold_test_dataset, **eval_params)

###############################################################################
# Build the model
###############################################################################


model = Uni_model(args.embedding_size, n_users, n_items, user_feature, item_feature, user_embedding, item_embedding
                  , args.cold, args.pretrain).to(device)


def evaluate(source, model, eval_loader):
    model.eval()
    source.eval()
    truths = []
    preds = []
    n_10 = []
    r_10 = []
    with torch.no_grad():
        for batch_idx, data in enumerate(eval_loader):
            data = [x.to(device, non_blocking=True) for x in data]
            (users, items, label, u_stat, i_stat) = data

            scores = utils.cross(source, model, users, items, u_stat, i_stat)
            scores = scores.view(-1, label.shape[1])

            n_10.append(metric.NDCG_binary_at_k_batch_torch(scores, label, 10, device=device))
            r_10.append(metric.Recall_at_k_batch_torch(scores, label, 10))

            preds.append(scores.view(-1, 1).cpu().numpy())
            truths.append(label.view(-1, 1).cpu().numpy())

    pred = np.concatenate(preds)
    truth = np.concatenate(truths)
    auc = roc_auc_score(truth, pred)
    best_lambd = -1
    return auc, best_lambd, truth, pred, torch.mean(torch.cat(n_10)), torch.mean(torch.cat(r_10))


if __name__ == '__main__':
    # At any point you can hit Ctrl + C to break out of training early.
    print('Start testing')
    with open(save, 'rb') as f:
        source = torch.load(f)
    auc, lambd, truth, pred, n10, r10 = evaluate(source, model, test_loader)
    print('| auc: {:5.3f} | r10: {:5.3f} | n10: {:5.3f}'.format(auc, r10, n10))
    auc, lambd, truth, pred, n10, r10 = evaluate(source, model, cold_test_loader)
    print('| auc: {:5.3f} | r10: {:5.3f} | n10: {:5.3f}'.format(auc, r10, n10))
