# coding:utf8
import os
import time
import random
import numpy as np
import tensorflow.compat.v1 as tf
import pandas as pd
import sys
import argparse

from input import *
from evaluate import *
from utils import *
from beta_hat.beta_hat_model import *

parser = argparse.ArgumentParser()
parser.add_argument("--epochs", default='5')
parser.add_argument("--topN", default='[20,50]')
parser.add_argument("--lr", default='0.0001')
parser.add_argument("--train_batch_size", default='512')
parser.add_argument("--sample_size", default='100')
parser.add_argument("--usedata", default='Wiki_beta_hat')

args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
random.seed(1234)
np.random.seed(1234)
tf.set_random_seed(1234)

train_batch_size = eval(args.train_batch_size)
topN = eval(args.topN)
lr = eval(args.lr)

hyperparameter = {}

print('[data]', args.usedata)
print('[epochs]', args.epochs)
print('[topN]', args.topN)
print('[learning rate]', args.lr)
print('[train batch size]', args.train_batch_size)

data_dict = loadData(args.usedata,int(args.sample_size))
uid_pos = loadUid_pos()

optimal_beta_model_path = 'lr_0.00001/save_path/ckptTop20'
gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    summary_writer = tf.summary.FileWriter('./tf_summary/', sess.graph)
    model = BetaHat(data_dict['user_count'],data_dict['item_count'])
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    model.restore(sess, optimal_beta_model_path)
    output_value_list = [getInfo() for i in range(len(topN))]
    best_epoch = [0 for _ in range(len(topN))]
    train_set = data_dict['train_set']
    print('len user',len(train_set))


    valid_pos_result, valid_user_pred = candidate_ranking(sess, model, data_dict['valid_set'], topN, data_dict['log_item'])
    test_pos_result, test_user_pred = candidate_ranking(sess, model,data_dict['test_set'], topN, data_dict['log_item'])

    print('[Valid set] Precision: {}\tRecall: {}\tNDCG: {}'.format(
            valid_pos_result[0], valid_pos_result[1], valid_pos_result[2]))
    print('[Test set] Precision: {}\tRecall: {}\tNDCG: {}'.format(
            test_pos_result[0], test_pos_result[1], test_pos_result[2]))

    start_time = time.time()
    user_info = {}
    uid = 0
    #### 这里input的东西是uid_pos
    count = 0
    for _, uij in DataInputSyn(train_set, train_batch_size):
        uty = model.getUncertainty2(sess, uij[0])
        count+=1
        for i in range(len(uij[0])):
            user_info[uid] = {'pop': uid_pos[uid][-1],'uncertainty': uty[i][0]}
            uid+=1
    print('len key',len(user_info.keys()))
    print('uid',uid)

    plot_scatter(user_info, 'save_path/best_epoch')