import json
import numpy as np
import argparse
import os
import traceback
import time
import tqdm 
import pickle
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dot, Embedding, Flatten, Dense, GlobalAveragePooling1D
from tensorflow.keras import layers
from tensorflow.keras import backend as K

import sys
sys.path.insert(0, './src')
from utils import MaskedMeanPool, reco_seq_construct1, generate_seq_trn_dat, clf_kernel_score, generate_phi
from data import generate_cl_trn_dat, ml_rating_preprocess
from metric import eval_reco, classification_metric, ml_item_label
from model import Item2Vec, GruforRec, MLP, logistic_reg


def TrnEmb(ws, ns, epochs, save_emb, log_dir, dat_dir):
    """
    Pretraining movie embedding using user viewing data
    """
    # generate pretraining data from user viewing sequence
    user_hist, itemid2idx = ml_rating_preprocess(dat_dir + 'ratings.dat')
    order_hist_trn, order_hist_test, id_freq = reco_seq_construct1(list(user_hist.values()), itemid2idx)
    seqs_trn, targets_trn, labels_trn = generate_seq_trn_dat(order_hist_trn, 
                                                             len(itemid2idx), 
                                                             list(id_freq.values()), 
                                                             for_trn=True, 
                                                             num_ns=2, 
                                                             SEED=23333, 
                                                             min_len=4, 
                                                             max_len=20, 
                                                             distortion=0.)

    seqs_test, targets_test, labels_test = generate_seq_trn_dat(order_hist_test, 
                                                             len(itemid2idx), 
                                                             list(id_freq.values()), 
                                                             for_trn=False, 
                                                             num_ns=100, 
                                                             SEED=23333, 
                                                             min_len=4, 
                                                             max_len=20, 
                                                             distortion=0.)

    # building training data
    targets, contexts, labels = generate_cl_trn_dat(list(user_hist.values()), ws, ns, len(itemid2idx))
    BATCH_SIZE = 1024
    BUFFER_SIZE = 10000
    dataset = tf.data.Dataset.from_tensor_slices(((targets, contexts), labels))
    dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
    
    for iter_ in range(10):
        item2vec = Item2Vec(len(itemid2idx), embedding_dim)
        item2vec.compile(optimizer='adam',
                         loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                         metrics=['accuracy'])

        item2vec.fit(dataset, epochs=epochs, verbose=0)

        item_emb = item2vec.get_layer('item_embedding').get_weights()[0]

        if save_emb:
            fname = "ML-1m_emb_epoch{}_iter{}.pkl".format(epoch, iter_)
            with open(log_dir + fname, 'wb') as f:
                pickle.dump(item_emb, f)
                
    # storing environment variables for downstream task
    store_var = dict({"seqs_trn":seqs_trn, "targets_trn":targets_trn, "labels_trn":labels_trn,
                          "seqs_test":seqs_test, "targets_test":targets_test, "labels_test":labels_test,
                          "itemid2idx":itemid2idx, "id_freq":id_freq, "user_hist":user_hist})
    return store_var


