import os
# os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
import functools
import pickle
import time
import datetime
import json
import numpy as np
import jax
from jax import numpy as jnp
import flax
from flax import serialization
from objax.jaxboard import SummaryWriter, Summary
import tensorflow as tf

from ProtLig_GPCRclassA.metrics import log_confusion_matrix

# from ProtLig_GPCRclassA.amino_GNN.loader import AminoDatasetPrecompute, AminoLoader, AminoCollatePrecompute
from ProtLig_GPCRclassA.amino_GNN.dataset import AminoDatasetPrecompute
from ProtLig_GPCRclassA.amino_GNN.collate import AminoCollatePrecompute
from ProtLig_GPCRclassA.amino_GNN.element import AminoElementPrecompute
from ProtLig_GPCRclassA.amino_GNN.loader import AminoLoader, get_tf_loader

from ProtLig_GPCRclassA.utils import _serialize_hparam, prefetch_to_device, get_last_datetime, get_last_state

from ProtLig_GPCRclassA.amino_GNN.make_loss_func import make_loss_func

from ProtLig_GPCRclassA.amino_GNN.base.make_init import make_init_model, get_tf_specs
from ProtLig_GPCRclassA.amino_GNN.make_create_optimizer import make_create_optimizer
from ProtLig_GPCRclassA.amino_GNN.make_compute_metrics import log_metrics_from_epoch, make_compute_metrics
from ProtLig_GPCRclassA.amino_GNN.base.train.make_train_epoch import make_train_epoch
from ProtLig_GPCRclassA.amino_GNN.base.train.make_valid_epoch import make_valid_epoch

from ProtLig_GPCRclassA.amino_GNN.make_regularization_loss import make_regularization_loss

from ProtLig_GPCRclassA.amino_GNN.select_model import get_model_by_name

import logging

# jax.config.update('jax_platform_name', 'cpu')
# print(jnp.ones(3).device_buffer.device())


