#!/usr/bin/env python
# coding: utf-8

# In[1]:


import sys
sys.path.append('..')
import torch
import numpy as np
import lib
import argparse

import os
import wandb

# parser
parser = argparse.ArgumentParser(description='DeFoRec_k_ary')
parser.add_argument('--k_ary', type=int, default=3, help='k_ary')
parser.add_argument('--weight_decay', type=float, default=1e-5, help='weight_decay')
parser.add_argument('--sample_num', type=int, default=100, help='sample_num')
parser.add_argument('--emb_dim', type=int, default=24, help='emb_dim')
parser.add_argument('--run_name', type=str, required=True)
parser.add_argument('--gpu_id', type=int, default=0, help='gpu_id')
parser.add_argument('--dataset', type=str, default='movie', help='dataset', choices=['mind', 'gowalla', 'movie'])
parser.add_argument('--batch_gap_to_rebuild_tree', type=int, default=80000, help='batch_gap_to_rebuild_tree')
parser.add_argument('--debug', action='store_true', help='debug mode', default=False)
parser.add_argument('--repeat_time', type=int, default=1, help='repeat_time')
args = parser.parse_args()

# parametres
k_ary = args.k_ary
max_calculate_num=25000
get_weights_mode='parallel' 
update_tree_gap=4
sampling_method='uniform_k_ary'#top_down,softmax,all_negative_sampling,uniform_multiclass

weight_decay=args.weight_decay
optimizer=lambda params:torch.optim.Adam(params,lr=1.0e-3,amsgrad=True,weight_decay=weight_decay)
data_set_name = args.dataset
# device='cuda:3'# default device is cpu if device_ids=[] or None
device = 'cuda:{}'.format(args.gpu_id)
tree_num=1 #12
repeat_time=1
runtime=1# total 10 times,each runtime correspondint to one data partition
has_processed_data=True
topk=20
N=args.sample_num#if negative_num is None, compute the negative_num by N in trainer
train_sample_seg_cnt=10#the training data is located in the train_sample_seg_cnt datafiles
parall=4
seq_len=70 # se_len-1 is the number of behaviours in all the windows
min_seq_len=15
test_user_num=6000# the number of user in test file
tree_learner_mode='jtm'
gamma=0.0


item_node_share_embedding=True
raw_data_file='../../data/{}/{}.txt'.format(data_set_name,data_set_name)

data_file_prefix='../../data/{}/processed_dataset/'.format(data_set_name,runtime)

if not has_processed_data:
    if os.path.exists(data_file_prefix):
        pass
    else:
        os.makedirs(data_file_prefix)
train_instances_file=data_file_prefix+'train_instances'
test_instances_file=data_file_prefix+'test_instances'
validation_instances_file=data_file_prefix+'validation_instances'
kv_file=data_file_prefix+'kv.txt' # save the key vavlue (i.e. item_id:leaf_code)
result_prefix='../../data/{}/DeFoRec_K_ary_one_tree/K_{}/'.format(data_set_name, k_ary )+'result_of_N_{}_share_embedding_{}_wd_{}_dim_{}_gap_{}_repeat_{}/'\
                                                            .format(N,item_node_share_embedding,weight_decay,args.emb_dim,args.batch_gap_to_rebuild_tree, args.repeat_time)


# the file path prefix to save the result


featrue_groups=[20,20,10,10,2,2,2,1,1,1]
assert sum(featrue_groups)==seq_len-1

embed_dim=args.emb_dim


# sample_num=100
training_batch_size=100 #500
validation_batch_size=50 if args.k_ary <= 32 else 30 
# batch_gap_to_rebuild_tree=80000 #16000#float('inf')
batch_gap_to_rebuild_tree=args.batch_gap_to_rebuild_tree


eps=0.000001
if device!='cpu':
    torch.cuda.set_device(device)#the main gpu is device_ids[0]
print(result_prefix)

if args.debug:
    os.environ['WANDB_MODE'] = 'dryrun'
wandb.init(project='DeFoRec_K_ary_{}'.format(args.dataset), name=args.run_name, config=args)

# In[2]:


def presision(result_list,gt_list,top_k):
    count=0.0
    for r,g in zip(result_list,gt_list):
       count+=len(set(r).intersection(set(g)))
    return count/(top_k*len(result_list))
def recall(result_list,gt_list):
    t=0.0
    for r,g in zip(result_list,gt_list):
        t+=1.0*len(set(r).intersection(set(g)))/len(g)
    return t/len(result_list)
