#!/usr/bin/env python3

# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.

import os
import gc
import sys
import pickle
import argparse
import keras
import numpy as np
import tensorflow as tf
import neurite as ne
import voxelmorph as vxm
from glob import glob
import tqdm
import subprocess
import psutil
import warnings
from memory_profiler import profile
from tensorflow.keras import backend as k
warnings.filterwarnings('ignore', category=UserWarning)


#  PRINTING THE VERSIONS OF THE LIBRARIES ----
print('Python version:', sys.version)
print('TensorFlow version:', tf.__version__)
print('Numpy version:', np.__version__)
print('Keras version:', keras.__version__)


# reference
ref = (
    'If you find this script useful, please consider citing:\n\n'
    '\tM Hoffmann, B Billot, DN Greve, JE Iglesias, B Fischl, AV Dalca\n'
    '\tSynthMorph: learning contrast-invariant registration without acquired images\n'
    '\tIEEE Transactions on Medical Imaging (TMI), 41 (3), 543-558, 2022\n'
    '\thttps://doi.org/10.1109/TMI.2021.3116879\n'
)


# parse command line
bases = (argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter)
p = argparse.ArgumentParser(
    formatter_class=type('formatter', bases, {}),
    description=f'Train a SynthMorph model on images synthesized from label maps. {ref}',
)


# data organization parameters
p.add_argument('--label-dir', nargs='+', help='path or glob pattern pointing to input label maps')
p.add_argument('--model-dir', default='models', help='model output directory')
p.add_argument('--log-dir', help='optional TensorBoard log directory')
p.add_argument('--sub-dir', help='optional subfolder for logs and model saves')

# generation parameters
p.add_argument('--same-subj', action='store_true', help='generate image pairs from same label map')
p.add_argument('--blur-std', type=float, default=1, help='maximum blurring std. dev.')
p.add_argument('--gamma', type=float, default=0.25, help='std. dev. of gamma')
p.add_argument('--vel-std', type=float, default=0.5, help='std. dev. of SVF')
p.add_argument('--vel-res', type=float, nargs='+', default=[16], help='SVF scale')
p.add_argument('--bias-std', type=float, default=0.3, help='std. dev. of bias field')
p.add_argument('--bias-res', type=float, nargs='+', default=[40], help='bias scale')
p.add_argument('--out-shape', type=int, nargs='+', help='output shape to pad to''')
p.add_argument('--out-labels', default='fs_labels.npy', help='labels to optimize, see README')

# training parameters
p.add_argument('--gpu', type=str, default='0', help='ID of GPU to use')
p.add_argument('--epochs', type=int, default=1500, help='training epochs')
p.add_argument('--batch-size', type=int, default=1, help='batch size')
p.add_argument('--init-weights', help='optional weights file to initialize with')
p.add_argument('--save-freq', type=int, default=20, help='epochs between model saves')
p.add_argument('--reg-param', type=float, default=1., help='regularization weight')
p.add_argument('--lr', type=float, default=1e-4, help='learning rate')
p.add_argument('--init-epoch', type=int, default=0, help='initial epoch number')
p.add_argument('--verbose', type=int, default=1, help='0 silent, 1 bar, 2 line/epoch')

# network architecture parameters
p.add_argument('--int-steps', type=int, default=5, help='number of integration steps')
p.add_argument('--enc', type=int, nargs='+', default=[64] * 4, help='U-Net encoder filters')
# p.add_argument('--enc', type=int, nargs='+', default=[16, 32, 32, 32], help='U-Net encoder filters')
p.add_argument('--dec', type=int, nargs='+', default=[64] * 6, help='U-Net decorder filters')
# p.add_argument('--dec', type=int, nargs='+', default=[32, 32, 32, 32, 32, 16, 16], help='U-Net decorder filters')

# loss selection parameters
p.add_argument('--image-loss', default='ncc',
                    help='image reconstruction loss - can be dice, ncc or dicencc (default: ncc)')
p.add_argument('--lambda', type=float, dest='lambda_weight', default=0.01,
                    help='weight of gradient (default: 0.01)')

# dry-run arguments
# p.add_argument('--dry-run', action='store_true', help='run a short training session')
arg = p.parse_args()


# TensorFlow handling
device, nb_devices = vxm.tf.utils.setup_device(arg.gpu)
print('DEVICE:', device)
assert np.mod(arg.batch_size, nb_devices) == 0, \
    f'batch size {arg.batch_size} not a multiple of the number of GPUs {nb_devices}'
assert tf.__version__.startswith('2'), f'TensorFlow version {tf.__version__} is not 2 or later'

# Prepare directories
#arg.label_dir = '/home/deeksha/mirageormagic/data/neurite-OASIS'
#arg.sub_dir = 'synthmorph'
#arg.model_dir = '/home/deeksha/mirageormagic/voxelmorph/training_results'
#arg.log_dir = '/home/deeksha/mirageormagic/voxelmorph/training_results/synthmorph/logs'
arg.sub_dir = "sxm"
arg.model_dir = "./sxm/training_results"
arg.log_dir = "./sxm/training_results/logs/"