def main_train(hparams):
    if hparams['SIZE_CUT_DIRNAME'] is None:
        hparams.update({'BIG_SWITCH_EPOCH' : hparams['N_EPOCH'] + 1})

    datadir = os.path.join(hparams['DATA_PARENT_DIR'], hparams['DATACASE'])
    if hparams['SIZE_CUT_DIRNAME'] is not None:
        data_train_csv = os.path.join(datadir, hparams['SIZE_CUT_DIRNAME'], hparams['TRAIN_CSV_NAME'])
        data_valid_csv = os.path.join(datadir, hparams['SIZE_CUT_DIRNAME'], hparams['VALID_CSV_NAME'])
        big_data_train_csv = os.path.join(datadir, hparams['SIZE_CUT_DIRNAME'], hparams['BIG_TRAIN_CSV_NAME'])
        big_data_valid_csv = os.path.join(datadir, hparams['SIZE_CUT_DIRNAME'], hparams['BIG_VALID_CSV_NAME'])
    else:
        data_train_csv = os.path.join(datadir, hparams['TRAIN_CSV_NAME'])
        data_valid_csv = os.path.join(datadir, hparams['VALID_CSV_NAME'])
        big_data_train_csv = os.path.join(datadir, hparams['BIG_TRAIN_CSV_NAME'])
        big_data_valid_csv = os.path.join(datadir, hparams['BIG_VALID_CSV_NAME'])


    print('\n\n----->\tWARNING: Runnung tests....\n\n')
    logdir = os.path.join(hparams['LOGGING_PARENT_DIR'], hparams['DATACASE'])
    restore_file = hparams['RESTORE_FILE']

    model_class = get_model_by_name(hparams['MODEL_NAME'])
    model = model_class(atom_features = hparams['ATOM_FEATURES'],
                        bond_features = hparams['BOND_FEATURES'],
                        out_features = hparams['OUT_FEATURES'])
    logdir = os.path.join(logdir, model.__class__.__name__)

    if hparams['SLURM_JOB_ARRAY']:
        logdir = os.path.join(logdir, hparams['SLURM_ARRAY_JOB_ID'])

    if hparams['SELF_LOOPS']:
    # if isinstance(model, Simple_GAT_model) or isinstance(model, Transformer_GAT_model):
        hparams['PADDING_N_EDGE'] = hparams['PADDING_N_EDGE'] + hparams['PADDING_N_NODE'] # NOTE: Because of self_loops
        if hparams['SIZE_CUT_DIRNAME'] is not None:
            hparams['BIG_PADDING_N_EDGE'] = hparams['BIG_PADDING_N_EDGE'] + hparams['BIG_PADDING_N_NODE'] # NOTE: Because of self_loops
        if len(hparams['BOND_FEATURES']) > 0:
            raise ValueError('Can not have both bond features and self_loops.')

    logger = logging.getLogger('main_train')
    # logger.setLevel(logging.DEBUG)
    logger.setLevel(logging.INFO)
    _datetime = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    logdir = os.path.join(logdir, _datetime)
    os.makedirs(logdir)
    os.mkdir(os.path.join(logdir, 'ckpts'))
    logger_file_handler = logging.FileHandler(os.path.join(logdir, 'run.log'))
    logger_stdout_handler = logging.StreamHandler(sys.stdout)
    logger.addHandler(logger_file_handler)
    logger.addHandler(logger_stdout_handler)

    logger.info('jax_version = {}'.format(jax.__version__))
    logger.info('flax_version = {}'.format(flax.__version__))
    logger.info('from_disk = {}'.format(hparams['PYTABLE_FROM_DISK']))
    logger.info('model_name = {}'.format(hparams['MODEL_NAME']))
    logger.info('loader_output_type = {}'.format(hparams['LOADER_OUTPUT_TYPE']))

    # ---------
    # Datasets:
    # ---------
    import tables
    # h5file = tables.open_file(hparams['BERT_H5FILE'], mode = 'r', title="ProtBERT")
    # bert_table = h5file.root.bert.BERTtable

    h5file = tables.open_file(hparams['H5FILE'], mode = 'r', title=hparams['H5FILE_TITLE'])
    h5_table = h5file.root.amino.table # h5file.root.bert.BERTtable

    collate = AminoCollatePrecompute(h5_table, 
                                    padding_n_node = hparams['PADDING_N_NODE'], 
                                    padding_n_edge = hparams['PADDING_N_EDGE'],
                                    n_partitions = hparams['N_PARTITIONS'],
                                    from_disk = hparams['PYTABLE_FROM_DISK'],
                                    line_graph = hparams['LINE_GRAPH'])
    
    element = AminoElementPrecompute(bert_table = h5_table,
                                    padding_n_node = hparams['PADDING_N_NODE'], 
                                    padding_n_edge = hparams['PADDING_N_EDGE'],
                                    from_disk = hparams['PYTABLE_FROM_DISK'],
                                    seq_lookup = hparams['CACHE_SEQ_LOOKUP'])

    if hparams['SIZE_CUT_DIRNAME'] is not None:
        big_collate = AminoCollatePrecompute(h5_table, 
                                                padding_n_node = hparams['BIG_PADDING_N_NODE'], 
                                                padding_n_edge = hparams['BIG_PADDING_N_EDGE'],
                                                n_partitions = hparams['N_PARTITIONS'],
                                                from_disk = hparams['PYTABLE_FROM_DISK'],
                                                line_graph = hparams['LINE_GRAPH'])
        
        big_element = AminoElementPrecompute(bert_table = h5_table,
                                    padding_n_node = hparams['BIG_PADDING_N_NODE'], 
                                    padding_n_edge = hparams['BIG_PADDING_N_EDGE'],
                                    from_disk = hparams['PYTABLE_FROM_DISK'],
                                    seq_lookup = hparams['CACHE_SEQ_LOOKUP'])

    if hparams['CACHE_SEQ_LOOKUP']:
        id_mapping_table, seq_embedding_lookup = element.create_seq_embedding_lookups()
    else:
        id_mapping_table, seq_embedding_lookup = None, None

    if not hparams['PYTABLE_FROM_DISK']:
        h5file.close()
        logger.info('Table closed...: {}'.format(hparams['H5FILE']))

    dataset = AminoDatasetPrecompute(data_csv = data_train_csv,
                        mols_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['MOLS_CSV']),
                        mol_id_col = hparams['MOL_ID_COL'],
                        mol_col = hparams['MOL_COL'],
                        seq_id_col = hparams['SEQ_ID_COL'], # Gene is only sequence id.
                        label_col = hparams['LABEL_COL'],
                        weight_col = hparams['WEIGHT_COL'],
                        atom_features = model.atom_features,# ['AtomicNum', 'ChiralTag', 'Hybridization', 'FormalCharge', 
                                # 'NumImplicitHs', 'ExplicitValence', 'Mass', 'IsAromatic'],
                        bond_features = model.bond_features, # ['BondType', 'IsAromatic'],
                        class_alpha = hparams['CLASS_ALPHA'],
                        line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * hparams['PADDING_N_NODE'],
                        self_loops = hparams['SELF_LOOPS'],
                        line_graph = hparams['LINE_GRAPH'],
                        auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                        auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                        mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                        seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                        )
    logger.info('Train dataset size: {}'.format(len(dataset)))

    loader = AminoLoader(dataset, 
                        batch_size = hparams['BATCH_SIZE'],
                        collate_fn = collate.make_collate(),
                        shuffle = True,
                        rng = jax.random.PRNGKey(int(time.time())),
                        drop_last = True,
                        n_partitions = hparams['N_PARTITIONS'])

    valid_dataset = AminoDatasetPrecompute(data_csv = data_valid_csv,
                        mols_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['MOLS_CSV']),
                        mol_id_col = hparams['MOL_ID_COL'], 
                        mol_col = hparams['MOL_COL'],
                        seq_id_col = hparams['SEQ_ID_COL'], # Gene is only sequence id.
                        label_col = hparams['LABEL_COL'],
                        weight_col = hparams['VALID_WEIGHT_COL'],
                        atom_features = model.atom_features, # ['AtomicNum', 'ChiralTag', 'Hybridization', 'FormalCharge', 
                                # 'NumImplicitHs', 'ExplicitValence', 'Mass', 'IsAromatic'],
                        bond_features = model.bond_features, # ['BondType', 'IsAromatic'],
                        line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * hparams['PADDING_N_NODE'],
                        self_loops = hparams['SELF_LOOPS'],
                        line_graph = hparams['LINE_GRAPH'],
                        auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                        auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                        mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                        seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                        )
    logger.info('Valid dataset size: {}'.format(len(valid_dataset)))

    valid_loader = AminoLoader(valid_dataset, 
                        batch_size = hparams['BATCH_SIZE'],
                        collate_fn = collate.make_collate(),
                        shuffle = True,
                        rng = jax.random.PRNGKey(int(time.time())),
                        drop_last = True,
                        n_partitions = hparams['N_PARTITIONS'])

    if hparams['SIZE_CUT_DIRNAME'] is not None:
        _big_dataset = AminoDatasetPrecompute(data_csv = big_data_train_csv,
                        mols_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['MOLS_CSV']),
                        mol_id_col = hparams['MOL_ID_COL'],
                        mol_col = hparams['MOL_COL'],
                        seq_id_col = hparams['SEQ_ID_COL'], # Gene is only sequence id.
                        label_col = hparams['LABEL_COL'],
                        weight_col = hparams['WEIGHT_COL'],
                        atom_features = model.atom_features,# ['AtomicNum', 'ChiralTag', 'Hybridization', 'FormalCharge', 
                                # 'NumImplicitHs', 'ExplicitValence', 'Mass', 'IsAromatic'],
                        bond_features = model.bond_features, # ['BondType', 'IsAromatic'],
                        class_alpha = hparams['CLASS_ALPHA'],
                        line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * hparams['BIG_PADDING_N_NODE'],
                        self_loops = hparams['SELF_LOOPS'],
                        line_graph = hparams['LINE_GRAPH'],
                        auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                        auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                        mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                        seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                        )
        big_dataset = dataset + _big_dataset
        logger.info('Big train dataset size: {}'.format(len(big_dataset)))

        big_loader = AminoLoader(big_dataset, 
                            batch_size = hparams['BIG_BATCH_SIZE'],
                            collate_fn = big_collate.make_collate(),
                            shuffle = True,
                            rng = jax.random.PRNGKey(int(time.time())),
                            drop_last = True,
                            n_partitions = hparams['N_PARTITIONS'])

        _big_valid_dataset = AminoDatasetPrecompute(data_csv = big_data_valid_csv,
                        mols_csv = os.path.join(hparams['DATA_PARENT_DIR'], hparams['MOLS_CSV']),
                        mol_id_col = hparams['MOL_ID_COL'],
                        mol_col = hparams['MOL_COL'],
                        seq_id_col = hparams['SEQ_ID_COL'], # Gene is only sequence id.
                        label_col = hparams['LABEL_COL'],
                        weight_col = hparams['VALID_WEIGHT_COL'],
                        atom_features = model.atom_features, # ['AtomicNum', 'ChiralTag', 'Hybridization', 'FormalCharge', 
                                # 'NumImplicitHs', 'ExplicitValence', 'Mass', 'IsAromatic'],
                        bond_features = model.bond_features, # ['BondType', 'IsAromatic'],
                        line_graph_max_size = hparams['LINE_GRAPH_MAX_SIZE_MULTIPLIER'] * hparams['BIG_PADDING_N_NODE'],
                        self_loops = hparams['SELF_LOOPS'],
                        line_graph = hparams['LINE_GRAPH'],
                        auxiliary_label_cols = hparams['AUXILIARY_LABEL_COLS'],
                        auxiliary_weight_cols = hparams['AUXILIARY_WEIGHT_COLS'],
                        mol_global_cols = hparams['MOL_GLOBAL_COLS'],
                        seq_global_cols = hparams['SEQ_GLOBAL_COLS'],
                        )
        big_valid_dataset = valid_dataset + _big_valid_dataset
        logger.info('Big valid dataset size: {}'.format(len(big_valid_dataset)))

        big_valid_loader = AminoLoader(big_valid_dataset, 
                                batch_size = hparams['BIG_BATCH_SIZE'],
                                collate_fn = big_collate.make_collate(),
                                shuffle = True,
                                rng = jax.random.PRNGKey(int(time.time())),
                                drop_last = True,
                                n_partitions = hparams['N_PARTITIONS'])


    # if hparams['LOADER_OUTPUT_TYPE'] == 'jax':
    #     loader = _loader
    #     valid_loader = _valid_loader
    #     if hparams['SIZE_CUT_DIRNAME'] is not None:
    #         big_loader = _big_loader
    #         big_valid_loader = _big_valid_loader

    if hparams['LOADER_OUTPUT_TYPE'] == 'tf':
        dataset.element_preprocess = element.make_element_preprocess()
        valid_dataset.element_preprocess = element.make_element_preprocess()

        loader = get_tf_loader(dataset,
                               batch_size = hparams['BATCH_SIZE'],
                               use_cache = hparams['CACHE'],
                               shuffle = True,
                               shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                               drop_last = True,
                               id_mapping_table = id_mapping_table,
                               seq_embedding_lookup = seq_embedding_lookup)
        valid_loader = get_tf_loader(valid_dataset, 
                                    batch_size = hparams['BATCH_SIZE'],
                                    use_cache = hparams['CACHE'],
                                    shuffle = True, 
                                    shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                                    drop_last = True,
                                    id_mapping_table = id_mapping_table,
                                    seq_embedding_lookup = seq_embedding_lookup)

        if hparams['SIZE_CUT_DIRNAME'] is not None:
            big_dataset.element_preprocess = big_element.make_element_preprocess()
            big_valid_dataset.element_preprocess = big_element.make_element_preprocess()

            big_loader = get_tf_loader(big_dataset,
                               batch_size = hparams['BIG_BATCH_SIZE'],
                               use_cache = hparams['CACHE'],
                               shuffle = True,
                               shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                               drop_last = True,
                               id_mapping_table = id_mapping_table,
                               seq_embedding_lookup = seq_embedding_lookup,)
            big_valid_loader = get_tf_loader(big_valid_dataset, 
                                            batch_size = hparams['BIG_BATCH_SIZE'],
                                            use_cache = hparams['CACHE'],
                                            shuffle = True, 
                                            shuffle_buffer_size = hparams['SHUFFLE_BUFFER_SIZE'],
                                            drop_last = True,
                                            id_mapping_table = id_mapping_table,
                                            seq_embedding_lookup = seq_embedding_lookup)


    # ----------------
    # Initializations:
    # ----------------
    key1, key2 = jax.random.split(jax.random.PRNGKey(int(time.time())), 2)
    key_params, _key_num_steps, key_num_steps, key_dropout = jax.random.split(key1, 4)

    # Initializations:
    start = time.time()
    logger.info('Initializing...')
    init_model = make_init_model(model, 
                                batch_size = hparams['BATCH_SIZE'], 
                                n_partitions = hparams['N_PARTITIONS'],
                                seq_embedding_size = hparams['SEQ_EMBEDDING_SIZE'], 
                                num_node_features = len(hparams['ATOM_FEATURES']), 
                                num_edge_features = len(hparams['BOND_FEATURES']), 
                                self_loops = hparams['SELF_LOOPS'], 
                                line_graph = hparams['LINE_GRAPH'],
                                seq_max_length = hparams['SEQ_MAX_LENGTH'],
                                padding_n_node = hparams['PADDING_N_NODE'], 
                                padding_n_edge = hparams['PADDING_N_EDGE']) # 768)
    params = init_model(rngs = {'params' : key_params, 'dropout' : key_dropout, 'num_steps' : _key_num_steps}) # jax.random.split(key1, jax.device_count()))
    end = time.time()
    logger.info('TIME: init_model: {}'.format(end - start))

    transition_steps = hparams['OPTIMIZATION']['TRANSITION_EPOCHS']*(len(dataset)/hparams['BATCH_SIZE'])
    if hparams['SIZE_CUT_DIRNAME'] is not None:
        transition_steps += hparams['OPTIMIZATION']['TRANSITION_EPOCHS']*(len(_big_dataset)/hparams['BIG_BATCH_SIZE'])

    create_optimizer = make_create_optimizer(model, 
                                            option = hparams['OPTIMIZATION']['OPTION'],
                                            warmup_steps = hparams['OPTIMIZATION']['WARMUP_STEPS'],
                                            transition_steps = transition_steps)
    init_state, scheduler = create_optimizer(params, 
                                            rngs = {'dropout' : key_dropout, 'num_steps' : key_num_steps},
                                            learning_rate = hparams['LEARNING_RATE'])

    # Restore params:
    if restore_file is not None:
        logger.info('Restoring parameters from {}'.format(restore_file))
        with open(restore_file, 'rb') as pklfile:
            bytes_output = pickle.load(pklfile)
        state = serialization.from_bytes(init_state, bytes_output)
        logger.info('Parameters restored...')
    else:
        state = init_state    

    reg_loss_func_embed = make_regularization_loss(params_path = ['params/atomic_num_embed/node_embed/embedding',
                                                            'params/chiral_tag_embed/node_embed/embedding',
                                                            'params/hybridization_embed/node_embed/embedding',
                                                            'params/X_proj_non_embeded/kernel',
                                                            'params/X_proj_non_embeded/bias',
                                                            'params/bond_type_embed/edge_embed/embedding',
                                                            'params/E_proj_non_embeded/kernel',
                                                            'params/E_proj_non_embeded/bias',
                                                            'params/stereo_embed/edge_embed/embedding',
                                                            ], alpha = 0.01, option = 'l1')
    reg_loss_func_kernel = make_regularization_loss(params_path = ['kernel',
                                                            ], alpha = 0.01, option = 'l2')
    def reg_loss_func(params):
        return reg_loss_func_embed(params) + reg_loss_func_kernel(params)

    if hparams['N_PARTITIONS'] > 0:
        raise NotImplementedError('pmap needs to be checked...')
        # state = flax.jax_utils.replicate(state)
    else:
        train_epoch = make_train_epoch(is_weighted = hparams['WEIGHT_COL'] is not None, 
                                        loss_option = hparams['LOSS_OPTION'], 
                                        init_rngs = state.rngs, 
                                        logger = logger,
                                        aux_loss_option = hparams['AUXILIARY_LOSS_OPTION'],
                                        reg_loss_func = reg_loss_func, 
                                        loader_output_type = hparams['LOADER_OUTPUT_TYPE'], 
                                        num_classes = model.out_features)
        valid_epoch = make_valid_epoch(loss_option = hparams['LOSS_OPTION'], 
                                    logger = logger,
                                    aux_loss_option = hparams['AUXILIARY_LOSS_OPTION'], 
                                    loader_output_type = hparams['LOADER_OUTPUT_TYPE'], 
                                    num_classes = model.out_features)

    # Log hyperparams:
    _hparams = {}
    for key in hparams.keys():
        _hparams[key] = _serialize_hparam(hparams[key])
    hparams_logs = _hparams
    # hparams_logs.update({'DATACASE' : _datacase})
    with open(os.path.join(logdir, 'hparams_logs.json'), 'w') as jsonfile:
        json.dump(hparams_logs, jsonfile)

    # Training:
    _, key = jax.random.split(key2, 2) 
    logger.info('Training...')
    # ---- NEW
    train_writer = SummaryWriter(os.path.join(logdir, 'train'))
    valid_writer = SummaryWriter(os.path.join(logdir, 'validation'))
    # ----

    if hparams['N_PARTITIONS'] > 0:
        epoch = state.epoch[0]
    else:
        epoch = state.epoch
    while epoch <= hparams['N_EPOCH']:
        logger.info('Epoch:  {}'.format(epoch))
        if epoch == hparams['BIG_SWITCH_EPOCH']:
            logger.info('Switching to Big loader...')

        start = time.time()
        if epoch >= hparams['BIG_SWITCH_EPOCH']: # TODO: This doesn't make sense if there is no BIG!
            state, batch_metrics = train_epoch(state, big_loader)
        else:
            state, batch_metrics = train_epoch(state, loader)
        end = time.time()
        logger.info('TIME: train_epoch: {}'.format(end - start))
        batch_metrics_np = jax.device_get(batch_metrics)
        lr = scheduler(state.step)

        summary = Summary()
        summary.scalar('learning_rate', lr)
        summary = log_metrics_from_epoch(name = 'train',
                            epoch = epoch, 
                            hparams = hparams, 
                            metrics_np = batch_metrics_np, 
                            logger = logger, 
                            summary = summary)
        summary.scalar('regularization_loss', reg_loss_func(state.params))

        train_writer.write(summary, step = int(epoch))
        train_writer.writer.flush()

        # VALID:
        start = time.time()
        if epoch >= hparams['BIG_SWITCH_EPOCH']:
            state, valid_metrics = valid_epoch(state, big_valid_loader)
        else:
            state, valid_metrics = valid_epoch(state, valid_loader)
        end = time.time()
        logger.info('TIME: valid_epoch: {}'.format(end - start))
        valid_metrics_np = jax.device_get(valid_metrics)

        summary = Summary()
        summary = log_metrics_from_epoch(name = 'valid',
                            epoch = epoch, 
                            hparams = hparams, 
                            metrics_np = valid_metrics_np, 
                            logger = logger, 
                            summary = summary)

        valid_writer.write(summary, step = int(epoch))
        valid_writer.writer.flush()

    
        # Save current state:
        if epoch%hparams['SAVE_FREQUENCY'] == 0:
            bytes_output = serialization.to_bytes(state)
            with open(os.path.join(logdir, 'ckpts', 'state_e' + str(epoch) + '.pkl'), 'wb') as pklfile:
                pickle.dump(bytes_output, pklfile)
            logger.info('State {} saved...'.format('state_e' + str(epoch)))

        # Update epoch number:
        state = state.replace(epoch = jax.tree_map(lambda x: x + 1, state.epoch))
        
        # Get epoch for while condition:
        if hparams['N_PARTITIONS'] > 0:
            epoch = state.epoch[0]
        else:
            epoch = state.epoch


    # ---- NEW
    train_writer.close()
    valid_writer.close()
    # ----
    logger.info('Finished...')
    return None