
import numpy as np
from copy import deepcopy

import torch
from torch import Tensor
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.data import DataLoader

import e3nn
from e3nn import o3

import json
from tqdm import tqdm
import os, sys
import gzip
import pickle
import argparse
import time

sys.path.append('..')

from coordinates.protein import *
from cgnet_fibers import ClebschGordanVAE_symmetric_simple_flexible
from utils.data_utils import MNISTDataset, MNISTDatasetWithConditioning, MNISTDatasetWithConditioning__fibers
from utils.argparse_utils import *
from utils.equivariance_tests import rotate_signal, get_wigner_D_from_rot_matrix
from utils.data_getter import get_data_mnist

from typing import *

from torch.utils.tensorboard import SummaryWriter

ID_FRAME = torch.Tensor([[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]])

KLD_THRESHOLD = 0.4
REC_LOSS_THRESHOLD = 0.11

def dict_to_device(adict, device):
    for key in adict:
        adict[key] = adict[key].float().to(device)
    return adict

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--is_vae', type=str_to_bool, default=True)
    parser.add_argument('--model', type=str, default='cgvae_symmetric_simple_flexible')
    parser.add_argument('--input_type', type=str, default='RR-avg_sqrt_power') # RR-avg_sqrt_power
    parser.add_argument('--w3j_filepath', type=str, default='../cg_coefficients/w3j_matrices-lmax=14-version=0.5.0.pkl')
    
    parser.add_argument('--net_lmax', type=int, default=10)
    parser.add_argument('--latent_dim', type=int, default=16)
    parser.add_argument('--n_cg_blocks', type=int, default=6)
    parser.add_argument('--lmax_list', type=comma_sep_int_list, default='10,10,8,4,2,1')
    parser.add_argument('--ch_size_list', type=comma_sep_int_list, default='16,16,16,16,16,16')
    parser.add_argument('--ls_nonlin_rule_list', type=comma_sep_str_list, default='efficient,efficient,efficient,efficient,efficient,efficient')
    parser.add_argument('--ch_nonlin_rule_list', type=comma_sep_str_list, default='elementwise,elementwise,elementwise,elementwise,elementwise,elementwise')
    parser.add_argument('--do_initial_linear_projection', type=str_to_bool, default=True) # just always keep it true, current code breaks otherwise
    parser.add_argument('--ch_initial_linear_projection', type=int, default=16)

    parser.add_argument('--filter_symmetric', type=str_to_bool, default=True)
    parser.add_argument('--linearity_first', type=str_to_bool, default=False)

    parser.add_argument('--use_batch_norm', type=str_to_bool, default=True)
    parser.add_argument('--norm_type', type=str, default='signal') # None, layer, signal, layer_and_signal
    parser.add_argument('--normalization', type=str, default='component') # norm, component -> only considered if norm_type is not none
    parser.add_argument('--norm_balanced', type=str_to_bool_or_float, default=False) 
    parser.add_argument('--norm_affine', type=str_to_str_or_bool_or_comma_sep_tuple_of_both, default='per_l') # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
    parser.add_argument('--norm_nonlinearity', type=str, default=None) # identity, relu, swish, sigmoid -> only for layer_norm
    parser.add_argument('--norm_location', type=str, default='between') # first, between, last

    parser.add_argument('--use_additive_skip_connections', type=str_to_bool, default=True)
    parser.add_argument('--weight_decay', type=str_to_bool, default=False)
    parser.add_argument('--x_rec_loss_fn', type=str, default='mse')
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--learn_frame', type=str_to_bool, default=True)
    parser.add_argument('--lr', type=float, default=0.001) # OLD: 0.0001
    parser.add_argument('--lr_schedule', type=str, default='log_decrease_until_end_of_warmup', choices=['constant', 'log_decrease_until_end_of_warmup', 'log_decrease_until_end_by_1_OM', 'log_decrease_until_end_by_2_OM', 'log_decrease_until_end_by_3_OM', 'linear_decrease_until_end_of_warmup', 'decrease_below_threshold', 'decrease_after_warmup', 'decrease_at_half'])
    parser.add_argument('--n_epochs', type=int, default=60)
    parser.add_argument('--lambdas', type=comma_sep_float_list, default='50.0,0.2')
    parser.add_argument('--lambdas_schedule', type=str, default='constant', choices=['zeros', 'constant', 'drop_kl_at_half', 'linear_up_anneal_kl'])
    parser.add_argument('--no_kl_epochs', type=int, default=25)
    parser.add_argument('--warmup_kl_epochs', type=int, default=10)

    parser.add_argument('--seed', type=int, default=420420420)

    parser.add_argument('--hash', type=str, help='Unique identifier for the run. Usually a hash of the hyperparameters.')
    parser.add_argument('--experiments_dir', type=str)
    parser.add_argument('--experiments_suffix', type=str, default='equiv_fibers')
    parser.add_argument('--use_wandb', type=str_to_bool, default=False)
    parser.add_argument('--use_tensorboard', type=str_to_bool, default=False)

    args = parser.parse_args()

    print('HERE', file=sys.stderr)

    local_experiment_dir = os.path.join(args.experiments_dir, 'local_%s' % (args.experiments_suffix), args.hash)
    if not os.path.exists(local_experiment_dir):
        os.makedirs(local_experiment_dir)

    wandb_experiment_dir = os.path.join(args.experiments_dir, 'wandb_%s' % (args.experiments_suffix), args.hash)
    if not os.path.exists(wandb_experiment_dir):
        os.makedirs(wandb_experiment_dir)

    tensorboard_experiment_dir = os.path.join(args.experiments_dir, 'tensorboard_%s' % (args.experiments_suffix), args.hash)
    if not os.path.exists(tensorboard_experiment_dir):
        os.makedirs(tensorboard_experiment_dir)

    if args.use_tensorboard:
        writer = SummaryWriter(log_dir=tensorboard_experiment_dir)


    if len(args.input_type.split('-')) == 1:
        do_final_signal_norm = True
    elif args.input_type.split('-')[1] == 'sqrt_power':
        do_final_signal_norm = True
    elif args.input_type.split('-')[1] in ['None', 'avg_sqrt_power', 'avg_sqrt_power_times_11', 'avg_sqrt_power_balanced', 'avg_sqrt_power_balanced_times_100', 'avg_sqrt_power_balanced_times_10']:
        do_final_signal_norm = False
    else:
        do_final_signal_norm = False
    
    print('do_final_signal_norm: {}'.format(do_final_signal_norm))

    hyps_dict = args_to_dict(args, ignore_params=set(['experiments_dir', 'hash', 'experiments_suffix', 'w3j_filepath']))
    hyps_dict['do_final_signal_norm'] = do_final_signal_norm

    if args.use_wandb:
        import wandb
        wandb.config = hyps_dict

    with open(os.path.join(local_experiment_dir, 'hparams.json'), 'w+') as f:
        json.dump(hyps_dict, f, indent=2)

    # data preparation and loading stuff
    rng = torch.Generator().manual_seed(args.seed)

    data_irreps = o3.Irreps.spherical_harmonics(args.net_lmax, 1)
    print('Data irreps: {}'.format(data_irreps), file=sys.stderr)

    data, s2_data, (ba_grid, xyz_grid) = get_data_mnist(args.input_type, get_grids=True, get_s2=False, lmax=args.net_lmax)
    
    # train_data = torch.cat((data['train']['projections'], data['valid']['projections']), dim=0)
    # train_labels = torch.cat((torch.tensor(data['train']['labels']), torch.tensor(data['valid']['labels'])), dim=-1)
    train_data = data['train']['projections']
    train_rot = data['train']['rotations']
    train_labels = torch.tensor(data['train']['labels'])
    valid_data = data['valid']['projections']
    valid_rot = data['valid']['rotations']
    valid_labels = torch.tensor(data['valid']['labels'])

    train_dataset = MNISTDatasetWithConditioning__fibers(train_data, data_irreps, train_labels, train_rot)
    valid_dataset = MNISTDatasetWithConditioning__fibers(valid_data, data_irreps, valid_labels, valid_rot)

    print('%d training neighborhoods' % (len(train_dataset)), file=sys.stderr)
    print('%d validation neighborhoods' % (len(valid_dataset)), file=sys.stderr)

    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, generator=rng, shuffle=True, drop_last=True)
    valid_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size, generator=rng, shuffle=False, drop_last=True)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('Running on {}'.format(device), file=sys.stderr)

    ## get w3j matrices
    with gzip.open(args.w3j_filepath, 'rb') as f:
        w3j_matrices = pickle.load(f)

    for key in w3j_matrices:
        if key[0] <= args.net_lmax and key[1] <= args.net_lmax and key[2] <= args.net_lmax:
            if device is not None:
                w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float().to(device)
            else:
                w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float()
            w3j_matrices[key].requires_grad = False
    
    if args.model == 'cgvae':
        raise NotImplementedError('vanilla, non-symmetric cgvae not implemented')
    elif args.model == 'cgvae_symmetric_simple_flexible':
        cgvae = ClebschGordanVAE_symmetric_simple_flexible(data_irreps,
                                                        args.latent_dim,
                                                        args.net_lmax,
                                                        args.n_cg_blocks,
                                                        args.ch_size_list,
                                                        args.ls_nonlin_rule_list,
                                                        args.ch_nonlin_rule_list,
                                                        args.do_initial_linear_projection,
                                                        args.ch_initial_linear_projection,
                                                        w3j_matrices,
                                                        device,
                                                        lmax_list=args.lmax_list,
                                                        use_additive_skip_connections=args.use_additive_skip_connections,
                                                        use_batch_norm=args.use_batch_norm,
                                                        norm_type=args.norm_type, # None, layer, signal
                                                        normalization=args.normalization, # norm, component -> only considered if norm_type is not none
                                                        norm_balanced=args.norm_balanced,
                                                        norm_affine=args.norm_affine, # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                                        norm_nonlinearity=args.norm_nonlinearity, # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                                        norm_location=args.norm_location, # first, between, last
                                                        linearity_first=args.linearity_first, # currently only works with this being false
                                                        filter_symmetric=args.filter_symmetric, # whether to exclude duplicate pairs of l's from the tensor product nonlinearity
                                                        x_rec_loss_fn=args.x_rec_loss_fn, # mse, mse_normalized, cosine
                                                        do_final_signal_norm=do_final_signal_norm,
                                                        learn_frame=args.learn_frame,
                                                        is_vae=args.is_vae).to(device)

    if args.use_wandb:
        wandb.watch(cgvae, log_freq=10, log_graph=False)

    num_params = 0
    for param in cgvae.parameters():
        num_params += torch.flatten(param.data).shape[0]
    print('There are %d parameters' % (num_params), file=sys.stderr)

    if args.weight_decay:
        optimizer_all = Adam(cgvae.parameters(), lr=args.lr, weight_decay=1e-5)
    else:
        optimizer_all = Adam(cgvae.parameters(), lr=args.lr)

    optimizers = [optimizer_all]
    
    x_lambda, kl_lambda = args.lambdas
    
    def optimizing_step(x_reconst_loss: Tensor, kl_divergence: Tensor,
                        x_lambda: float, kl_lambda: float,
                        optimizers: List):
        if len(optimizers) == 1: # just one optimizer with all parameters
            optimizer = optimizers[0]
            loss = x_lambda * x_reconst_loss + kl_lambda * kl_divergence
            loss.backward()
            optimizer.step()
        return loss

    if args.lambdas_schedule == 'linear_up_anneal_kl':
        kl_lambda_per_epoch = list(np.zeros(args.no_kl_epochs)) + list(np.linspace(0.0, kl_lambda, args.warmup_kl_epochs)) + list(np.full(args.n_epochs - args.warmup_kl_epochs - args.no_kl_epochs, kl_lambda))
        print(kl_lambda_per_epoch, file=sys.stderr)
    elif args.lambdas_schedule == 'constant':
        kl_lambda_per_epoch = np.full(args.n_epochs, kl_lambda)
    elif args.lambdas_schedule == 'drop_kl_at_half':
        import math
        kl_lambda_per_epoch = list(np.full(math.floor(args.n_epochs), kl_lambda)) + list(np.full(math.floor(args.n_epochs), 0.0))
    elif args.lambdas_schedule == 'zeros':
        kl_lambda_per_epoch = np.full(args.n_epochs, 0.0)

    global_record_i = 0
    epoch_start = 0
    best_loss_04, best_loss_05, best_loss_06 = np.inf, np.inf, np.inf
    best_kl_04, best_kl_05, best_kl_06 = np.inf, np.inf, np.inf
    lowest_rec_loss = np.inf
    best_kl = np.inf
    lowest_rec_loss_kl = np.inf
    lowest_total_loss = np.inf
    lowest_total_loss_kl = np.inf
    lowest_total_loss_with_final_kl = np.inf
    lowest_total_loss_with_final_kl_kl = np.inf
    have_decreased_lr = False

    if args.lr_schedule == 'log_decrease_until_end_of_warmup':
        init_lr_scale = float(('%e' % (args.lr)).split('e')[0])
        init_lr_exponent = int(('%e' % (args.lr)).split('e')[1])
        lr_list = list(init_lr_scale*np.logspace(init_lr_exponent, init_lr_exponent-1, args.no_kl_epochs)) + list(init_lr_scale*np.full(args.n_epochs - args.no_kl_epochs, float('1e%d' % (init_lr_exponent-1))))
    elif args.lr_schedule == 'linear_decrease_until_end_of_warmup':
        lr_list = list(np.linspace(args.lr, args.lr * 0.1, args.no_kl_epochs)) + list(np.full(args.n_epochs - args.no_kl_epochs, args.lr * 0.1))
    elif args.lr_schedule == 'log_decrease_until_end_by_1_OM':
        init_lr_exponent = int(('%e' % (args.lr)).split('e')[1])
        lr_list = list(np.logspace(init_lr_exponent, init_lr_exponent-1, args.n_epochs))
    elif args.lr_schedule == 'log_decrease_until_end_by_2_OM':
        init_lr_exponent = int(('%e' % (args.lr)).split('e')[1])
        lr_list = list(np.logspace(init_lr_exponent, init_lr_exponent-2, args.n_epochs))
    elif args.lr_schedule == 'log_decrease_until_end_by_3_OM':
        init_lr_exponent = int(('%e' % (args.lr)).split('e')[1])
        lr_list = list(np.logspace(init_lr_exponent, init_lr_exponent-3, args.n_epochs))

    print(lr_list, file=sys.stderr)

    times_per_epoch_to_record = 5
    steps_to_record = len(train_dataloader) // times_per_epoch_to_record
    for epoch in range(epoch_start, args.n_epochs):
        print('Epoch %d/%d' % (epoch+1, args.n_epochs), file=sys.stderr)
        train_sf_rec_loss, train_rec_loss, train_kl, train_sf_reg, train_total_loss, train_total_loss_with_final_kl = [], [], [], [], [], []
        train_mean, train_log_var, train_sf, train_sf_rec = {'Mean': [], 'Min': [], 'Max': []}, {'Mean': [], 'Min': [], 'Max': []}, {'Mean': [], 'Min': [], 'Max': []}, {'Mean': [], 'Min': [], 'Max': []}
        record_i = 1
        kl_lambda = kl_lambda_per_epoch[epoch]

        if args.lr_schedule == 'decrease_after_warmup' and args.lambdas_schedule == 'linear_up_anneal_kl' and epoch == args.warmup_kl_epochs: # reduce learning rate after kl warmup
            for optimizer in optimizers:
                for g in optimizer.param_groups:
                    g['lr'] *= 0.1

        elif args.lr_schedule == 'log_decrease_until_end_of_warmup':
            for optimizer in optimizers:
                for g in optimizer.param_groups:
                    g['lr'] = lr_list[epoch]
        
        start_time = time.time()
        for i, (X, X_vec, y, rot) in enumerate(train_dataloader):
            X = dict_to_device(X, device)
            X_vec = X_vec.float().to(device)
            # y = y.float().to(device)

            # NOTE: this is a shortcut that only works because now the l=1 wigner-D matrix is equivalent
            # to the rotation matrix that parametrizes it! But that is not always the case. Change this.
            # frame = torch.transpose(rot.float().view(-1, 3, 3), 1, 2).reshape(-1, 1, 9).squeeze().to(device)
            frame = rot.float().view(-1, 3, 3).to(device)

            for optimizer in optimizers:
                optimizer.zero_grad()
            cgvae.train()
            _, x_reconst_loss, kl_divergence, _, x_reconst, ((mean, log_var), _, _) = cgvae(X, x_vec=X_vec, frame=frame)
            total_loss = optimizing_step(x_reconst_loss, kl_divergence,
                                            x_lambda, kl_lambda,
                                            optimizers)
            total_loss_with_final_kl = x_lambda * x_reconst_loss + args.lambdas[1] * kl_divergence

            train_total_loss.append(total_loss.item())
            train_total_loss_with_final_kl.append(total_loss_with_final_kl.item())
            train_rec_loss.append(x_reconst_loss.item())
            train_kl.append(kl_divergence.item())
            for key, stat_func in zip(['Mean', 'Min', 'Max'], [np.mean, np.min, np.max]):
                train_mean[key].append(stat_func(mean.cpu().detach().numpy(), axis=-1))
                train_log_var[key].append(stat_func(log_var.cpu().detach().numpy(), axis=-1))


            if i % steps_to_record == (steps_to_record - 1):
                valid_sf_rec_loss, valid_rec_loss, valid_kl, valid_sf_reg, valid_total_loss, valid_total_loss_with_final_kl = [], [], [], [], [], []
                valid_mean, valid_log_var, valid_sf, valid_sf_rec = {'Mean': [], 'Min': [], 'Max': []}, {'Mean': [], 'Min': [], 'Max': []}, {'Mean': [], 'Min': [], 'Max': []}, {'Mean': [], 'Min': [], 'Max': []}
                for j, (X, X_vec, y, rot) in enumerate(valid_dataloader):
                    X = dict_to_device(X, device)
                    X_vec = X_vec.float().to(device)
                    # y = y.float().to(device)

                    # NOTE: this is a shortcut that only works because now the l=1 wigner-D matrix is equivalent
                    # to the rotation matrix that parametrizes it! But that is not always the case. Change this.
                    # frame = torch.transpose(rot.float().view(-1, 3, 3), 1, 2).reshape(-1, 1, 9).squeeze().to(device)
                    frame = rot.float().view(-1, 3, 3).to(device)

                    cgvae.eval()
                    _, x_reconst_loss, kl_divergence, _, x_reconst, ((mean, log_var), _, _) = cgvae(X, x_vec=X_vec, frame=frame)

                    total_loss = x_lambda * x_reconst_loss + kl_lambda * kl_divergence
                    total_loss_with_final_kl = x_lambda * x_reconst_loss + args.lambdas[1] * kl_divergence


                    valid_total_loss.append(total_loss.item())
                    valid_total_loss_with_final_kl.append(total_loss_with_final_kl.item())
                    valid_rec_loss.append(x_reconst_loss.item())
                    valid_kl.append(kl_divergence.item())
                    for key, stat_func in zip(['Mean', 'Min', 'Max'], [np.mean, np.min, np.max]):
                        valid_mean[key].append(stat_func(mean.cpu().detach().numpy(), axis=-1))
                        valid_log_var[key].append(stat_func(log_var.cpu().detach().numpy(), axis=-1))

                end_time = time.time()
                print('%d/%d' % (record_i, times_per_epoch_to_record), end = ' - ', file=sys.stderr)
                print('TRAIN:: ', end='', file=sys.stderr)
                print('rec loss: %.5f' % np.mean(train_rec_loss), end=' -- ', file=sys.stderr)
                print('kl-div: %.5f' % np.mean(train_kl), end=' - ', file=sys.stderr)
                print('total loss: %.5f' % np.mean(train_total_loss), end=' - ', file=sys.stderr)
                print('Loss: %.5f' % np.mean(train_total_loss_with_final_kl), end=' - ', file=sys.stderr)
                print('VALID:: ', end='', file=sys.stderr)
                print('rec loss: %.5f' % np.mean(valid_rec_loss), end=' - ', file=sys.stderr)
                print('kl-div: %.5f' % np.mean(valid_kl), end=' - ', file=sys.stderr)
                print('total loss: %.5f' % np.mean(valid_total_loss), end=' - ', file=sys.stderr)
                print('Loss: %.5f' % np.mean(valid_total_loss_with_final_kl), end=' - ', file=sys.stderr)
                print('Time (s): %.1f' % (end_time - start_time), file=sys.stderr)

                if args.use_tensorboard:
                    writer.add_scalar('train/rec_loss', np.mean(train_rec_loss), global_step=global_record_i)
                    writer.add_scalar('train/kl_div', np.mean(train_kl), global_step=global_record_i)

                if args.use_wandb:
                    wandb.log({'train/rec_loss': np.mean(train_rec_loss)})
                    wandb.log({'train/kl_div': np.mean(train_kl)})

                if args.use_tensorboard:
                    writer.add_scalar('valid/rec_loss', np.mean(valid_rec_loss), global_step=global_record_i)
                    writer.add_scalar('valid/kl_div', np.mean(valid_kl), global_step=global_record_i)

                if args.use_wandb:
                    wandb.log({'valid/rec_loss': np.mean(valid_rec_loss)})
                    wandb.log({'valid/kl_div': np.mean(valid_kl)})

                for key in ['Mean', 'Min', 'Max']:
                    if args.use_tensorboard:
                        writer.add_scalar('train/mean-%s' % key, np.mean(np.hstack(train_mean[key])), global_step=global_record_i)
                        writer.add_scalar('train/log_var-%s' % key, np.mean(np.hstack(train_log_var[key])), global_step=global_record_i)
                        writer.add_scalar('train/sf-%s' % key, np.mean(np.hstack(train_sf[key])), global_step=global_record_i)
                        writer.add_scalar('train/sf_rec-%s' % key, np.mean(np.hstack(train_sf_rec[key])), global_step=global_record_i)

                    if args.use_wandb:
                        wandb.log({'train/mean-%s' % key: np.mean(np.hstack(train_mean[key]))})
                        wandb.log({'train/log_var-%s' % key: np.mean(np.hstack(train_log_var[key]))})

                    if args.use_tensorboard:
                        writer.add_scalar('valid/mean-%s' % key, np.mean(np.hstack(valid_mean[key])), global_step=global_record_i)
                        writer.add_scalar('valid/log_var-%s' % key, np.mean(np.hstack(valid_log_var[key])), global_step=global_record_i)

                    if args.use_wandb:
                        wandb.log({'valid/mean-%s' % key: np.mean(np.hstack(valid_mean[key]))})
                        wandb.log({'valid/log_var-%s' % key: np.mean(np.hstack(valid_log_var[key]))})

                
                if args.lr_schedule == 'decrease_below_threshold' and np.mean(train_rec_loss) < REC_LOSS_THRESHOLD and not have_decreased_lr:
                    for optimizer in optimizers:
                        for g in optimizer.param_groups:
                            g['lr'] *= 0.1
                    have_decreased_lr = True

                # record best model on validation rec loss
                if args.is_vae:
                    if args.lambdas_schedule in ['constant', 'linear_up_anneal_kl']:
                        if epoch >= args.no_kl_epochs and np.mean(valid_kl) < 0.4: # do it only post initial annealing of kldiv, once valid kldiv goes back down below a certain threshold
                            if np.mean(valid_rec_loss) < best_loss_04:
                                best_loss_04 = np.mean(valid_rec_loss)
                                best_kl_04 = np.mean(valid_kl)
                                # torch.save(cgvae.state_dict(), os.path.join(local_experiment_dir, 'best_model_04.pt'))
                        if epoch >= args.no_kl_epochs and np.mean(valid_kl) < 0.5: # do it only post initial annealing of kldiv, once valid kldiv goes back down below a certain threshold
                            if np.mean(valid_rec_loss) < best_loss_05:
                                best_loss_05 = np.mean(valid_rec_loss)
                                best_kl_05 = np.mean(valid_kl)
                                # torch.save(cgvae.state_dict(), os.path.join(local_experiment_dir, 'best_model_05.pt'))
                        if epoch >= args.no_kl_epochs and np.mean(valid_kl) < 0.6: # do it only post initial annealing of kldiv, once valid kldiv goes back down below a certain threshold
                            if np.mean(valid_rec_loss) < best_loss_06:
                                best_loss_06 = np.mean(valid_rec_loss)
                                best_kl_06 = np.mean(valid_kl)
                                # torch.save(cgvae.state_dict(), os.path.join(local_experiment_dir, 'best_model_06.pt'))
                    else:
                        if np.mean(valid_rec_loss) < best_loss:
                            best_loss = np.mean(valid_rec_loss)
                            best_kl = np.mean(valid_kl)
                            torch.save(cgvae.state_dict(), os.path.join(local_experiment_dir, 'best_model.pt'))
                    
                if np.mean(valid_rec_loss) < lowest_rec_loss:
                    lowest_rec_loss = np.mean(valid_rec_loss)
                    lowest_rec_loss_kl = np.mean(valid_kl)
                    torch.save(cgvae.state_dict(), os.path.join(local_experiment_dir, 'lowest_rec_loss_model.pt'))
                
                if np.mean(valid_total_loss_with_final_kl) < lowest_total_loss_with_final_kl:
                    lowest_total_loss_with_final_kl = np.mean(valid_total_loss_with_final_kl)
                    lowest_total_loss_with_final_kl_kl = np.mean(valid_kl)
                    torch.save(cgvae.state_dict(), os.path.join(local_experiment_dir, 'lowest_total_loss_with_final_kl_model.pt'))
                

                record_i += 1
                global_record_i += 1

                train_sf_rec_loss, train_rec_loss, train_kl = [], [], []
                train_mean, train_log_var, train_sf, train_sf_rec = {'Mean': [], 'Min': [], 'Max': []}, {'Mean': [], 'Min': [], 'Max': []}, {'Mean': [], 'Min': [], 'Max': []}, {'Mean': [], 'Min': [], 'Max': []}
                start_time = time.time()

    # record final model (more regularized than reported best model)
    # torch.save(cgvae.state_dict(), os.path.join(local_experiment_dir, 'final_model.pt'))
    
    # record hyperparameters and final best metrics
    metrics_dict = {'lowest_rec_loss': lowest_rec_loss,
                    'kl_at_lowest_rec_loss': lowest_rec_loss_kl,
                    'best_model_rec_loss_04': best_loss_04,
                    'kl_at_best_model_04': best_kl_04,
                    'best_model_rec_loss_05': best_loss_05,
                    'kl_at_best_model_05': best_kl_05,
                    'best_model_rec_loss_06': best_loss_06,
                    'kl_at_best_model_06': best_kl_06,
                    'final_rec_loss': np.mean(valid_rec_loss),
                    'final_kld': np.mean(valid_kl),
                    'lowest_total_loss_with_final_kl': lowest_total_loss_with_final_kl,
                    'lowest_total_loss_with_final_kl_kl': lowest_total_loss_with_final_kl_kl
                    }

    if args.use_tensorboard:
        writer.add_hparams(hyps_dict, metrics_dict)
    
    with open(os.path.join(local_experiment_dir, 'validation_metrics.json'), 'w+') as f:
        json.dump(metrics_dict, f, indent=2)