if arg.sub_dir:
    arg.model_dir = os.path.join(arg.model_dir, arg.sub_dir)
os.makedirs(arg.model_dir, exist_ok=True)

if arg.log_dir:
    if arg.sub_dir:
        arg.log_dir = os.path.join(arg.log_dir, arg.sub_dir)
    os.makedirs(arg.log_dir, exist_ok=True)

# label_names = glob(os.path.join(arg.label_dir, 'OASIS_OAS1*_MR1', 'aligned_seg35.nii.gz'))


# Check if TensorFlow is running on GPU or CPU
if tf.config.list_physical_devices('GPU'):
    print("TensorFlow is running on GPU")
else:
    print("TensorFlow is running on CPU")
input("Press Enter to continue...")


@profile
# Custom generator for SynthMorph registration
def generatelabelmaps(in_shape, num_label):
    im = ne.utils.augment.draw_perlin(
        out_shape=(*in_shape, num_label), # 160, 192, 224, 36
        scales=(32, 64), max_std=1,
    )
    print('im: ', sys.getsizeof(im), 'bytes')

    lab = tf.argmax(im, axis=-1)
    print('lab: ', sys.getsizeof(lab), 'bytes')
    print('lab np uint8: ', sys.getsizeof(np.uint8(lab)), 'bytes')
    print('lab np array: ', sys.getsizeof(np.array(lab, dtype='uint8')), 'bytes')

    del im
    gc.collect()

    return np.array(lab, dtype='uint8')


# @profile
def customsynthmorph(in_shape, num_label, batch_size=1, same_subj=False, flip=True):
    """
    Customer Generator for SynthMorph registration.

    Parameters:
        labels_maps: List of pre-loaded ND label maps, each as a NumPy array.
        batch_size: Batch size. Default is 1.
        same_subj: Whether the same label map is returned as the source and target for further
            augmentation. Default is False.
        flip: Whether axes are flipped randomly. Default is True.


    Shapes:
        vol_shape =====>  (160, 192, 224)
        num_dims =====>  3
        void shape =====>  (1, 160, 192, 224, 3)
        label_maps =====>  1 (160, 192, 224)
        ind =====>  [0 0]
        x =====>  2 (160, 192, 224)
        x =====>  2 (160, 192, 224, 1)
        src =====>  1 (160, 192, 224, 1)
        trg =====>  1 (160, 192, 224, 1)


        
    """
    
    vol_shape = in_shape[1:] 
    num_dim = len(vol_shape) 

    print("Starting...")
    keras.backend.clear_session()

    # "True" moved image and warp, that will be ignored by SynthMorph losses.
    void = np.zeros((batch_size, *vol_shape, num_dim), dtype='float32') 

    rand = np.random.default_rng()
    prop = dict(replace=False, shuffle=False)

    print('vol: ', sys.getsizeof(vol_shape), 'bytes')
    print('num_dim: ', sys.getsizeof(num_dim), 'bytes')
    print('void shape: ', sys.getsizeof(void), 'bytes')

    
    while True:

        label_maps = generatelabelmaps(in_shape, num_label)  
        ind = rand.integers(len(label_maps), size=2 * batch_size)

        x = [label_maps[i] for i in ind]

        if same_subj:
            x = x[:batch_size] * 2
        x = np.stack(x)[..., None]

        if flip:
            axes = rand.choice(num_dim, size=rand.integers(num_dim + 1), **prop)
            x = np.flip(x, axis=axes + 1)

        src = x[:batch_size, ...]
        trg = x[batch_size:, ...]

        print('label_maps: ', sys.getsizeof(label_maps), 'bytes')
        print('ind: ', sys.getsizeof(ind), 'bytes')
        print('x: ', sys.getsizeof(x), 'bytes')
        print('src: ', sys.getsizeof(src), 'bytes')
        print('trg: ', sys.getsizeof(trg), 'bytes')

        del label_maps, x

        gc.collect()
        yield [src, trg], [void] * 2




# labels and label maps
# labels_in, maps = vxm.py.utils.load_labels(label_names)
labels_in = [i for i in range(36)]


# in_shape = maps[0].shape
in_shape = (160, 192, 224)



# print(f'Found {len(labels_in)} labels: {labels_in}')




gen = customsynthmorph(
    in_shape=(arg.batch_size,)+in_shape, #1, 160, 192, 224
    num_label=len(labels_in),
    batch_size=arg.batch_size,
    same_subj=arg.same_subj,
    flip=True,
)


arg.out_labels = '/home/deeksha/mirageormagic/data/neurite-OASIS/seg35_labels.txt'
if arg.out_labels.endswith('.npy'):
    labels_out = sorted(x for x in np.load(arg.out_labels) if x in labels_in)
elif arg.out_labels.endswith('.pickle'):
    with open(arg.out_labels, 'rb') as f:
        labels_out = {k: v for k, v in pickle.load(f).items() if k in labels_in}
