# coding:utf8
import os
import time
import random
import numpy as np
import tensorflow.compat.v1 as tf
import pandas as pd
import sys
import pickle
import argparse

from affectOfUty.input_onlyX.py import *
from affectOfUty.split_evaluate import *
from utils import *
from pi.model.pi_OPUN import *
from pi.model.pi_capping import *
from pi.model.pi_nooff import *
from pi.model.pi_SNIPS import *
from pi.model.pi_minVar import *
from pi.model.pi_stableVar import *
from pi.model.pi_POEM import *
from pi.model.pi_POXM_topK import *
from pi.model.pi_adaptive import *
from pi.model.pi_banditNet import *
from pi.model.pi_shrinkage import *


parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="no_off_policy")
parser.add_argument("--optimal_beta_model_path", default="lr_0.00001_size_100/save_path/ckptTop20")
parser.add_argument("--epochs", default='5')
parser.add_argument("--topN", default='[5,10,20,50]')
parser.add_argument("--lr", default='0.0001')
parser.add_argument("--train_batch_size", default='32')
parser.add_argument("--usedata", default='Wiki_beta_hat')
parser.add_argument("--capping", default='1')
parser.add_argument("--hyper", default='0.65')
parser.add_argument(
    '--OPUN_para',
    type=lambda x: {k:float(v) for k,v in (i.split('-') for i in x.split(','))},
    default='0-0',
    help='comma-separated field-position pairs, e.g. Date-0,Amount-2,Payee-5,Memo-9'
)


args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
random.seed(1234)
np.random.seed(1234)
tf.set_random_seed(1234)

train_batch_size = eval(args.train_batch_size)
topN = eval(args.topN)
lr = eval(args.lr)

hyperparameter = {}

print('[model name]',args.model_name)
print('[optimal_beta_model_path]', args.optimal_beta_model_path)
print('[hyperparameter]',args.hyper)
print('[data]', args.usedata)
print('[epochs]', args.epochs)
print('[topN]', args.topN)
print('[learning rate]', args.lr)
print('[train batch size]', args.train_batch_size)
print('[if capping]',args.capping)
print('[OPUN_para]', args.OPUN_para)

data_dict = loadData(args.usedata)



gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    summary_writer = tf.summary.FileWriter('./tf_summary/', sess.graph)
    if args.model_name == 'no_off_policy':
        model = nooff(data_dict['user_count'], data_dict['item_count'])
    elif args.model_name == 'capping':
        model = Pi_capping(data_dict['user_count'], data_dict['item_count'], eval(args.capping))
    elif args.model_name == 'banditNet':
        model = Pi_banditNet(data_dict['user_count'], data_dict['item_count'], eval(args.hyper))
    elif args.model_name == 'SNIPS':
        model = Pi_SNIPS(data_dict['user_count'], data_dict['item_count'])
    elif args.model_name == 'minVar':
        model = Pi_minVar(data_dict['user_count'], data_dict['item_count'])
    elif args.model_name == 'stableVar':
        model = Pi_stableVar(data_dict['user_count'], data_dict['item_count'])
    elif args.model_name == 'adaptive':
        model = Pi_adaptive(data_dict['user_count'], data_dict['item_count'], eval(args.hyper), eval(args.capping))  # 这里的capping在模型里面表现为alpha
    elif args.model_name == 'POEM':
        model = Pi_POEM(data_dict['user_count'], data_dict['item_count'], eval(args.capping), eval(args.hyper))
    elif args.model_name == 'POXM':
        model = Pi_POXM(data_dict['user_count'], data_dict['item_count'], eval(args.hyper))  # 这里的capping在模型里面表现为alpha
    elif args.model_name == 'shrinkage':
        model = Pi_shrinkage(data_dict['user_count'], data_dict['item_count'], eval(args.hyper))
    elif args.model_name == 'OPUN':
        model = Pi_OPUN(data_dict['user_count'], data_dict['item_count'], para=args.OPUN_para)
    else:
        print('[Main] Unknown model.')
        raise AssertionError

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    optimal_beta_model_path = ''  # 在debug模式下运行，这里先空着
    model.restore(sess, optimal_beta_model_path)

    precision, recall, NDCG = Split_candidate_ranking(sess, model,data_dict['test_set'], topN, data_dict['log_item'])
    id_eval = {}
    for X_index in range(len(data_dict['test_set'])):
        id_eval[X_index] = {'X':data_dict['test_set'][X_index],
                            'precision': precision[X_index],
                            'recall': recall[X_index],
                            'ndcg':NDCG[X_index]}
    with open('X_eval.pkl', 'wb') as f:
        pickle.dump(id_eval, f, pickle.HIGHEST_PROTOCOL)



