# coding:utf8
import os
import time
import random
import numpy as np
import tensorflow.compat.v1 as tf
import pandas as pd
import sys
import argparse
from input import *
from toPlotData.split_evaluate import *
from utils import *
import pickle

from Model.CE import CE
from Model.IPS_Cap import Capping
from Model.banditNet import banditNet
from Model.SNIPS import SNIPS
from Model.minVar import minVar
from Model.stableVar import stableVar
from Model.adaptive import adaptive
from Model.POEM import POEM
from Model.POXM_topK import POXM
from Model.shrinkage import shrinkage
from Model.UIPS import UIPS
from Model.UIPS_O import UIPS_O
from Model.UIPS_P import UIPS_P


parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="OPUN")
parser.add_argument("--topN", default='[50]')
parser.add_argument("--train_batch_size", default='512')
parser.add_argument("--usedata", default='Kuai')
parser.add_argument("--capping", default='1')
parser.add_argument("--hyper", default='0.65')
parser.add_argument(
    '--OPUN_para',
    type=lambda x: {k:float(v) for k,v in (i.split(':') for i in x.split(','))},
    default='0:0',
    help='comma-separated field:position pairs, e.g. Date:0,Amount:2,Payee:5,Memo:9'
)


args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
random.seed(1234)
np.random.seed(1234)
tf.set_random_seed(1234)

train_batch_size = eval(args.train_batch_size)
topN = eval(args.topN)

print('[model name]',args.model_name)



data_dict = loadData(args.usedata)



gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
  root_path = 'uauc_Kuai/'
  tail_path = '/save_path/ckptTop50'
  if args.model_name == 'CE':
    model = CE(data_dict['user_count'], data_dict['item_count'], data_dict['cate_count'], data_dict['cate_list'])
    optimal_beta_model_path = root_path+'no_off_policy/cap0_hyper0_epoch20_batchsize512_lr_0.000001'+tail_path
  elif args.model_name == 'capping':
    model = Capping(data_dict['user_count'], data_dict['item_count'], data_dict['cate_count'],data_dict['cate_list'], eval(args.capping))
    optimal_beta_model_path = root_path+'capping/cap1_hyper0_epoch10_batchsize512_lr_0.000001'+tail_path
  elif args.model_name == 'banditNet':
    model = banditNet(data_dict['user_count'], data_dict['item_count'], data_dict['cate_count'], data_dict['cate_list'],eval(args.hyper))
    optimal_beta_model_path = root_path+'banditNet/cap0_hyper0.95_epoch10_batchsize512_lr_0.00001'+tail_path
  elif args.model_name == 'SNIPS':
    model = SNIPS(data_dict['user_count'], data_dict['item_count'], data_dict['cate_count'],data_dict['cate_list'])
    optimal_beta_model_path = root_path+'SNIPS/cap0_hyper0_epoch10_batchsize512_lr_0.000001'+tail_path
  elif args.model_name == 'minVar':
    model = minVar(data_dict['user_count'], data_dict['item_count'], data_dict['cate_count'],data_dict['cate_list'])
    optimal_beta_model_path = root_path+'minVar/cap0_hyper0_epoch10_batchsize512_lr_0.000001/'+tail_path
  elif args.model_name == 'stableVar':
    model = stableVar(data_dict['user_count'], data_dict['item_count'], data_dict['cate_count'], data_dict['cate_list'])
    optimal_beta_model_path = root_path+'stableVar/cap0_hyper0_epoch10_batchsize512_lr_0.000001'+tail_path
  elif args.model_name == 'adaptive':
    model = adaptive(data_dict['user_count'], data_dict['item_count'], data_dict['cate_count'], data_dict['cate_list'], eval(args.hyper), eval(args.capping))
    optimal_beta_model_path = root_path+'adaptive/cap1_hyper0.004_epoch10_batchsize512_lr_0.000001'+tail_path
  elif args.model_name == 'POEM':
    model = POEM(data_dict['user_count'], data_dict['item_count'], data_dict['cate_count'], data_dict['cate_list'],eval(args.capping), eval(args.hyper))
    optimal_beta_model_path = root_path+'POEM/cap1_hyper100.0_epoch10_batchsize512_lr_0.000001'+tail_path
  elif args.model_name == 'POXM':
    model = POXM(data_dict['user_count'], data_dict['item_count'], data_dict['cate_count'], data_dict['cate_list'], eval(args.hyper), eval(args.capping))
    optimal_beta_model_path = root_path+'POXM/cap0.95_hyper0.6_epoch10_batchsize512_lr_0.00001'+tail_path
  elif args.model_name =='shrinkage':
    model = shrinkage(data_dict['user_count'], data_dict['item_count'], data_dict['cate_count'],data_dict['cate_list'],eval(args.hyper))
    optimal_beta_model_path = root_path+'shrinkage/cap0_hyper10_epoch10_batchsize512_lr_0.000001'+tail_path
  elif args.model_name == 'OPUN':
    para = {}
    para['lambda_'] = 0.0
    para['eta_'] = 0.0
    para['lambdaDiff'] = 0.0
    para['normalize_phi_sa'] = 0.0
    para['gamma'] = 0.0
    para['cappingThre'] = 0.0
    para['cappingFirstEpoch'] = 0.0
    model = UIPS(user_count=data_dict['user_count'], item_count= data_dict['item_count'], 
                     cate_count=data_dict['cate_count'],cate_list=data_dict['cate_list'], para=para) 
    optimal_beta_model_path = root_path+'UIPS'+tail_path
  elif args.model_name == 'UIPS_O':
    para = {}
    para['lambda_'] = 0.0
    para['eta_'] = 0.0
    para['lambdaDiff'] = 0.0
    para['normalize_phi_sa'] = 0.0
    para['gamma'] = 0.0
    para['cappingThre'] = 0.0
    para['cappingFirstEpoch'] = 0.0
    model = UIPS_O(user_count=data_dict['user_count'], item_count=data_dict['item_count'],
                 cate_count=data_dict['cate_count'], cate_list=data_dict['cate_list'], para=args.OPUN_para)
    optimal_beta_model_path = root_path+'UIPS_O'+tail_path
  elif args.model_name == 'UIPS_P':
    para = {}
    para['lambda_'] = 0.0
    para['eta_'] = 0.0
    para['lambdaDiff'] = 0.0
    para['normalize_phi_sa'] = 0.0
    para['gamma'] = 0.0
    para['cappingThre'] = 0.0
    para['cappingFirstEpoch'] = 0.0
    model = UIPS_P(user_count=data_dict['user_count'], item_count=data_dict['item_count'],
                 cate_count=data_dict['cate_count'], cate_list=data_dict['cate_list'], para=args.OPUN_para)
    optimal_beta_model_path = root_path+'UIPS_P'+tail_path
  else:
    print('[Main] Unknown model.')
    raise AssertionError


  sess.run(tf.global_variables_initializer())
  sess.run(tf.local_variables_initializer())
  model.restore(sess, optimal_beta_model_path)
  
  precision, recall, NDCG = Split_candidate_ranking(sess, model, data_dict['mask'], data_dict['test_set'], data_dict['all_item'],topN, data_dict['log_item'], data_dict['mask_valid'], True)
  id_eval = {}
  for X_index in range(len(data_dict['test_set'])):
    id_eval[X_index] = {'X':data_dict['test_set'][X_index],
                            'precision': precision[X_index],
                            'recall': recall[X_index],
                            'ndcg':NDCG[X_index]}
  with open('Kuai_val/X_eval'+args.model_name+'.pkl', 'wb') as f:
    pickle.dump(id_eval, f, pickle.HIGHEST_PROTOCOL)
