import os
import sys
sys.path.append('./')

import math
import datetime
import itertools
import numpy as np
from typing import Any
from tabulate import tabulate
from functools import partial
from collections import OrderedDict

import jax
import jaxlib
import flax
import optax
import jax.numpy as jnp
import tensorflow as tf
import tensorflow_datasets as tfds
from flax import jax_utils, serialization
from flax.training import checkpoints, common_utils, train_state
from flax.training import dynamic_scale as dynamic_scale_lib
from tensorflow.io.gfile import GFile

from scripts import defaults
from src import input_pipeline
from src.resnet import FlaxResNet
from src.metrics import evaluate_acc, evaluate_nll


def launch(config, print_fn):

    local_device_count = jax.local_device_count()
    shard_shape = (local_device_count, -1)

    # setup mixed precision training if specified
    platform = jax.local_devices()[0].platform
    if config.mixed_precision and platform == 'gpu':
        dynamic_scale = dynamic_scale_lib.DynamicScale()
        model_dtype = jnp.float16
    elif config.mixed_precision and platform == 'tpu':
        dynamic_scale = None
        model_dtype = jnp.bfloat16
    else:
        dynamic_scale = None
        model_dtype = jnp.float32

    # ----------------------------------------------------------------------- #
    # Dataset
    # ----------------------------------------------------------------------- #
    if config.data_name.startswith('cifar100'):
        dataset_builder = tfds.builder('cifar100')
    else:
        dataset_builder = tfds.builder(config.data_name)
    NUM_CLASSES = defaults.NUM_CLASSES[config.data_name]

    trn_split = defaults.TRN_SPLIT[config.data_name]
    trn_iter = input_pipeline.create_val_split(
        dataset_builder, 1, split=trn_split, add_tfds_id=True)
    
    TFDSID_LIST = []
    for batch_idx, batch in enumerate(trn_iter):
        TFDSID_LIST.append(str(batch['tfdsid'][0]))
        if batch_idx == dataset_builder.info.splits[trn_split].num_examples:
            break
    TFDSID_LIST = list(set(TFDSID_LIST))
    np.random.default_rng(config.seed).shuffle(TFDSID_LIST)

    blocked_ws = jax.random.dirichlet(
        jax.random.PRNGKey(config.seed), jnp.concatenate([
        jnp.ones((config.dir_num_blocks,)), jnp.ones(
            (config.dir_num_blocks,)) * config.dir_alpha / config.dir_trunc]))

    TFDSID_TO_W_CE = {}
    TFDSID_TO_W_KD = {}
    for iii, TFDSID in enumerate(TFDSID_LIST):
        i, j = np.random.default_rng(
            config.seed + iii).choice(config.dir_num_blocks, size=2)
        TFDSID_TO_W_CE[TFDSID] = config.dir_num_blocks * float(
            blocked_ws[0 * config.dir_num_blocks + i])
        TFDSID_TO_W_KD[TFDSID] = config.dir_num_blocks * float(
            blocked_ws[1 * config.dir_num_blocks + j])

    def prepare_tf_data_trn(batch):
        batch['weight_ce'] = np.array([
            TFDSID_TO_W_CE[str(e)] for e in batch['tfdsid']])
        batch['weight_kd'] = np.array([
            TFDSID_TO_W_KD[str(e)] for e in batch['tfdsid']])
        batch['images'] = batch['images']._numpy()
        batch['labels'] = batch['labels']._numpy()
        batch['marker'] = np.ones_like(batch['labels'])
        del batch['tfdsid']
        def _prepare(x):
            if x.shape[0] < config.batch_size:
                x = np.concatenate([x, np.zeros([
                    config.batch_size - x.shape[0], *x.shape[1:]
                ], x.dtype)])
            return x.reshape(shard_shape + x.shape[1:])
        return jax.tree_util.tree_map(_prepare, batch)

    trn_split = defaults.TRN_SPLIT[config.data_name]
    trn_steps_per_epoch = math.ceil(
        dataset_builder.info.splits[trn_split].num_examples / config.batch_size)
    trn_iter = map(prepare_tf_data_trn, input_pipeline.create_trn_split(
        dataset_builder, config.batch_size, split=trn_split, add_tfds_id=True))
    trn_iter = jax_utils.prefetch_to_device(trn_iter, config.prefetch_factor)

    def prepare_tf_data_val(batch):
        batch['images'] = batch['images']._numpy()
        batch['labels'] = batch['labels']._numpy()
        batch['marker'] = np.ones_like(batch['labels'])
        def _prepare(x):
            if x.shape[0] < config.batch_size:
                x = np.concatenate([x, np.zeros([
                    config.batch_size - x.shape[0], *x.shape[1:]
                ], x.dtype)])
            return x.reshape(shard_shape + x.shape[1:])
        return jax.tree_util.tree_map(_prepare, batch)

    val_split = defaults.VAL_SPLIT[config.data_name]
    val_steps_per_epoch = math.ceil(
        dataset_builder.info.splits[val_split].num_examples / config.batch_size)
    val_iter = map(prepare_tf_data_val, input_pipeline.create_val_split(
        dataset_builder, config.batch_size, split=val_split))
    val_iter = jax_utils.prefetch_to_device(val_iter, config.prefetch_factor)

    tst_split = defaults.TST_SPLIT[config.data_name]
    tst_steps_per_epoch = math.ceil(
        dataset_builder.info.splits[tst_split].num_examples / config.batch_size)
    tst_iter = map(prepare_tf_data_val, input_pipeline.create_val_split(
        dataset_builder, config.batch_size, split=tst_split))
    tst_iter = jax_utils.prefetch_to_device(tst_iter, config.prefetch_factor)

    # ----------------------------------------------------------------------- #
    # Model
    # ----------------------------------------------------------------------- #
    model = FlaxResNet(
        image_size=224,
        depth=config.resnet_depth,
        widen_factor=config.resnet_width,
        dtype=model_dtype,
        pixel_mean=(0.48145466, 0.45782750, 0.40821073),
        pixel_std=(0.26862954, 0.26130258, 0.27577711))
        
    def initialize_model(key, model):
        @jax.jit
        def init(*args):
            return model.init(*args)
        return init({'params': key}, jnp.ones((1, 224, 224, 3), model.dtype))
    variables = initialize_model(jax.random.PRNGKey(config.seed), model)

    # define forward function and specify shapes
    images = next(trn_iter)['images']
    output = jax.pmap(model.apply)({
        'params': jax_utils.replicate(variables['params']),
        'batch_stats': jax_utils.replicate(variables['batch_stats']),
        'image_stats': jax_utils.replicate(variables['image_stats'])}, images)
    FEATURE_DIM = output.shape[-1]

    log_str = f'images.shape: {images.shape}, output.shape: {output.shape}'
    print_fn(log_str)

    # load pre-trained checkpoint
    ckpt = checkpoints.restore_checkpoint(config.resnet_ext_init, target=None)
    ckpt['params']['cls'] = checkpoints.restore_checkpoint(
        config.resnet_cls_init, target=None)['params']['cls']

    # setup trainable parameters
    params = {
        'ext': ckpt['params']['ext'],
        'cls': jnp.zeros_like(ckpt['params']['cls'])}
    log_str = 'The number of trainable parameters: {:d}'.format(
        jax.flatten_util.ravel_pytree(params)[0].size)
    print_fn(log_str)

    # ----------------------------------------------------------------------- #
    # Optimization
    # ----------------------------------------------------------------------- #
    def step_trn(state, batch, config, scheduler, dynamic_scale):

        def _global_norm(updates):
            return jnp.sqrt(sum([jnp.sum(jnp.square(e))
                                 for e in jax.tree_util.tree_leaves(updates)]))
        
        def _clip_by_global_norm(updates, global_norm):
            return jax.tree_util.tree_map(
                lambda e: jnp.where(
                    global_norm < config.optim_global_clipping, e,
                    (e / global_norm) * config.optim_global_clipping), updates)
        
        # define loss function
        def loss_fn(params):

            # get features
            output, new_model_state = model.apply({
                'params': params['ext'],
                'batch_stats': state.batch_stats,
                'image_stats': state.image_stats}, batch['images'] / 255.0,
                mutable='batch_stats',
                use_running_average=config.resnet_bn_freeze)

            # loss_ce
            smooth = config.optim_label_smoothing
            target = common_utils.onehot(batch['labels'], NUM_CLASSES)
            target = (1.0 - smooth) * target + \
                smooth * jnp.ones_like(target) / NUM_CLASSES
            source = jax.nn.log_softmax(output @ params['cls'], axis=-1)
            loss_ce = -jnp.sum(target * source, axis=-1)
            loss_ce = jnp.mean(batch['weight_ce'] * loss_ce)
            
            # loss_kd
            priors = model.apply({
                'params': ckpt['params']['ext'],
                'batch_stats': ckpt['batch_stats'],
                'image_stats': ckpt['image_stats']}, batch['images'] / 255.0)
            target = jax.nn.softmax(priors @ ckpt['params']['cls'], axis=-1)
            loss_kd = -jnp.sum(target * source, axis=-1)
            loss_kd = jnp.mean(batch['weight_kd'] * loss_kd)

            # loss_l2sp
            x = params['ext']
            m = ckpt['params']['ext']
            loss_l2sp = 0.5 * sum([
                jnp.sum((e1 - e2)**2) for e1, e2 in zip(
                    jax.tree_util.tree_leaves(x),
                    jax.tree_util.tree_leaves(m),
                )]) * config.l2sp_decay

            # loss
            loss = loss_ce + loss_kd + loss_l2sp

            # log metrics
            metrics = OrderedDict({
                'loss': loss,
                'loss_ce': loss_ce,
                'loss_kd': loss_kd,
                'loss_l2sp': loss_l2sp})
            return loss, (metrics, new_model_state)

        # compute losses and gradients
        if dynamic_scale:
            dynamic_scale, is_fin, aux, grads = dynamic_scale.value_and_grad(
                loss_fn, has_aux=True, axis_name='batch')(state.params)
        else:
            aux, grads = jax.value_and_grad(
                loss_fn, has_aux=True)(state.params)
            grads = jax.lax.pmean(grads, axis_name='batch')

        # weight decay regularization in PyTorch-style
        grads = jax.tree_util.tree_map(
            lambda g, p: g + config.optim_weight_decay * p,
            grads, state.params)
        
        # compute norms of weights and gradients
        w_norm = _global_norm(state.params)
        g_norm = _global_norm(grads)
        if config.optim_global_clipping:
            grads = _clip_by_global_norm(grads, g_norm)

        # get auxiliaries
        metrics = jax.lax.pmean(aux[1][0], axis_name='batch')
        metrics['w_norm'] = w_norm
        metrics['g_norm'] = g_norm
        metrics['lr'] = scheduler(state.step)

        # update train state
        new_state = state.apply_gradients(grads=grads)
        if not config.resnet_bn_freeze:
            new_state = new_state.replace(batch_stats=aux[1][1]['batch_stats'])
        if dynamic_scale:
            new_state = new_state.replace(
                opt_state=jax.tree_util.tree_map(
                    partial(jnp.where, is_fin),
                    new_state.opt_state, state.opt_state),
                params=jax.tree_util.tree_map(
                    partial(jnp.where, is_fin),
                    new_state.params, state.params))
            metrics['dyn_scale'] = dynamic_scale.scale
        return new_state, metrics
    
    # define optimizer with scheduler
    scheduler = optax.join_schedules(
        schedules=[
            optax.linear_schedule(
                init_value       = 0.0,
                end_value        = config.optim_lr,
                transition_steps = math.floor(0.1 * config.optim_ni)),
            optax.cosine_decay_schedule(
                init_value       = config.optim_lr,
                decay_steps      = math.floor(0.9 * config.optim_ni))
        ], boundaries=[
            math.floor(0.1 * config.optim_ni),
        ])
    optimizer = optax.sgd(
        scheduler, momentum=config.optim_momentum,
        accumulator_dtype=model_dtype)

    # build and replicate train state
    class TrainState(train_state.TrainState):
        batch_stats: Any = None
        image_stats: Any = None

    state = TrainState.create(
        apply_fn=model.apply, params=params, tx=optimizer,
        batch_stats=ckpt['batch_stats'], image_stats=ckpt['image_stats'])
    state = jax_utils.replicate(state)

    def apply_fn(images, state):
        return model.apply({
            'params': state.params['ext'],
            'batch_stats': state.batch_stats,
            'image_stats': state.image_stats}, images) @ state.params['cls']
    p_apply_fn = jax.pmap(apply_fn)

    # run optimization
    best_acc = 0.0
    p_step_trn = jax.pmap(partial(
        step_trn, config=config, scheduler=scheduler), axis_name='batch')
    sync_stats = jax.pmap(lambda x: jax.lax.pmean(x, 'x'), 'x')

    if dynamic_scale:
        dynamic_scale = jax_utils.replicate(dynamic_scale)

    trn_metric = []
    for iter_idx in itertools.count(start=1):
        
        # rendezvous
        jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()

        # terminate training
        if iter_idx == config.optim_ni + 1:
            break

        # ------------------------------------------------------------------- #
        # Train
        # ------------------------------------------------------------------- #
        log_str = '[Iter {:7d}/{:7d}] '.format(iter_idx, config.optim_ni)

        batch = next(trn_iter)
        state, metrics = p_step_trn(state, batch, dynamic_scale=dynamic_scale)
        if not config.resnet_bn_freeze:
            state = state.replace(batch_stats=sync_stats(state.batch_stats))
        trn_metric.append(metrics)

        if iter_idx % 500 == 0:
            trn_summarized, val_summarized, tst_summarized = {}, {}, {}
            
            trn_metric = common_utils.get_metrics(trn_metric)
            trn_summarized = {f'trn/{k}': v for k, v in jax.tree_util.tree_map(
                lambda e: e.mean(), trn_metric).items()}
            trn_metric = []

            log_str += ', '.join(
                f'{k} {v:.3e}' for k, v in trn_summarized.items())

            # --------------------------------------------------------------- #
            # Valid
            # --------------------------------------------------------------- #
            acc, nll, cnt = 0.0, 0.0, 0
            for batch_idx, batch in enumerate(val_iter, start=1):
                logits = p_apply_fn(batch['images'] / 255.0, state)
                logits = logits.reshape(-1, NUM_CLASSES)
                labels = batch['labels'].reshape(-1)
                marker = batch['marker'].reshape(-1)
                pre = jax.nn.log_softmax(logits, axis=-1)
                acc += jnp.sum(jnp.where(marker, evaluate_acc(
                    pre, labels, log_input=True, reduction='none'
                ), marker))
                nll += jnp.sum(jnp.where(marker, evaluate_nll(
                    pre, labels, log_input=True, reduction='none'
                ), marker))
                cnt += jnp.sum(marker)
                if batch_idx == val_steps_per_epoch:
                    break
            val_summarized['val/acc'] = acc / cnt
            val_summarized['val/nll'] = nll / cnt
            val_summarized['val/best_acc'] = max(
                val_summarized['val/acc'], best_acc)

            log_str += ', '
            log_str += ', '.join(
                f'{k} {v:.3e}' for k, v in val_summarized.items())

            # --------------------------------------------------------------- #
            # Save
            # --------------------------------------------------------------- #
            if best_acc < val_summarized['val/acc']:

                log_str += ' (best_acc: {:.3e} -> {:.3e})'.format(
                    best_acc, val_summarized['val/acc'])
                best_acc = val_summarized['val/acc']

                best_ckpt = {
                    'params': state.params,
                    'batch_stats': state.batch_stats,
                    'image_stats': state.image_stats,}
                best_ckpt = jax.device_get(
                    jax.tree_util.tree_map(lambda x: x[0], best_ckpt))

                if config.save:
                    best_path = os.path.join(config.save, 'best_acc.ckpt')
                    with GFile(best_path, 'wb') as fp:
                        fp.write(serialization.to_bytes(best_ckpt))
                
                # ----------------------------------------------------------- #
                # Test
                # ----------------------------------------------------------- #
                acc, nll, cnt = 0.0, 0.0, 0
                for batch_idx, batch in enumerate(tst_iter, start=1):
                    logits = p_apply_fn(batch['images'] / 255.0, state)
                    logits = logits.reshape(-1, NUM_CLASSES)
                    labels = batch['labels'].reshape(-1)
                    marker = batch['marker'].reshape(-1)
                    pre = jax.nn.log_softmax(logits, axis=-1)
                    acc += jnp.sum(jnp.where(marker, evaluate_acc(
                        pre, labels, log_input=True, reduction='none'
                    ), marker))
                    nll += jnp.sum(jnp.where(marker, evaluate_nll(
                        pre, labels, log_input=True, reduction='none'
                    ), marker))
                    cnt += jnp.sum(marker)
                    if batch_idx == tst_steps_per_epoch:
                        break
                tst_summarized['tst/acc'] = acc / cnt
                tst_summarized['tst/nll'] = nll / cnt

                log_str += ', '
                log_str += ', '.join(
                    f'{k} {v:.3e}' for k, v in tst_summarized.items())

            # logging current iteration
            print_fn(log_str)

            # terminate training if loss is nan
            if jnp.isnan(trn_summarized['trn/loss']):
                break


