import torch
import numpy as np
from time import time
t0_script = time()
torch.set_num_threads(1)
torch.set_default_dtype(torch.float64)

import sys
sys.path.append('../../../package')
from planar import *
from glmvi import *


#------------------------------------------------------ parameter settings 
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--replicate',      type=int)
parser.add_argument('--num_layers_lst', type=int, nargs='+')
parser.add_argument('--dim',            type=int, default=10)
parser.add_argument('--n_data',         type=int, default=[10,20], nargs=2)
parser.add_argument('--p_nonzero',      type=int, default=2)
parser.add_argument('--rho_X',        type=float, default=0.5)
parser.add_argument('--prior_scale',  type=float, default=0.1)

parser.add_argument('--batch_size',           type=int, default=250)
parser.add_argument('--num_loops',            type=int, default=50)
parser.add_argument('--num_updates_per_loop', type=int, default=10_000)
parser.add_argument('--num_tests',            type=int, default=10)
parser.add_argument('--batch_size_test',      type=int, default=100_000)
parser.add_argument('--lr_init',            type=float, default=0.001)
parser.add_argument('--lr_gamma',           type=float, default=0.95)

parser.add_argument('--seed_init',     type=int, default=235711131719)
parser.add_argument('--seed_train',    type=int, default=11235813)
parser.add_argument('--seed_test',     type=int, default=31415926)
parser.add_argument('--seed_glm_data', type=int, default=42)

args = parser.parse_args()
replicate            = args.replicate
num_layers_lst       = args.num_layers_lst
dim                  = args.dim
n_data               = args.n_data
p_nonzero            = args.p_nonzero
rho_X                = args.rho_X
prior_scale          = args.prior_scale

batch_size           = args.batch_size
num_loops            = args.num_loops
num_updates_per_loop = args.num_updates_per_loop
num_tests            = args.num_tests
batch_size_test      = args.batch_size_test
lr_init              = args.lr_init
lr_gamma             = args.lr_gamma

seed_init            = args.seed_init
seed_train           = args.seed_train
seed_test            = args.seed_test
seed_glm_data        = args.seed_glm_data

print(f'replicate            = {replicate}')
print(f'num_layers_lst       = {num_layers_lst}')
print(f'dim                  = {dim}')
print(f'n_data               = {n_data}')
print(f'p_nonzero            = {p_nonzero}')
print(f'rho_X                = {rho_X}')
print(f'prior_scale          = {prior_scale}')

print(f'batch_size           = {batch_size}')
print(f'num_loops            = {num_loops}')
print(f'num_updates_per_loop = {num_updates_per_loop}')
print(f'num_tests            = {num_tests}')
print(f'batch_size_test      = {batch_size_test}')
print(f'lr_init              = {lr_init}')
print(f'lr_gamma             = {lr_gamma}')

print(f'seed_init            = {seed_init}')
print(f'seed_train           = {seed_train}')
print(f'seed_test            = {seed_test}')
print(f'seed_glm_data        = {seed_glm_data} \n')


#------------------------------------------------------ simulation
num_regressions = 2
model_state = {}
test_loss_hist = torch.zeros(num_regressions, len(num_layers_lst), 2, num_loops+1)

for i_GLM, (GLM, name) in enumerate(zip([LMVI, LogisticVI], ['LM', 'Logistic'])):
    for i_num_layers, num_layers in enumerate(num_layers_lst):
        for i_reparam, reparam in enumerate(['old', 'new']):
            # Normalizing Flow
            linear_layer = True
            flow = PlanarFlow(dim, num_layers, reparam, 
                              linear_layer, replicate, seed_init)
            model = GLM(n_data[i_GLM], p_nonzero, rho_X, prior_scale, flow, 
                        seed_train, seed_test, seed_glm_data)
            
            # define optimizer and scheduler
            optimizer = torch.optim.Adam(model.flow.parameters(), lr=lr_init)
            scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer, gamma=lr_gamma)
            
            # training and testing
            train_test = model.train_test(
                num_loops, 
                optimizer, scheduler, 
                num_updates_per_loop, batch_size, 
                num_tests, batch_size_test, verbose=False)
            model_state[f'{name}-K{num_layers}-{reparam}'] = train_test['model_state']
            test_loss_hist[i_GLM, i_num_layers, i_reparam] = train_test['test_loss_hist']
    print(name)
    print(test_loss_hist[i_GLM, :, :, -1].T, '\n')


#------------------------------------------------------ save
import pickle
file = f'pickle/r{replicate}_.pickle'
with open(file, 'wb') as f:
    pickle.dump(test_loss_hist, f, pickle.HIGHEST_PROTOCOL)


PATH = f'pt/r{replicate}_.pt'
torch.save(model_state, PATH)


print(f'Total time for running this script' \
      f'  {round((time()-t0_script)/3600, 1)} hrs')

