import torch
import numpy as np
from time import time
t0_script = time()
torch.set_num_threads(1)
torch.set_default_dtype(torch.float64)

import sys
sys.path.append('../../../package')
from planar import *
from toyvi import *


#------------------------------------------------------ parameter settings 
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--replicate',      type=int)
parser.add_argument('--num_layers_lst', type=int, nargs='+')

parser.add_argument('--batch_size',           type=int, default=250)
parser.add_argument('--num_loops',            type=int, default=50)
parser.add_argument('--num_updates_per_loop', type=int, default=10_000)
parser.add_argument('--num_tests',            type=int, default=10)
parser.add_argument('--batch_size_test',      type=int, default=100_000)
parser.add_argument('--lr_init',            type=float, default=0.001)
parser.add_argument('--lr_gamma',           type=float, default=0.95)

parser.add_argument('--seed_init',  type=int, default=235711131719)
parser.add_argument('--seed_train', type=int, default=11235813)
parser.add_argument('--seed_test',  type=int, default=31415926)

args = parser.parse_args()
replicate            = args.replicate
num_layers_lst       = args.num_layers_lst

batch_size           = args.batch_size
num_loops            = args.num_loops
num_updates_per_loop = args.num_updates_per_loop
num_tests            = args.num_tests
batch_size_test      = args.batch_size_test
lr_init              = args.lr_init
lr_gamma             = args.lr_gamma

seed_init            = args.seed_init
seed_train           = args.seed_train
seed_test            = args.seed_test

print(f'replicate            = {replicate}')
print(f'num_layers_lst       = {num_layers_lst}')

print(f'batch_size           = {batch_size}')
print(f'num_loops            = {num_loops}')
print(f'num_updates_per_loop = {num_updates_per_loop}')
print(f'num_tests            = {num_tests}')
print(f'batch_size_test      = {batch_size_test}')
print(f'lr_init              = {lr_init}')
print(f'lr_gamma             = {lr_gamma}')

print(f'seed_init            = {seed_init}')
print(f'seed_train           = {seed_train}')
print(f'seed_test            = {seed_test} \n')


#------------------------------------------------------ simulation
num_toys = 4
model_state = {}
test_loss_hist = torch.zeros(num_toys, len(num_layers_lst), 2, num_loops+1)

for i_Toy, Toy in enumerate([Toy1VI, Toy2VI, Toy3VI, Toy4VI]):
    for i_num_layers, num_layers in enumerate(num_layers_lst):
        for i_reparam, reparam in enumerate(['old', 'new']):
            # Normalizing Flow
            dim = 2
            linear_layer = True
            flow = PlanarFlow(dim, num_layers, reparam, 
                              linear_layer, replicate, seed_init)
            model = Toy(flow, seed_train, seed_test)
            
            # define optimizer and scheduler
            optimizer = torch.optim.Adam(model.flow.parameters(), lr=lr_init)
            scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer, gamma=lr_gamma)
            
            # training and testing
            train_test = model.train_test(
                num_loops, 
                optimizer, scheduler, 
                num_updates_per_loop, batch_size, 
                num_tests, batch_size_test, verbose=False)
            model_state[f'T{i_Toy}-K{num_layers}-{reparam}'] = train_test['model_state']
            test_loss_hist[i_Toy, i_num_layers, i_reparam] = train_test['test_loss_hist']
    print(f'Toy {i_Toy+1}')
    print(test_loss_hist[i_Toy, :, :, -1].T, '\n')


#------------------------------------------------------ save
import pickle
file = f'pickle/r{replicate}_.pickle'
with open(file, 'wb') as f:
    pickle.dump(test_loss_hist, f, pickle.HIGHEST_PROTOCOL)


PATH = f'pt/r{replicate}_.pt'
torch.save(model_state, PATH)


print(f'Total time for running this script' \
      f'  {round((time()-t0_script)/3600, 1)} hrs')