else:
    labels_out = labels_in






# model configuration
gen_args = dict(
    in_shape=in_shape,
    out_shape=arg.out_shape,
    in_label_list=labels_in,
    out_label_list=labels_out,
    warp_std=arg.vel_std,
    warp_res=arg.vel_res,
    blur_std=arg.blur_std,
    bias_std=arg.bias_std,
    bias_res=arg.bias_res,
    gamma_std=arg.gamma,
)

reg_args = dict(
    int_steps=arg.int_steps,
    int_resolution=2,
    svf_resolution=2,
    nb_unet_features=(arg.enc, arg.dec),
)



# build model
strategy = 'MirroredStrategy' if nb_devices > 1 else 'get_strategy'
with getattr(tf.distribute, strategy)().scope():

    # generation
    gen_model_1 = ne.models.labels_to_image(**gen_args, id=0)
    gen_model_2 = ne.models.labels_to_image(**gen_args, id=1)
    ima_1, map_1 = gen_model_1.outputs
    ima_2, map_2 = gen_model_2.outputs


    # registration
    inputs = gen_model_1.inputs + gen_model_2.inputs
    reg_args['inshape'] = ima_1.shape[1:-1]
    reg_args['input_model'] = tf.keras.Model(inputs, outputs=(ima_1, ima_2))
    model = vxm.networks.VxmDense(**reg_args)
    flow = model.references.pos_flow
    # neg_flow = model.references.neg_flow
    pred_map1 = vxm.layers.SpatialTransformer(interp_method='linear', name='pred')([map_1, flow])
    pred_ima1 = vxm.layers.SpatialTransformer(interp_method='linear', name='pred')([ima_1, flow])


    const = tf.ones(shape=arg.batch_size // nb_devices)
    if arg.image_loss == 'ncc':
        model.add_loss(vxm.losses.NCC().loss(ima_2, pred_ima1) + const)
        model.add_loss(vxm.losses.Grad('l2', loss_mult=arg.reg_param).loss(None, flow))
        weights = [1, arg.lambda_weight]

    elif arg.image_loss == 'dice':
        model.add_loss(vxm.losses.Dice().loss(map_2, pred_map1) + const)
        model.add_loss(vxm.losses.Grad('l2', loss_mult=arg.reg_param).loss(None, flow))
        weights = [1, arg.lambda_weight]

    elif arg.image_loss == 'dicencc':
        model.add_loss(vxm.losses.NCC().loss(ima_2, pred_ima1))
        model.add_loss(vxm.losses.Dice().loss(map_2, pred_map1) + const)
        model.add_loss(vxm.losses.Grad('l2', loss_mult=arg.reg_param).loss(None, flow))
        weights = [1, 1, arg.lambda_weight]

    else:
        raise ValueError('Image loss should be "dice", "ncc" or "dicencc" (default: "ncc"), but found "%s"' % arg.image_loss)

    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=arg.lr), loss_weights=weights)


    model.summary()



# callbacks
steps_per_epoch = 100
save_name = os.path.join(arg.model_dir, f'{{epoch:05d}}_{arg.image_loss}.h5')
save = tf.keras.callbacks.ModelCheckpoint(
    save_name,
    save_freq=steps_per_epoch * arg.save_freq,
)

# class MemoryCallback(tf.keras.callbacks.Callback):
#     def on_epoch_end(self, epoch, logs=None):
#         # GPU memory
#         result = subprocess.check_output(['nvidia-smi', '--format=csv', '--query-gpu=memory.used,memory.free,utilization.gpu'])
#         print('\nGPU Memory and Utilisation:\n', result.decode('utf-8'))

#         # CPU memory
#         memory_info = psutil.virtual_memory()
#         print('\nCPU Memory: used = {} MB, free = {} MB'.format(memory_info.used // 1024**2, memory_info.available // 1024**2))

#         # CPU utilization
#         cpu_utilization = psutil.cpu_percent(interval=1)
#         print('\nCPU Utilization: {}%'.format(cpu_utilization))

class ClearMemory(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        gc.collect()
        k.clear_session()
        # del model

callbacks = [save, ClearMemory()]
# callbacks.append(MemoryCallback())


# if arg.log_dir:
#     log = tf.keras.callbacks.TensorBoard(
#         log_dir=arg.log_dir,
#         histogram_freq=1,
#         # write_graph=False,
#     )
# callbacks.append(log)
print(save_name)


# initialize and fit
if arg.init_weights:
    model.load_weights(arg.init_weights)


# model.save(save_name.format(epoch=arg.init_epoch))
model.save(save_name.format(epoch=arg.init_epoch))

# @profile
model.fit(
    gen,
    initial_epoch=arg.init_epoch,
    epochs=arg.epochs,
    callbacks=callbacks,
    steps_per_epoch=steps_per_epoch,
    verbose=arg.verbose,
)
print(f'\nThank you for using SynthMorph! {ref}')