def f_measure(result_list,gt_list,top_k,eps=1.0e-9):
    f=0.0
    for r,g in zip(result_list,gt_list):
        recc=1.0*len(set(r).intersection(set(g)))/len(g)
        pres=1.0*len(set(r).intersection(set(g)))/top_k
        if recc+pres<eps:
            continue
        f+=(2*recc*pres)/(recc+pres)
    return f/len(result_list)
def novelty(result_list,s_u,top_k):
    count=0.0
    for r,g in zip(result_list,s_u):
        count+=len(set(r)-set(g))
    return count/(top_k*len(result_list))

def hit_ratio(result_list,gt_list):
    intersetct_set=[len(set(r)&set(g)) for r,g in zip(result_list,gt_list)]
    return 1.0*sum(intersetct_set)/sum([len(gts) for gts in gt_list])

def NDCG(result_list,gt_list):
    t=0.0
    for re,gt in zip(result_list,gt_list):
        setgt=set(gt)
        indicator=np.asfarray([1 if r in setgt else 0 for r in re])
        sorted_indicator = np.ones(min(len(setgt), len(re)))
        if 1 in indicator:
            t+=np.sum(indicator / np.log2(1.0*np.arange(2,len(indicator)+ 2)))/\
               np.sum(sorted_indicator/np.log2(1.0*np.arange(2,len(sorted_indicator)+ 2)))
    return t/len(gt_list)

def MAP(result_list,gt_list,topk):
    t=0.0
    for re,gt in zip(result_list,gt_list):
        setgt=set(gt)
        indicator=np.asfarray([1 if r in setgt else 0 for r in re])
        t+=np.mean([indicator[:i].sum(-1)/i for i in range(1,topk+1)],axis=-1)
    return t/len(gt_list)


# In[3]:

from lib.Generate_Data_and_Tree import _read,_gen_train_sample,_gen_test_sample,_init_tree,_gen_discriminator_samples
import gc
import numpy as np
if not has_processed_data: 
    behavior_dict, train_sample, test_sample,validation_sample = _read(raw_data_file,test_user_num)  # 20 is the test users
    # write the training instance into different train_sample_seg_cnt files， avoid that a file is too large
    # stat record the click frequency of each item
    # seq_len=20 min that 19 behaviors and one label
    stat = _gen_train_sample(train_sample, train_instances_file,test_sample=test_sample,validation_sample=validation_sample,
                                                    train_sample_seg_cnt=train_sample_seg_cnt,
                                                    parall=parall, seq_len=seq_len, min_seq_len=min_seq_len)
    _gen_test_sample(test_sample, test_instances_file, seq_len=seq_len,min_seq_len=min_seq_len)
    _gen_test_sample(validation_sample, validation_instances_file, seq_len=seq_len,min_seq_len=min_seq_len)
    
    #_gen_discriminator_samples(train_sample,discriminator_instances_file, train_sample_seg_cnt=train_sample_seg_cnt,
    #                                                parall=parall, seq_len=seq_len, min_seq_len=min_seq_len)
    ids, codes = _init_tree(train_sample, test_sample,validation_sample, stat, kv_file=kv_file)
    del behavior_dict
    del train_sample
    del test_sample
    del stat
    gc.collect()

else:
    ids=[]
    codes=[]
    assert kv_file is not None
    with open(kv_file) as f:
        while True:
            line=f.readline()
            if line:
                id_code=line.split('::')
                ids.append(int(id_code[0]))
                codes.append(int(id_code[1]))
            else:
                break
    ids=np.array(ids,dtype=np.int32)
    codes=np.array(codes,dtype=np.int32)
print('min item id is {}, max item id is {}'.format(ids.min(),ids.max()))
print('min leaf node code is {}, max leaf node code is {}'.format(codes.min(), codes.max()))

ids_list,codes_list=[],[]
for _ in range(tree_num):
    ids_list.append(ids)
    codes_list.append(codes)
# print(ids_list[0])
item_num=len(ids_list[0])
print('item number is {}'.format(item_num))


# In[4]:


from pandas import DataFrame
from IPython.display import clear_output
import matplotlib.pyplot as plt
import numpy as np

from lib.generate_training_batches import Train_instance
train_instances=Train_instance(parall=parall)
training_batch_generator=train_instances.training_batches(train_instances_file,train_sample_seg_cnt,batchsize=training_batch_size)
validation_batch_generator=train_instances.validation_batches(validation_instances_file,batchsize=validation_batch_size)
test_instances=train_instances.read_test_instances_file(test_instances_file)
training_instance_index_pair=train_instances.get_item_instance_pair_index(train_instances_file,train_sample_seg_cnt)#


