# coding:utf8
import os
import time
import random
import numpy as np
import tensorflow.compat.v1 as tf
import pandas as pd
import sys
import pickle
import argparse

from affectOfUty.input_onlyX import *
from affectOfUty.split_evaluate import *
from utils import *
from pi.model.pi_UIPS import *
from pi.model.pi_IPS_Cap import *
from pi.model.pi_CE import *
from pi.model.pi_SNIPS import *
from pi.model.pi_minVar import *
from pi.model.pi_stableVar import *
from pi.model.pi_POEM import *
from pi.model.pi_POXM_topK import *
from pi.model.pi_adaptive import *
from pi.model.pi_banditNet import *
from pi.model.pi_shrinkage import *
from pi.model.pi_UIPS_O import *
from pi.model.pi_UIPS_P import *


parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="no_off_policy")
parser.add_argument("--topN", default='[5]')
parser.add_argument("--train_batch_size", default='32')
parser.add_argument("--usedata", default='Wiki_beta_hat')
parser.add_argument("--capping", default='1')
parser.add_argument("--hyper", default='0.65')
parser.add_argument(
    '--OPUN_para',
    type=lambda x: {k:float(v) for k,v in (i.split('-') for i in x.split(','))},
    default='0-0',
    help='comma-separated field-position pairs, e.g. Date-0,Amount-2,Payee-5,Memo-9'
)


args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
random.seed(1234)
np.random.seed(1234)
tf.set_random_seed(1234)

train_batch_size = eval(args.train_batch_size)
topN = eval(args.topN)
data_dict = loadData(args.usedata)


print('[model name]',args.model_name)
gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    summary_writer = tf.summary.FileWriter('./tf_summary/', sess.graph)
    if args.model_name == 'CE':
        optimal_beta_model_path = 'pi_model/nooff/save_path/ckptTop5'
        model = CE(data_dict['user_count'], data_dict['item_count'])
    elif args.model_name == 'capping':
        model = Pi_capping(data_dict['user_count'], data_dict['item_count'], eval(args.capping))
        optimal_beta_model_path = 'pi_model/capping/save_path/ckptTop5'
    elif args.model_name == 'banditNet':
        model = Pi_banditNet(data_dict['user_count'], data_dict['item_count'], eval(args.hyper))
        optimal_beta_model_path = 'pi_model/banditNet/save_path/ckptTop5'
    elif args.model_name == 'SNIPS':
        model = Pi_SNIPS(data_dict['user_count'], data_dict['item_count'])
        optimal_beta_model_path = 'pi_model/SNIPS/save_path/ckptTop5'
    elif args.model_name == 'minVar':
        model = Pi_minVar(data_dict['user_count'], data_dict['item_count'])
        optimal_beta_model_path = 'pi_model/minVar/save_path/ckptTop5'
    elif args.model_name == 'stableVar':
        model = Pi_stableVar(data_dict['user_count'], data_dict['item_count'])
        optimal_beta_model_path = 'pi_model/stableVar/save_path/ckptTop5'
    elif args.model_name == 'adaptive':
        model = Pi_adaptive(data_dict['user_count'], data_dict['item_count'], eval(args.hyper), eval(args.capping)) 
        optimal_beta_model_path = 'pi_model/adaptive/save_path/ckptTop5'
    elif args.model_name == 'POEM':
        model = Pi_POEM(data_dict['user_count'], data_dict['item_count'], eval(args.capping), eval(args.hyper))
        optimal_beta_model_path = 'pi_model/POEM/save_path/ckptTop5'
    elif args.model_name == 'POXM':
        model = Pi_POXM(data_dict['user_count'], data_dict['item_count'], eval(args.hyper))  
        optimal_beta_model_path = 'pi_model/POXM/save_path/ckptTop5'
    elif args.model_name == 'shrinkage':
        model = Pi_shrinkage(data_dict['user_count'], data_dict['item_count'], eval(args.hyper))
        optimal_beta_model_path = 'pi_model/shrinkage/save_path/ckptTop5'
    elif args.model_name == 'UIPS':
        model = UIPS(data_dict['user_count'], data_dict['item_count'], para=args.OPUN_para)
        optimal_beta_model_path = 'pi_model/OPUN/save_path/ckptTop5'
    elif args.model_name == 'UIPS_O':
        model = UIPS_O(data_dict['user_count'], data_dict['item_count'], para=args.OPUN_para)
        optimal_beta_model_path = 'pi_model/U_base/save_path/ckptTop5'
    elif args.model_name == 'UIPS_P':
        model = UIPS_P(data_dict['user_count'], data_dict['item_count'], para=args.OPUN_para)
        optimal_beta_model_path = 'pi_model/pos_U_base/save_path/ckptTop5'
    else:
        print('[Main] Unknown model.')
        raise AssertionError

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    model.restore(sess, optimal_beta_model_path)

    precision, recall, NDCG = Split_candidate_ranking(sess, model,data_dict['test_set'], topN, data_dict['log_item'])
    id_eval = {}
    for X_index in range(len(data_dict['test_set'])):
        id_eval[X_index] = {'X':data_dict['test_set'][X_index],
                            'precision': precision[X_index],
                            'recall': recall[X_index],
                            'ndcg':NDCG[X_index]}
    with open('X_eval'+args.model_name+'.pkl', 'wb') as f:
        pickle.dump(id_eval, f, pickle.HIGHEST_PROTOCOL)



