import os
import importlib
import argparse
import sys
import pathlib
import pickle
import numpy as np
from time import strftime
from shutil import copyfile
import gzip

import tensorflow as tf
import tensorflow.contrib.eager as tfe

import utilities
from utilities import log

from utilities_tf import load_batch_gcnn


def load_batch_tf(x):
    return tf.py_func(
        load_batch_gcnn,
        [x],
        [tf.float32, tf.int32, tf.float32, tf.float32, tf.int32, tf.int32, tf.int32])


def pretrain(model, dataloader):
    """
    Pre-normalizes a model (i.e., PreNormLayer layers) over the given samples.

    Parameters
    ----------
    model : model.BaseModel
        A base model, which may contain some model.PreNormLayer layers.
    dataloader : tf.data.Dataset
        Dataset to use for pre-training the model.
    Return
    ------
    number of PreNormLayer layers processed.
    """
    model.pre_train_init()
    i = 0
    while True:
        for batch in dataloader:
            c, ei, ev, v, n_cs, n_vs, cands = batch
            batched_states = (c, ei, ev, v, n_cs, n_vs)

            if not model.pre_train(batched_states, tf.convert_to_tensor(True)):
                break

        res = model.pre_train_next()
        if res is None:
            break
        else:
            layer, name = res

        i += 1

    return i


