# -*- coding: utf-8 -*-
import numpy as np
import torch
import pdb
from sklearn.metrics import roc_auc_score
import pdb
import arguments

from dataset import load_data
from matrix_factorization_logistic_gpu_uidr import *
from utils import gini_index, ndcg_func, get_user_wise_ctr, rating_mat_to_sample, binarize, shuffle, minU, precision_func, recall_func
mse_func = lambda x,y: np.mean((x-y)**2)
acc_func = lambda x,y: np.sum(x == y) / len(x)
mae_func = lambda x,y: np.mean(np.abs(x-y))



def train_and_eval(dataset_name, train_args, model_args):

    top_k_list = [5]
    top_k_names = ("precision_5", "recall_5", "ndcg_5", "f1_5")
    if dataset_name == "coat":
        train_mat, test_mat = load_data("coat")        
        x_train, y_train = rating_mat_to_sample(train_mat)
        x_test, y_test = rating_mat_to_sample(test_mat)
        num_user = train_mat.shape[0]
        num_item = train_mat.shape[1]

    elif dataset_name == "yahoo":
        x_train, y_train, x_test, y_test = load_data("yahoo")
        x_train, y_train = shuffle(x_train, y_train)
        num_user = x_train[:,0].max() + 1
        num_item = x_train[:,1].max() + 1

    elif dataset_name == "kuai":
        x_train, y_train, x_test, y_test = load_data("kuai")
        num_user = x_train[:,0].max() + 1
        num_item = x_train[:,1].max() + 1
        top_k_list = [50]
        top_k_names = ("precision_50", "recall_50", "ndcg_50", "f1_50")

    np.random.seed(2020)
    torch.manual_seed(2020)

    print("# user: {}, # item: {}".format(num_user, num_item))
    # binarize
    if dataset_name == "kuai":
        y_train = binarize(y_train, 1)
        y_test = binarize(y_test, 1)
    else:
        y_train = binarize(y_train)
        y_test = binarize(y_test)

    
    # UDR
    mf_udr = MF_UDR(num_user, num_item, 
    embedding_k_pred=model_args["embedding_k_pred"], embedding_k_impu=model_args["embedding_k_impu"], 
    embedding_k_prop=model_args["embedding_k_prop"], embedding_k_base_prop=model_args["embedding_k_base_prop"])
    mf_udr.cuda()
    mf_udr.fit(x_train[:, 0], x_train[:, 1], y_train,
            L1 = model_args["L1"], L2 = model_args["L2"], L3 = model_args["L3"],
            L4 = model_args["L4"], L5 = model_args["L5"],
            lr_pred=model_args["lr_pred"], lamb_pred=model_args["lamb_pred"], 
            lr_prop=model_args["lr_prop"], lamb_prop=model_args["lamb_prop"], 
            lr_base_prop=model_args["lr_base_prop"], lamb_base_prop=model_args["lamb_base_prop"],
            lr_impu=model_args["lr_impu"], lamb_impu=model_args["lamb_impu"], 
            gamma = train_args["gamma"], batch_size=train_args["batch_size"], 
            tol=train_args["tol"], verbose=train_args["verbose"])

    test_pred = mf_udr.predict(x_test)
    mse_mf = mse_func(y_test, test_pred)
    mae_mf = mae_func(y_test, test_pred)
    precisions = precision_func(mf_udr, x_test, y_test, top_k_list)
    recalls = recall_func(mf_udr, x_test, y_test, top_k_list)
    auc = roc_auc_score(y_test, test_pred)
    ndcgs = ndcg_func(mf_udr, x_test, y_test, top_k_list)

    f1 = 2 / (1 / np.mean(precisions[top_k_names[0]]) + 1 / np.mean(recalls[top_k_names[1]]))

    print("***"*5 + "[MF]" + "***"*5)
    print("[MF] test mse:", mse_mf)
    print("[MF] test mse:", mae_mf)
    print("[MF] test auc:", auc)
    print("[MF] {}:{:.6f}".format(
            top_k_names[2].replace("_", "@"), np.mean(ndcgs[top_k_names[2]])))
    print("[MF] {}:{:.6f}".format(top_k_names[3].replace("_", "@"), f1))
    print("[MF] {}:{:.6f}".format(
            top_k_names[0].replace("_", "@"), np.mean(precisions[top_k_names[0]])))
    print("[MF] {}:{:.6f}".format(
            top_k_names[1].replace("_", "@"), np.mean(recalls[top_k_names[1]])))
    user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)
    gi,gu = gini_index(user_wise_ctr)
    print("***"*5 + "[MF-UDR]" + "***"*5)

def para(args):
    if args.dataset=="coat":
        args.train_args = {"batch_size":128, "gamma":0.05, "tol":1e-4, "verbose":False}
        args.model_args = {"embedding_k_pred":4, "embedding_k_impu":8, "embedding_k_prop":4, 
        "embedding_k_base_prop":4,
        "L1" : 10, "L2" : 100, "L3" : 0.05, "L4" : 1, "L5" : 1,
        "lr_pred":0.05, "lamb_pred":0.001, "lr_prop":0.005, "lamb_prop":0.01, 
        "lr_base_prop":0.005, "lamb_base_prop":0.01, "lr_impu":0.005, "lamb_impu":0.001}
    elif args.dataset=="yahoo":
        args.train_args = {"batch_size":1024, "gamma":0.001, "tol":1e-5, "verbose":False}
        args.model_args = {"embedding_k_pred":32, "embedding_k_impu": 32, "embedding_k_prop": 32, 
        "embedding_k_base_prop": 32,
        "L1" : 100, "L2" : 1000, "L3" : 1, "L4" : 0.1, "L5" : 1,
        "lr_pred":0.005, "lamb_pred":5e-5, "lr_prop":0.005, "lamb_prop":5e-5, 
        "lr_base_prop":0.005, "lamb_base_prop":5e-5, "lr_impu":0.005, "lamb_impu":5e-5}
    elif args.dataset=="kuai":
        args.train_args = {"batch_size":1024, "gamma":0.001, "tol":1e-5, "verbose":False}
        args.model_args = {"embedding_k_pred": 4, "embedding_k_impu": 16,
        "embedding_k_prop": 16, "embedding_k_base_prop": 16,
        "L1" : 1e4, "L2" : 5e7, "L3" : 1e-3, "L4" : 1e1, "L5" : 5e1,
        "lr_pred":0.1, "lamb_pred":1e-3, "lr_prop":0.1, "lamb_prop":1e-5,
        "lr_base_prop":0.1, "lamb_base_prop":1e-3, "lr_impu":0.1, "lamb_impu":1e-3}
    return args
    

if __name__ == "__main__":
    args = arguments.parse_args()
    para(args=args)

    train_and_eval(args.dataset, args.train_args, args.model_args)