def EmbRec(emb_loc, store_var, rnn_dim, phi_dim, epochs):
    """
    Perform sequential recommendation using pre-trained recommendation
    """
    # restoring pretrained emb and environment variables
    with open(emb_loc, 'rb') as f:
        item_emb = pickle.load(f)
    
    seqs_trn, targets_trn, labels_trn = store_var['seqs_trn'], store_var['targets_trn'], store_var['labels_trn']
    seqs_test, targets_test, labels_test = store_var['seqs_test'], store_var['targets_test'], store_var['labels_test']
    
    # generating training data for reco
    BATCH_SIZE = 1024
    BUFFER_SIZE = 10000
    seq_dataset = tf.data.Dataset.from_tensor_slices(((seqs_trn, targets_trn), labels_trn))
    seq_dataset = seq_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
    
    # using h for downstream sequential reco
    GruRecPreTrn = GruforRec(rnn_dim[0], len(itemid2idx), hidden_units_l=[rnn_dim[0],rnn_dim[1]], pre_emb = item_emb, l2_reg=0.)
    GruRecPreTrn.compile(optimizer='adam',
                         loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                         metrics=['accuracy'])
    GruRecPreTrn.fit(seq_dataset, epochs=epochs, verbose=0)

    gru_res.append(eval_reco(GruRecPreTrn, seqs_test, targets_test, labels_test))
    eval_reco(GruRecPreTrn, seqs_test, targets_test, labels_test)
    
    # using \phi_m(h) for downstream sequential reco
    phi = generate_phi(item_emb, phi_dim)
    GruRecPreTrn = GruforRec(rnn_dim[0], len(itemid2idx), hidden_units_l=[rnn_dim[0],rnn_dim[1]], pre_emb = item_emb, l2_reg=0.)
    GruRecPreTrn.compile(optimizer='adam',
                         loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                         metrics=['accuracy'])
    GruRecPreTrn.fit(seq_dataset, epochs=epochs, verbose=0)

    gru_res.append(eval_reco(GruRecPreTrn, seqs_test, targets_test, labels_test))
    eval_reco(GruRecPreTrn, seqs_test, targets_test, labels_test)
    
    
def EmbClf(emb_loc, store_var, dat_dir, phi_dim):
    """
    Perform movie genre classification using pre-trained recommendation
    """
    
    # restoring pretrained emb and environment variables
    with open(emb_loc, 'rb') as f:
        item_emb = pickle.load(f)
    
    itemid2idx = store_var['itemid2idx']
    emb_clf, label = ml_item_label(dat_dir + 'movies.dat', item_emb, itemid2idx)
    feat = emb_clf
    
    # using h for downstream classification
    emb_pred, emb_pred_score, test_label = logistic_reg(feat, label)
    lr_clf_res = classification_metric(test_label, emb_pred, emb_pred_score, auc=False, return_res=True)
    f1_micro_lr.append(lr_clf_res['f1_micro'])
    f1_macro_lr.append(lr_clf_res['f1_macro'])
    
    # using \phi_m(h) for downstream classification
    phi = generate_phi(item_emb, phi_dim)
    emb_clf, label = ml_item_label(dat_dir + 'movies.dat', phi, itemid2idx)
    feat = emb_clf

    emb_pred, emb_pred_score, test_label = logistic_reg(feat, label)
    lr_clf_res = classification_metric(test_label, emb_pred, emb_pred_score, auc=False, return_res=True)
    f1_micro_lr.append(lr_clf_res['f1_micro'])
    f1_macro_lr.append(lr_clf_res['f1_macro'])
    

    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--log_dir", type=str, default="log/ml_seq_reco/")
    parser.add_argument("--dat_dir", type=str, default="data/ml-1m/")
    parser.add_argument("--emb_dim", type=int, default=32)
    parser.add_argument("--rnn_dim", type=list, default=[32, 16])
    parser.add_argument("--phi_dim", type=list, default=[32, 16])
    parser.add_argument("--ns", type=int, default=3)
    parser.add_argument("--ws", type=int, default=3)
    parser.add_argument("--epochs", type=int, default=32)
    parser.add_argument("--save_emb", type=int, default=0, choices=[0,1]) # whether to save trained embeddings
    parser.add_argument("--GPU", type=int, default=1)
    
    args = parser.parse_args()
    
    # setting up GPU
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            tf.config.experimental.set_memory_growth(gpus[args.GPU], True)
            tf.config.experimental.set_visible_devices(gpus[args.GPU], 'GPU')
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    #             print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
        except RuntimeError as e:
            print(e)

            
    store_var = TrnEmb(args.ws, args.ns, args.epochs, args.save_emb, args.log_dir, args.dat_dir)
    
    # emb loc needs to be specificed
    EmbRec(emb_loc, store_var, args.rnn_dim, args.phi_dim, args.epochs)
    
    EmbClf(emb_loc, store_var, args.dat_dir, args.phi_dim)