def process(model, dataloader, top_k, optimizer=None):
    mean_loss = 0
    mean_kacc = np.zeros(len(top_k))

    n_samples_processed = 0
    for batch in dataloader:
        c, ei, ev, v, n_cs, n_vs, best_cands = batch
        batched_states = (c, ei, ev, v, tf.reduce_sum(n_cs, keepdims=True), tf.reduce_sum(n_vs, keepdims=True))  # prevent padding
        batch_size = len(n_cs.numpy())

        if optimizer:
            with tf.GradientTape() as tape:
                logits = model(batched_states, tf.convert_to_tensor(True)) # training mode
                logits = tf.transpose(tf.concat((1-logits, logits), axis=0))

                norm = 9
                weights = best_cands * (norm - 1) + 1

                # loss = tf.losses.mean_squared_error(labels=best_cands, predictions=logits[0])
                loss = tf.losses.sparse_softmax_cross_entropy(labels=best_cands,logits=logits,weights=weights)

            grads = tape.gradient(target=loss, sources=model.variables)
            optimizer.apply_gradients(zip(grads, model.variables))
        else:
            logits = model(batched_states, tf.convert_to_tensor(False))  # eval mode
            logits = tf.transpose(tf.concat((1-logits, logits), axis=0))
            # logits = tf.expand_dims(tf.gather(tf.squeeze(logits, 0), cands), 0)  # filter candidate variables
            # logits = model.pad_output(logits, n_cands.numpy())  # apply padding now
            # loss = tf.losses.sparse_softmax_cross_entropy(labels=best_cands, logits=logits)
            loss = tf.losses.sparse_softmax_cross_entropy(labels=best_cands,logits=logits)

        # true_scores = model.pad_output(tf.reshape(cand_scores, (1, -1)), n_cands)
        # true_bestscore = tf.reduce_max(true_scores, axis=-1, keepdims=True)
        # true_scores = true_scores.numpy()
        # true_bestscore = true_bestscore.numpy()

        # kacc = []
        # for k in top_k:
        #     pred_top_k = tf.nn.top_k(logits, k=k)[1].numpy()
        #     pred_top_k_true_scores = np.take_along_axis(true_scores, pred_top_k, axis=1)
        #     kacc.append(np.mean(np.any(pred_top_k_true_scores == true_bestscore, axis=1)))
        # kacc = np.asarray(kacc)

        mean_loss += loss.numpy() * batch_size
        # mean_kacc += kacc * batch_size
        n_samples_processed += batch_size

    mean_loss /= n_samples_processed
    # mean_kacc /= n_samples_processed

    return mean_loss


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'problem',
        help='MILP instance type to process.',
        choices=['setcover', 'cauctions', 'facilities', 'indset', 'item'],
    )
    parser.add_argument(
        '-m', '--model',
        help='GCNN model to be trained.',
        type=str,
        default='baseline',
    )
    parser.add_argument(
        '-s', '--seed',
        help='Random generator seed.',
        type=utilities.valid_seed,
        default=0,
    )
    parser.add_argument(
        '-g', '--gpu',
        help='CUDA GPU id (-1 for CPU).',
        type=int,
        default=1,
    )
    args = parser.parse_args()

    ### HYPER PARAMETERS ###
    max_epochs = 500
    epoch_size = 60  # 用于确定总共多少数据， epoch_size * batch size
    batch_size = 16  # default:32
    pretrain_batch_size = batch_size
    valid_batch_size = batch_size
    lr = 0.001
    patience = 15
    early_stopping = 30
    top_k = [1, 3, 5, 10]
    train_ncands_limit = np.inf
    valid_ncands_limit = np.inf

    problem_folders = {
        'setcover': 'samples/' + args.problem,
        'item': 'samples/' + args.problem
        }
    problem_folder = problem_folders[args.problem]

    running_dir = f"trained_models/{args.problem}/{args.model}/{args.seed}"

    os.makedirs(running_dir, exist_ok=True)

    ### LOG ###
    logfile = os.path.join(running_dir, 'log.txt')

    log(f"max_epochs: {max_epochs}", logfile)
    log(f"epoch_size: {epoch_size}", logfile)
    log(f"batch_size: {batch_size}", logfile)
    log(f"pretrain_batch_size: {pretrain_batch_size}", logfile)
    log(f"valid_batch_size : {valid_batch_size }", logfile)
    log(f"lr: {lr}", logfile)
    log(f"patience : {patience }", logfile)
    log(f"early_stopping : {early_stopping }", logfile)
    log(f"top_k: {top_k}", logfile)
    log(f"problem: {args.problem}", logfile)
    log(f"gpu: {args.gpu}", logfile)
    log(f"seed {args.seed}", logfile)

    ### NUMPY / TENSORFLOW SETUP ###
    if args.gpu == -1:
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = f'{args.gpu}'
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    tf.enable_eager_execution(config)
    tf.executing_eagerly()

    rng = np.random.RandomState(args.seed)
    tf.set_random_seed(rng.randint(np.iinfo(int).max))

    ### SET-UP DATASET ###
    train_files = list(pathlib.Path(f'{problem_folder}/train').glob('sample_*.pkl'))
    valid_files = list(pathlib.Path(f'{problem_folder}/valid').glob('sample_*.pkl'))

    train_files = [str(x) for x in train_files]
    valid_files = [str(x) for x in valid_files]

    valid_data = tf.data.Dataset.from_tensor_slices(valid_files)
    valid_data = valid_data.batch(valid_batch_size)
    valid_data = valid_data.map(load_batch_tf)
    valid_data = valid_data.prefetch(1)

    pretrain_files = [f for i, f in enumerate(train_files) if i % 10 == 0]
    pretrain_data = tf.data.Dataset.from_tensor_slices(pretrain_files)
    pretrain_data = pretrain_data.batch(pretrain_batch_size)
    pretrain_data = pretrain_data.map(load_batch_tf)
    pretrain_data = pretrain_data.prefetch(1)

    ### MODEL LOADING ###
    sys.path.insert(0, os.path.abspath(f'models/{args.model}'))
    import model
    importlib.reload(model)
    model = model.GCNPolicy()
    del sys.path[0]

    ### TRAINING LOOP ###
    optimizer = tf.train.AdamOptimizer(learning_rate=lambda: lr)  # dynamic LR trick
    best_loss = np.inf
    for epoch in range(max_epochs + 1):
        log(f"EPOCH {epoch}...", logfile)
        epoch_loss_avg = tfe.metrics.Mean()
        epoch_accuracy = tfe.metrics.Accuracy()

        # TRAIN
        if epoch == 0:
            n = pretrain(model=model, dataloader=pretrain_data)
            log(f"PRETRAINED {n} LAYERS", logfile)
            # model compilation
            model.call = tfe.defun(model.call, input_signature=model.input_signature)
        else:
            # bugfix: tensorflow's shuffle() seems broken...
            epoch_train_files = rng.choice(train_files, epoch_size * batch_size, replace=True)
            train_data = tf.data.Dataset.from_tensor_slices(epoch_train_files)
            train_data = train_data.batch(batch_size)
            train_data = train_data.map(load_batch_tf)
            # train_data = train_data.prefetch(1)
            train_loss = process(model, train_data, top_k, optimizer)
            log(f"TRAIN LOSS: {train_loss:0.3f} ", logfile)

        # TEST
        valid_loss = process(model, valid_data, top_k, None)
        log(f"VALID LOSS: {valid_loss:0.3f} ", logfile)

        if valid_loss < best_loss:
            plateau_count = 0
            best_loss = valid_loss
            model.save_state(os.path.join(running_dir, 'best_params.pkl'))
            log(f"  best model so far", logfile)
        else:
            plateau_count += 1
            if plateau_count % early_stopping == 0:
                log(f"  {plateau_count} epochs without improvement, early stopping", logfile)
                break
            if plateau_count % patience == 0:
                lr *= 0.2
                log(f"  {plateau_count} epochs without improvement, decreasing learning rate to {lr}", logfile)

    model.restore_state(os.path.join(running_dir, 'best_params.pkl'))
    valid_loss = process(model, valid_data, top_k, None)
    log(f"BEST VALID LOSS: {valid_loss:0.3f} ", logfile)