# In[5]:


from lib.trainer import TrainModel
train_model=TrainModel(ids,codes,
                      embed_dim=embed_dim,
                      feature_groups=featrue_groups,
                      all_training_instance=train_instances.training_data,
                      item_user_pair_dict=training_instance_index_pair,
                      parall=parall, 
                      optimizer=optimizer,
                      N=N,
                      sampling_method=sampling_method,
                      tree_learner_mode=tree_learner_mode,
                      item_node_share_embedding=item_node_share_embedding,
                      device=device,
                      gamma=gamma,
                      k=k_ary,
                      max_calculate_num=max_calculate_num,
                      get_weights_mode=get_weights_mode
                     )

# print(train_model.tmd_model.device)
# print(train_model.network_model)
# print(train_model.tmd_model.linear_part_0)
# for name,para in train_model.tmd_model.named_parameters():
#    print(name)


# In[7]:


import matplotlib.pyplot as plt
import numpy as np
loss_history,dev_precision_history,dev_recall_history,dev_f_measure_history,dev_novelty_history=[],[],[],[],[]
total_precision_history,total_recall_history,total_f_measure_history,total_novelty_history=[],[],[],[]
if os.path.exists(result_prefix):
    pass
else:
    os.makedirs(result_prefix)
# train_log_file = open(result_prefix + 'train.log', 'w')
# print('tree number is {}'.format(tree_num))


# In[8]:


import time

time_start_updating_network = time.time()
for (batch_x,batch_y) in training_batch_generator:
    loss=train_model.update_network_model(batch_x,batch_y)
    wandb.log({'Train/loss': loss.item()})

    if train_model.batch_num% 500 ==0: #5000
        train_model.network_model.eval()
        bs=validation_batch_size
        bs_count=(len(test_instances)-1)//bs+1
        all_result=np.zeros((len(test_instances),topk),dtype=np.int32)

        for i in range(bs_count):
            bs_user=test_instances[i*bs:(i+1)*bs]
            if sampling_method == 'uniform_k_ary':
                all_result[i*bs:(i+1)*bs]=train_model.predict_k_ary(bs_user,100,topk,k_ary,forest=False)
            else:
                all_result[i*bs:(i+1)*bs]=train_model.predict(bs_user,100,topk,forest=False)
        resutl_history=all_result.tolist()
        test_precison = presision(resutl_history,train_instances.test_labels,topk)
        test_recall = recall(resutl_history,train_instances.test_labels)
        test_f_measure = f_measure(resutl_history,train_instances.test_labels,topk)
        test_novelty = novelty(resutl_history,test_instances.tolist(),topk)
        test_hit_ratio = hit_ratio(resutl_history,train_instances.test_labels)
        test_ndcg = NDCG(resutl_history,train_instances.test_labels)
        test_map = MAP(resutl_history,train_instances.test_labels,topk)
        wandb.log(
            {
                "Test/precision": test_precison,
                "Test/recall": test_recall,
                "Test/f_measure": test_f_measure,
                "Test/novelty": test_novelty,
                "Test/hit_ratio": test_hit_ratio,
                "Test/NDCG": test_ndcg,
                "Test/MAP": test_map,
            }
        )
        train_model.network_model.train()

    if train_model.batch_num% 100 == 0: #100
        train_model.network_model.eval()
        validation_batch,validation_index=validation_batch_generator.__next__()
        gt_history=[train_instances.validation_labels[i.item()] for i in validation_index]
        if sampling_method == 'uniform_k_ary':
            resutl_history=train_model.predict_k_ary(validation_batch,100,topk,k_ary,forest=False).tolist()
        else:
            resutl_history=train_model.predict(validation_batch,100,topk,forest=False).tolist()
        dev_precison = presision(resutl_history,train_instances.test_labels,topk)
        dev_recall = recall(resutl_history,train_instances.test_labels)
        dev_f_measure = f_measure(resutl_history,train_instances.test_labels,topk)
        dev_novelty = novelty(resutl_history,test_instances.tolist(),topk)
        dev_hit_ratio = hit_ratio(resutl_history,train_instances.test_labels)
        dev_ndcg = NDCG(resutl_history,train_instances.test_labels)
        dev_map = MAP(resutl_history,train_instances.test_labels,topk)
        wandb.log(
            {
                "Validation/precision": dev_precison,
                "Validation/recall": dev_recall,
                "Validation/f_measure": dev_f_measure,
                "Validation/novelty": dev_novelty,
                "Validation/hit_ratio": dev_hit_ratio,
                "Validation/NDCG": dev_ndcg,
                "Validation/MAP": dev_map,
            }
        )
        train_model.network_model.train()

    if train_model.batch_num>=batch_gap_to_rebuild_tree:
        model_path=result_prefix+"{}_tree_{}_network_model.pth".format(tree_learner_mode,len(train_model.tree_list))
        if os.path.exists(model_path):
            os.remove(model_path)
        torch.save(train_model.network_model,model_path)
        
        tree_path=result_prefix+"{}_tree_{}_tree.txt".format(tree_learner_mode,len(train_model.tree_list))
        with open(tree_path,'w') as f:
            for item_id,code in train_model.tree.item_id_leaf_code.items():
                line=str(item_id)+"::"+str(code)+'\n'
                f.write(line)
        break


