# coding:utf8
import os
import time
import random
import numpy as np
import tensorflow.compat.v1 as tf
import pandas as pd
import sys
import argparse

from input import *
from evaluate import *
from utils import *
from beta_hat.beta_hat_model import *
from off_evaluation import *

parser = argparse.ArgumentParser()
parser.add_argument("--policy_name", default='capping')
parser.add_argument("--ckpt_path", default='beta_hat10/lr_0.00001_size_100/save_path/ckptTop20')
parser.add_argument("--epochs", default='5')
parser.add_argument("--topN", default='[20,50]')
parser.add_argument("--lr", default='0.0001')
parser.add_argument("--train_batch_size", default='512')
parser.add_argument("--sample_size", default='1000')
parser.add_argument("--usedata", default='Wiki_beta_hat_temp2_mse')
parser.add_argument("--hidden_dim", default=64)
parser.add_argument(
    '--para',
    type=lambda x: {k:float(v) for k,v in (i.split(':') for i in x.split(','))},
    default='0:0',
    help='comma-separated field:position pairs, e.g. Date:0,Amount:2,Payee:5,Memo:9'
)


args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
random.seed(1234)
np.random.seed(1234)
tf.set_random_seed(1234)

train_batch_size = eval(args.train_batch_size)
topN = eval(args.topN)
lr = eval(args.lr)

hyperparameter = {}

print('[data]', args.usedata)
print('[ckpt_path]', args.ckpt_path)
print('[epochs]', args.epochs)
print('[topN]', args.topN)
print('[learning rate]', args.lr)
print('[train batch size]', args.train_batch_size)

data_dict = loadData(args.usedata,int(args.sample_size))
print('[main] Finish loading data.')

gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    model = BetaHat(data_dict['user_count'],data_dict['item_count'], hidden_dim=args.hidden_dim)
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    model.restore(sess, args.ckpt_path)
    print('Finish restoring!')
    

    valid_pos_result, valid_user_pred = candidate_ranking(sess, model, data_dict['valid_set'], topN, data_dict['log_item'])
    test_pos_result, test_user_pred = candidate_ranking(sess, model,data_dict['test_set'], topN, data_dict['log_item'])

    print('[Valid set] Precision: {}\tRecall: {}\tNDCG: {}'.format(
            valid_pos_result[0], valid_pos_result[1], valid_pos_result[2]))
    print('[Test set] Precision: {}\tRecall: {}\tNDCG: {}'.format(
            test_pos_result[0], test_pos_result[1], test_pos_result[2]))



    ground_truth_v = []
    estimated_v = []
    valid_mse = policy_evaluation_mse(sess, model, args.policy_name,  data_dict['valid_set'],data_dict['item_count'], para=args.para)
    test_mse = policy_evaluation_mse(sess, model, args.policy_name,  data_dict['test_set'],data_dict['item_count'], para=args.para)
    
    print('~~~~~~~~~~MSE Evaluation ~~~~~~~~~~')
    print('[Valid set] GT:', valid_mse[0], ", ES: ", valid_mse[1])
    print('[Test set] GT:', test_mse[0], ", ES: ", test_mse[1])