def main():

    TIME_STAMP = datetime.datetime.now().strftime('%Y%m%d%H%M%S')

    parser = defaults.default_argument_parser()

    parser.add_argument(
        '--l2sp_decay', required=True, type=float,
        help='l2sp reg')

    parser.add_argument(
        '--dir_num_blocks', default=10, type=int,
        help='the number of blocks (default: 10)')
    parser.add_argument(
        '--dir_alpha', default=1.0, type=float,
        help='concentration (default: 1.0)')
    parser.add_argument(
        '--dir_trunc', default=1.0, type=float,
        help='truncation limit (default: 1.0)')

    parser.add_argument(
        '--resnet_ext_init', required=True, type=str,
        help='path to the pre-trained *.ckpt (required)')
    parser.add_argument(
        '--resnet_cls_init', required=True, type=str,
        help='path to the pre-trained *.ckpt (required)')
    parser.add_argument(
        '--resnet_bn_freeze', default=False, type=defaults.str2bool,
        help='freeze batch_stats if specified (default: False)')

    parser.add_argument(
        '--optim_ni', default=5000, type=int,
        help='the number of training iterations (default: 5000)')
    parser.add_argument(
        '--optim_lr', default=0.01, type=float,
        help='base learning rate (default: 0.01)')
    parser.add_argument(
        '--optim_momentum', default=0.9, type=float,
        help='momentum coefficient (default: 0.9)')
    parser.add_argument(
        '--optim_weight_decay', default=0.0, type=float,
        help='weight decay coefficient (default: 0.0)')

    parser.add_argument(
        '--optim_label_smoothing', default=0.0, type=float,
        help='label smoothing regularization (default: 0.0)')
    parser.add_argument(
        '--optim_global_clipping', default=None, type=float,
        help='global norm for the gradient clipping (default: None)')

    parser.add_argument(
        '--save', default=None, type=str,
        help='save the *.log and *.ckpt files if specified (default: False)')
    parser.add_argument(
        '--seed', default=None, type=int,
        help='random seed for training (default: None)')

    parser.add_argument(
        '--mixed_precision', default=False, type=defaults.str2bool,
        help='run mixed precision training if specified (default: False)')
    
    args = parser.parse_args()
    
    if args.seed is None:
        args.seed = (
            os.getpid()
            + int(datetime.datetime.now().strftime('%S%f'))
            + int.from_bytes(os.urandom(2), 'big'))

    if args.save is not None:
        if os.path.exists(args.save):
            raise AssertionError(f'already existing args.save = {args.save}')
        os.makedirs(args.save, exist_ok=True)

    def print_fn(s):
        s = datetime.datetime.now().strftime('[%Y-%m-%d %H:%M:%S] ') + s
        if args.save is not None:
            with open(os.path.join(args.save, f'{TIME_STAMP}.log'), 'a') as fp:
                fp.write(s + '\n')
        print(s, flush=True)

    log_str = tabulate([
        ('sys.platform', sys.platform),
        ('Python', sys.version.replace('\n', '')),
        ('JAX', jax.__version__
            + ' @' + os.path.dirname(jax.__file__)),
        ('jaxlib', jaxlib.__version__
            + ' @' + os.path.dirname(jaxlib.__file__)),
        ('Flax', flax.__version__
            + ' @' + os.path.dirname(flax.__file__)),
        ('Optax', optax.__version__
            + ' @' + os.path.dirname(optax.__file__)),
    ]) + '\n'
    log_str = f'Environments:\n{log_str}'
    print_fn(log_str)

    log_str = ''
    max_k_len = max(map(len, vars(args).keys()))
    for k, v in vars(args).items():
        log_str += f'- args.{k.ljust(max_k_len)} : {v}\n'
    log_str = f'Command line arguments:\n{log_str}'
    print_fn(log_str)

    if jax.local_device_count() > 1:
        log_str = (
            'Multiple local devices are detected:\n'
            f'{jax.local_devices()}\n')
        print_fn(log_str)

    launch(args, print_fn)


if __name__ == '__main__':
    main()