# In[9]:


# single last tree
import time

resutl_history=[]
topk=120


print('k_ary:{} gap:{}'.format(k_ary,update_tree_gap))

st=time.time()
for beam_size in [200, 300, 400, 500, 100]:
    print('******************* beam_size: {} *****************'.format(beam_size))

    if args.k_ary > 32:
        bs = 10
        bs_count = (len(test_instances) - 1) // bs + 1
    elif args.k_ary > 10 and beam_size >= 200:
        bs = 10
        bs_count = (len(test_instances) - 1) // bs + 1
    elif args.k_ary > 6 and beam_size >= 300:
        bs = 20
        bs_count = (len(test_instances) - 1) // bs + 1
    else:
        bs = 50
        bs_count = (len(test_instances) - 1) // bs + 1
        
    if beam_size == 100:
        topk = 100

    output_file_path_dict = {}
    output_file_dict = {}
    
    i = -1
    all_result=np.zeros((len(test_instances),topk),dtype=np.int32)
    train_model.network_model = train_model.model_list[i]
    train_model.tree = train_model.tree_list[i]
    train_model.network_model.eval()

    for j in range(bs_count):
        bs_user=test_instances[j*bs:(j+1)*bs]
        if sampling_method == 'uniform_k_ary':
            with torch.no_grad():
                all_result[j*bs:(j+1)*bs]=train_model.predict_k_ary(bs_user,beam_size,topk,k_ary,forest=False)
        else:
            raise NotImplementedError
    
    all_result=all_result[:,::-1]

    metric_list = ['precision', 'recall', 'f_measure', 'novelty', 'hit_ratio', 'NDCG', 'MAP']
    metrics_dict = {metric:[] for metric in metric_list} 
    for k in [20, 40, 60, 80, 100, 120]:
        pre=presision(all_result[:,:k].tolist(),train_instances.test_labels,k)
        metrics_dict['precision'].append(pre)
        rec=recall(all_result[:,:k].tolist(),train_instances.test_labels)
        metrics_dict['recall'].append(rec)
        f_mea=f_measure(all_result[:,:k].tolist(),train_instances.test_labels,k)
        metrics_dict['f_measure'].append(f_mea)
        novel=novelty(all_result[:,:k].tolist(),train_instances.test_labels,k)
        metrics_dict['novelty'].append(novel)
        hit=hit_ratio(all_result[:,:k].tolist(),train_instances.test_labels)
        metrics_dict['hit_ratio'].append(hit)
        ndcg=NDCG(all_result[:,:k].tolist(),train_instances.test_labels)
        metrics_dict['NDCG'].append(ndcg)
        map=MAP(all_result[:,:k].tolist(),train_instances.test_labels,k)
        metrics_dict['MAP'].append(map)
       
  
    # table = wandb.Table(columns=['test_accuracy', 'test_top5_accuracy'], data=[[prec1, prec5]])
    for metric in ['precision', 'recall', 'f_measure', 'novelty', 'hit_ratio', 'NDCG', 'MAP']:
        table = wandb.Table(columns=[metric+'@'+str(k) for k in [20, 40, 60, 80, 100, 120]],
                            data=[metrics_dict[metric]])
        wandb.log({'Beam_Size ' + str(beam_size) + '/' + metric: table})
    # wandb.log({'Tables': table})

print('predcit cost time {:.4f}s'.format(time.time()-st))
wandb.finish()