# coding:utf8
import os
import time
import pickle
import random
import numpy as np
import tensorflow.compat.v1 as tf
import pandas as pd
import sys
import argparse
from affectOfUty.input_onlyX import *
from utils import *
from beta_star.beta_star_model import *
from beta_star.beta_star_model_v2 import *


parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="Beta_Star")
parser.add_argument("--lr", default='0.0001')
parser.add_argument("--train_batch_size", default='512')
parser.add_argument("--usedata", default='Wiki_beta_hat')
parser.add_argument(
    '--para',
    type=lambda x: {k:float(v) for k,v in (i.split(':') for i in x.split(','))},
    default='temp:0.5',
    help='comma-separated field:position pairs, e.g. Date:0,Amount:2,Payee:5,Memo:9'
)
parser.add_argument("--optimal_beta_model_path", default='beta_hat10/lr_0.00001_size_100/save_path/ckptTop20')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
random.seed(1234)
np.random.seed(1234)
tf.set_random_seed(1234)

train_batch_size = eval(args.train_batch_size)
lr = eval(args.lr)


data_dict = loadData('Wiki_beta_hat')

gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    model = BetaHat(data_dict['user_count'], data_dict['item_count'],hidden_dim=64)
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    model.restore(sess, args.optimal_beta_model_path)

    item_info = {}
    X_index = 0
    for _, uij in DataInputSyn(data_dict['train_set'], train_batch_size):
        beta_hat_prob = model.run_eval(sess, uij[0])[0] # [B,#item]
        beta_uncertainty = model.getSUncertainty(sess, uij)   # [B,1] 
        for i in range(len(uij[0])):  # n user
            if uij[-1][i]==0: # display = 0
                continue
            item = uij[1][i]
            if item not in item_info.keys():
                item_info[item] = {'frequency': 0, 'uncertainty': 0.0, 'beta_hat': 0.0}
            item_info[item]['frequency']+=1
            item_info[item]['uncertainty']+=beta_uncertainty[i]
            item_info[item]['beta_hat']+=beta_hat_prob[i][item]

    with open('item_info.pkl', 'wb') as f:
        pickle.dump(item_info, f, pickle.HIGHEST_PROTOCOL)