# Libraries
import torch
import argparse 
from sklearn.feature_extraction.text import CountVectorizer
import sys

# Custom Imports
from Hierarchical_VQ_AE import Hierarchical_VQ_AE

sys.path.append('./utils')
from train_utils import *

parser = argparse.ArgumentParser()
parser.add_argument('--channel', type=int, default=64)
parser.add_argument('--n_embed_bot', type=int, default=32)
parser.add_argument('--n_embed_mid', type=int, default=32)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--n_res_block', type=int, default=2)
parser.add_argument('--n_res_channel', type=int, default=32)
parser.add_argument('--scale_reduc_bot', type=int, default=4)
parser.add_argument('--scale_reduc_mid', type=int, default=4)
parser.add_argument('--out_channel_deconv', type=int, default=64)
parser.add_argument('--version', type=str, default='1')
parser.add_argument('dataset_name', type=str)
args = parser.parse_args()


#Paths
save_data_path_train = '../data/' + args.dataset_name + '_TRAIN/'
save_data_path_test = '../data/' + args.dataset_name + '_TEST/'
save_model_path = '../results/trained_models'

#Load Data
TS_train = torch.load(save_data_path_train + '/X_tensor.pt').to(torch.float32)
TS_test = torch.load(save_data_path_test + '/X_tensor.pt').to(torch.float32)
TS_tensor = torch.cat([TS_train, TS_test], axis=0)

#Find architecture parameters
params = find_params_VQ_AE(TS_tensor, args.scale_reduc_bot, args.scale_reduc_mid)
ks_list_decod_mid = params['k_s_list_encod_m'][::-1] + params['k_s_list_encod_b'][::-1]
ks_list_decod_mid_to_bot = params['k_s_list_encod_m'][::-1]
ks_list_decod_bot = params['k_s_list_encod_b'][::-1]

#Instance model
model = Hierarchical_VQ_AE(in_channel=1, 
                           channel=args.channel, 
                           n_res_block=args.n_res_block,
                           n_res_channel=args.n_res_channel, 
                           embedd_dim=args.embed_dim, 
                           n_embed_bot=args.n_embed_bot,
                           n_embed_mid=args.n_embed_mid,
                           out_channel_deconv=args.out_channel_deconv,
                           ks_list_decod_mid=ks_list_decod_mid,
                           ks_list_decod_bot=ks_list_decod_bot,
                           ks_list_decod_mid_to_bot=ks_list_decod_mid_to_bot,
                           comit_cost=0.25,
                           **params)

model.load_state_dict(torch.load(save_model_path + '/unsup_model_' + args.dataset_name + '_' + args.version + '.pt', map_location=torch.device('cpu')))

#Forward
recons, vq_loss, quant_b, quant_m, encoding_indices_m, encoding_indices_b = model(TS_tensor)

vectorizer_bot = CountVectorizer(input='content', encoding='utf-8', decode_error='strict', analyzer=custom_analyzer_bis)
vectorizer_mid = CountVectorizer(input='content', encoding='utf-8', decode_error='strict', analyzer=custom_analyzer_bis)

text_bot = turn_array_into_text(encoding_indices_b.unsqueeze(1).numpy() + 1)
text_mid = turn_array_into_text(encoding_indices_m.unsqueeze(1).numpy() + 1)

# Extract and save bot unigrams and bigrams
X_bot_centroides = vectorizer_bot.fit_transform(text_bot)
tensor_names_bot_list = [int(ele) for ele in vectorizer_bot.get_feature_names()]
tensor_names_bot = torch.tensor(tensor_names_bot_list)
torch.save(tensor_names_bot, save_model_path + '/ngrams_bot_name_' + args.dataset_name + '_' + args.version + '.pt')
X_bot_centroides_to_save = torch.tensor(X_bot_centroides.toarray())
torch.save(X_bot_centroides_to_save, save_model_path + '/ngrams_bot_' + args.dataset_name + '_' + args.version + '.pt')

# Extract and save mid unigrams and bigrams
X_mid_centroides = vectorizer_mid.fit_transform(text_mid)
tensor_names_mid_list = [int(ele) for ele in vectorizer_mid.get_feature_names()]
tensor_names_mid = torch.tensor(tensor_names_mid_list)
torch.save(tensor_names_mid, save_model_path + '/ngrams_mid_name_' + args.dataset_name + '_' + args.version + '.pt')
X_mid_centroides_to_save = torch.tensor(X_mid_centroides.toarray())
torch.save(X_mid_centroides_to_save, save_model_path + '/ngrams_mid_' + args.dataset_name + '_' + args.version + '.pt')












