# Libraries
import numpy as np
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
import sys

# Custom Imports
from Hierarchical_VQ_AE import Hierarchical_VQ_AE

sys.path.append('./utils')
from train_utils import *

################
# Load arguments
################


parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=256)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--channel', type=int, default=64)
parser.add_argument('--n_embed_bot', type=int, default=10)
parser.add_argument('--n_embed_mid', type=int, default=10)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--num_training', type=int, default=1)
parser.add_argument('--n_res_block', type=int, default=2)
parser.add_argument('--n_res_channel', type=int, default=32)
parser.add_argument('--scale_reduc_bot', type=int, default=4)
parser.add_argument('--scale_reduc_mid', type=int, default=4)
parser.add_argument('--out_channel_deconv', type=int, default=64)
parser.add_argument('--version', type=str, default='1')
parser.add_argument('dataset_name', type=str)

args = parser.parse_args()



#Paths
save_data_path_train = '../data/' + args.dataset_name + '_TRAIN/'
save_data_path_test = '../data/' + args.dataset_name + '_TEST/'
save_model_path = '../results/trained_models'
save_dict_path = '../results/dict_info'

#Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Load Data
TS_train = torch.load(save_data_path_train + '/X_tensor.pt').to(torch.float32)
TS_test = torch.load(save_data_path_test + '/X_tensor.pt').to(torch.float32)
TS_tensor = torch.cat([TS_train, TS_test], axis=0)

#Find architecture parameters
params = find_params_VQ_AE(TS_tensor, args.scale_reduc_bot, args.scale_reduc_mid)
ks_list_decod_mid = params['k_s_list_encod_m'][::-1] + params['k_s_list_encod_b'][::-1]
ks_list_decod_mid_to_bot = params['k_s_list_encod_m'][::-1]
ks_list_decod_bot = params['k_s_list_encod_b'][::-1]

#Instance model
model = Hierarchical_VQ_AE(in_channel=1, 
                           channel=args.channel, 
                           n_res_block=args.n_res_block,
                           n_res_channel=args.n_res_channel, 
                           embedd_dim=args.embed_dim, 
                           n_embed_bot=args.n_embed_bot,
                           n_embed_mid=args.n_embed_mid,
                           out_channel_deconv=args.out_channel_deconv,
                           ks_list_decod_mid=ks_list_decod_mid,
                           ks_list_decod_bot=ks_list_decod_bot,
                           ks_list_decod_mid_to_bot=ks_list_decod_mid_to_bot,
                           comit_cost=0.25,
                           **params).to(device)

#Training parameters
batch_size = args.batch
num_training_updates = args.num_training
learning_rate = args.lr
optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=False)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
torch.cuda.empty_cache()

#Prepare for GPU's
data = MyDataset(TimeSeries=TS_tensor[:, :, :])
training_loader = DataLoader(data, batch_size=batch_size, shuffle=True,pin_memory=True)

##############
#Training
##############
model.train()
ii = 0
not_stop_condition = True

while not_stop_condition and ii <= num_training_updates:

    data = next(iter(training_loader))
    data = data.to(device)
    optimizer.zero_grad()

    recons, vq_loss, quant_b, quant_m, encoding_indices_m, encoding_indices_b = model(data)
    recon_error = F.mse_loss(recons, data.flatten(1,2)) 

    loss = recon_error + vq_loss
    loss.backward()
    optimizer.step()

    ii += 1

torch.save(model.state_dict(), save_model_path + '/unsup_model_' + args.dataset_name + '_' + args.version + '.pt')


dict_stats = {'dataset_name': args.dataset_name,
                'seq_len' : params['seq_len'],
                'scale_reduc_bot': args.scale_reduc_bot,
                'seq_bot_len': params['seq_bot_len'],
                'scale_reduc_mid': args.scale_reduc_mid,
                'seq_mid_len': params['seq_mid_len'],          
                'VQ_loss': vq_loss.item(),
                'Recons_loss': recon_error.item(),
                'Nb_bot_centroides_theory': args.n_embed_bot,
                'Nb_mid_centroides_theory':args.n_embed_mid,
                'Nb_bot_centroides': len(encoding_indices_b.unique()),
                'Nb_mid_centroides':len(encoding_indices_m.unique())}

with open(save_dict_path + '/dict_' + args.dataset_name + '_' + args.version + '.json', 'w') as fp:
    json.dump(dict_stats, fp)






