# coding:utf8
import os
import time
import pickle
import random
import numpy as np
import tensorflow.compat.v1 as tf
import pandas as pd
import sys
import argparse
from affectOfUty.input_onlyX import *
from utils import *
from beta_star.beta_star_model import *
from beta_star.beta_star_model_v2 import *


parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="Beta_Star")
parser.add_argument("--lr", default='0.0001')
parser.add_argument("--train_batch_size", default='32')
parser.add_argument("--usedata", default='Wiki_beta_hat')
parser.add_argument(
    '--para',
    type=lambda x: {k:float(v) for k,v in (i.split(':') for i in x.split(','))},
    default='temp:0.5',
    help='comma-separated field:position pairs, e.g. Date:0,Amount:2,Payee:5,Memo:9'
)
parser.add_argument("--optimal_beta_model_path", default='beta_hat10/lr_0.00001_size_100/save_path/ckptTop20')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
random.seed(1234)
np.random.seed(1234)
tf.set_random_seed(1234)

train_batch_size = eval(args.train_batch_size)
lr = eval(args.lr)


data_dict = loadData('Wiki_beta_hat')
data_dict_start = loadData('Wiki_beta_star')
beta_hat = ImportGraph(args.optimal_beta_model_path, data_dict)

gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    if args.model_name == 'beta_star_v2':
        model = BetaStar_V2(data_dict['user_count'],data_dict['item_count'],data_dict['feature_count'], v2_para=args.para)
    else:
        model = BetaStart(data_dict['user_count'],data_dict['item_count'],data_dict['feature_count'], para=args.para)
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    optimal_beta_model_path = 'optimal_beta_start_v1/ckptTop20'
    model.restore(sess, optimal_beta_model_path)

    id_uty_diff = {}
    X_index = 0
    for _, uij, uij_start in DataInputOnlyX(data_dict['test_set'], data_dict_start['test_set'], train_batch_size):  # 输出正例的
        beta_hat_prob, beta_uncertainty = beta_hat.getAllProb(uij)   # getAUncertainty 这里输出的是所有的,之后根据id去选择正例的
        beta_hat_prob = beta_hat_prob[0]
        beta_uncertainty = beta_uncertainty[0]
        beta_start_prob = beta_hat_prob
        _, beta_start_prob = model.run_eval(sess, uij_start, lr)
        for i in range(len(uij[0])):
            pos_uty = beta_uncertainty[i]
            pos_hat_prob = []
            pos_start_prob = []
            for item in range(len(uij[1][i])):
                pos_hat_prob.append(beta_hat_prob[i][item])
                pos_start_prob.append(beta_start_prob[i][item])
            pos_hat_prob = np.array(pos_hat_prob)
            pos_start_prob = np.array(pos_start_prob)
            id_uty_diff[X_index] = {'X':uij[0][i],
                                    'num_pos':uij[2][i],
                                    'avg_uty': pos_uty,
                                    'avg_diff': np.sum(np.square(pos_hat_prob-pos_start_prob))/uij[2][i]}  # 平方和/n
            
            X_index += 1
    with open('X_uty_diff.pkl', 'wb') as f:
        pickle.dump(id_uty_diff, f, pickle.HIGHEST_PROTOCOL)