# script for pruning and finetuning the base models


import os.path as osp
import sys, time
from pathlib import Path
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.utils import to_categorical

from examples.common.paths import configure_paths
from examples.common.sample_config import create_sample_config
from examples.tensorflow.classification.datasets.builder import DatasetBuilder
from examples.tensorflow.common.argparser import get_common_argument_parser
from examples.tensorflow.common.callbacks import get_callbacks
from examples.tensorflow.common.callbacks import get_progress_bar
from examples.tensorflow.common.distributed import get_distribution_strategy
from examples.tensorflow.common.experimental_patcher import patch_if_experimental_quantization
from examples.tensorflow.common.export import export_model
from examples.tensorflow.common.logger import logger
from examples.tensorflow.common.model_loader import get_model
from examples.tensorflow.common.optimizer import build_optimizer
from examples.tensorflow.common.scheduler import build_scheduler
from examples.tensorflow.common.utils import SummaryWriter
from examples.tensorflow.common.utils import create_code_snapshot
from examples.tensorflow.common.utils import get_run_name
from examples.tensorflow.common.utils import get_saving_parameters
from examples.tensorflow.common.utils import print_args
from examples.tensorflow.common.utils import serialize_cli_args
from examples.tensorflow.common.utils import serialize_config
from examples.tensorflow.common.utils import set_seed
from examples.tensorflow.common.utils import write_metrics
from nncf.config.utils import is_accuracy_aware_training
from nncf.tensorflow import create_compression_callbacks
from nncf.tensorflow.helpers.model_creation import create_compressed_model
from nncf.tensorflow.helpers.model_manager import TFModelManager
from nncf.tensorflow.initialization import register_default_init_args
from nncf.tensorflow.utils.state import TFCompressionState
from nncf.tensorflow.utils.state import TFCompressionStateLoader
from nncf import NNCFConfig
from nncf.tensorflow import create_compressed_model, register_default_init_args

from sklearn.metrics import classification_report
to_percentage=100

def process_results(predictions, ground_true):
    predictions_classes = np.argmax(predictions, axis=1)
    true_classes = np.argmax(ground_true,axis=1)
    # print('{},{}'.format(predictions_classes.shape,predictions_classes[0]))
    # print('{},{}'.format(true_classes.shape,true_classes[0]))
    res_dic = classification_report(true_classes, predictions_classes, output_dict=True)
    acc = res_dic.get("accuracy")
    mac_precision = res_dic.get("macro avg").get("precision")
    mac_recall = res_dic.get("macro avg").get("recall")
    mac_f1_score = res_dic.get("macro avg").get("f1-score")
    support = res_dic.get("macro avg").get("support")
    print('accuracy : {:.4f}; precision : {:.4f}; recall : {:.4f}; f1-score : {:.4f}; support : {} '.format(acc, mac_precision, mac_recall, mac_f1_score, support))#mac_precision)
    

def get_argument_parser():
    parser = get_common_argument_parser(precision=False, save_checkpoint_freq=False, print_freq=False)

    parser.add_argument(
        "--dataset", help="Dataset to use.", choices=["imagenet2012", "cifar100", "cifar10"], default=None
    )
    parser.add_argument(
        "--test-every-n-epochs", default=1, type=int, help="Enables running validation every given number of epochs"
    )
    parser.add_argument(
        "--pretrained",
        dest="pretrained",
        help="Use pretrained models from the tf.keras.applications",
        action="store_true",
    )
    return parser


def get_config_from_argv(argv, parser):
    args = parser.parse_args(args=argv)
    config = create_sample_config(args, parser)
    configure_paths(config, get_run_name(config))
    return config



def get_num_classes(dataset):
    if "imagenet2012" in dataset:
        num_classes = 1000
    elif dataset == "cifar100":
        num_classes = 100
    elif dataset == "cifar10":
        num_classes = 10
    else:
        num_classes = 1000

    print("The sample is started with {} classes".format(num_classes))
    return num_classes



def get_model_accuracy(model_fn, model_params, nncf_config, validation_dataset, validation_steps):
    with TFModelManager(model_fn, nncf_config, **model_params) as model:
        model.compile(metrics=[tf.keras.metrics.CategoricalAccuracy(name="acc@1")])
        results = model.evaluate(validation_dataset, steps=validation_steps, return_dict=True)
        return 100 * results["acc@1"]


def run(config):
    # if config.disable_tensor_float_32_execution:
    #     tf.config.experimental.enable_tensor_float_32_execution(False)

    # strategy = get_distribution_strategy(config)
    # if config.metrics_dump is not None:
    #     write_metrics(0, config.metrics_dump)

    # set_seed(config)

    # model_fn, model_params = get_model(
    #     config.model,
    #     input_shape=config.get("input_info", {}).get("sample_size", None),
    #     num_classes=config.get("num_classes", get_num_classes(config.dataset)),
    #     pretrained=config.get("pretrained", False),
    #     weights=config.get("weights", None),
    # )

    # train_builder, validation_builder = get_dataset_builders(config, strategy.num_replicas_in_sync)
    # train_dataset, validation_dataset = train_builder.build(), validation_builder.build()

   

    train_epochs = 25
    train_steps = 390
    validation_steps = 78
    model = tf.keras.models.load_model("classifier-ResNet50-cifar10-on-50000.h5")

    # nncf_config = NNCFConfig.from_json("/home/bhvora/nncf/examples/tensorflow/classification/configs/pruning/densenet121-cifar10-pruning_50%_L2.json")
    
    (train_images, train_labels), _ = tf.keras.datasets.cifar10.load_data()
    
    train_images = train_images / 255.0
    print("Train image object: {}, Train label object: {}".format(train_images.shape, train_labels.shape))
    train_labels = to_categorical(train_labels)
    print("Train image object: {}, Train label object: {}".format(train_images.shape, train_labels.shape))

    train_dataset = tf.data.Dataset.from_tensor_slices((tf.cast(train_images, tf.float32),tf.cast(train_labels, tf.int64)))
    train_dataset = train_dataset.batch(128)

    # resume_training = config.ckpt_path is not None
    nncf_config = register_default_init_args(
        nncf_config=config.nncf_config, data_loader=train_dataset, batch_size=128
    )
    compression_state = None
    # if resume_training
    #     compression_state = load_compression_state(config.ckpt_path)

    # with TFModelManager(model_fn, nncf_config, **model_params) as model:
        # with strategy.scope():
    compression_ctrl, compress_model = create_compressed_model(model, nncf_config, compression_state)
    compression_callbacks = create_compression_callbacks(compression_ctrl, log_dir=config.log_dir)

    scheduler = build_scheduler(config=config, steps_per_epoch=train_steps)
    optimizer = build_optimizer(config=config, scheduler=scheduler)

    loss_obj = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1)

    compress_model.add_loss(compression_ctrl.loss)

    metrics = [
        tf.keras.metrics.CategoricalAccuracy(name="acc@1"),
        tf.keras.metrics.TopKCategoricalAccuracy(k=5, name="acc@5"),
        tfa.metrics.MeanMetricWrapper(loss_obj, name="ce_loss"),
        tfa.metrics.MeanMetricWrapper(compression_ctrl.loss, name="cr_loss"),
    ]

    compress_model.compile(
        optimizer=optimizer, loss=loss_obj, metrics=metrics, run_eagerly= False
    )

    compress_model.summary()

    checkpoint = tf.train.Checkpoint(
        model=compress_model, compression_state=TFCompressionState(compression_ctrl)
    )

    initial_epoch = 0
    # if resume_training:
    #     initial_epoch = resume_from_checkpoint(
    #         checkpoint=checkpoint, ckpt_path=config.ckpt_path, steps_per_epoch=train_steps
    #     )

    callbacks = get_callbacks(
        include_tensorboard=True,
        track_lr=True,
        profile_batch=0,
        initial_step=initial_epoch * train_steps,
        log_dir=config.log_dir,
        # ckpt_dir=config.checkpoint_save_dir,
        checkpoint=checkpoint,
    )

    callbacks.append(get_progress_bar(stateful_metrics=["loss"] + [metric.name for metric in metrics]))
    callbacks.extend(compression_callbacks)

    # validation_kwargs = {
    #     "validation_data": validation_dataset,
    #     "validation_steps": validation_steps,
    #     "validation_freq": config.test_every_n_epochs,
    # }

    if "train" in config.mode:
        if is_accuracy_aware_training(config):
            print("starting an accuracy-aware training loop...")
            result_dict_to_val_metric_fn = lambda results: 100 * results["acc@1"]
            statistics = compress_model.accuracy_aware_fit(
                train_dataset,
                compression_ctrl,
                nncf_config=config.nncf_config,
                callbacks=callbacks,
                initial_epoch=initial_epoch,
                # steps_per_epoch=train_steps,
                # tensorboard_writer=SummaryWriter(config.log_dir, "accuracy_aware_training"),
                log_dir=config.log_dir,
                result_dict_to_val_metric_fn=result_dict_to_val_metric_fn
                # **validation_kwargs,
            )
            # print(f"Compressed model statistics:\n{statistics.to_str()}")
        else:
            print("training...")
            compress_model.fit(
                train_dataset,
                epochs=train_epochs,
                # steps_per_epoch=train_steps,
                initial_epoch=initial_epoch,
                callbacks=callbacks
                # **validation_kwargs,
            )

    # print("evaluation...")
    # statistics = compression_ctrl.statistics()
    # print(statistics.to_str())
    # eval_model = compress_model

    x_test = np.load('cifar10-x-test-to-tensorrt-10000.npy')
    y_test = np.load('cifar10-y-test-to-tensorrt-10000.npy')

    input_tensor=tf.constant(x_test.astype('float32'))

    compress_model.predict(input_tensor)
    start_time=time.time()
    predictions = compress_model.predict(input_tensor)
    end_time = time.time()
    elapsed_time = end_time - start_time
    print("NNCF stats on benign test examples with inference in {:.2f} ms.".format(elapsed_time*1000))
    process_results(predictions, y_test)
    optimization_str = str(config.config).split('/')[-1].split('.')[0]
    compression_ctrl.export_model( optimization_str +"-classifier-ResNet50-cifar10-on-50000.h5", save_format='h5')

    # results = eval_model.evaluate(
    #     validation_dataset,
    #     steps=validation_steps,
    #     callbacks=[get_progress_bar(stateful_metrics=["loss"] + [metric.name for metric in metrics])],
    #     verbose=1,
    # )

    # if config.metrics_dump is not None:
    #     write_metrics(results[1], config.metrics_dump)

    # if "export" in config.mode:
    #     save_path, save_format = get_saving_parameters(config)
    #     export_model(compression_ctrl.strip(), save_path, save_format)
    #     print("Saved to {}".format(save_path))


# def export():
#     model = tf.keras.models.load_model(  "/home/bhvora/comp_robust/output/cifar10/BM/classifier-DenseNet121-cifar10-on-50000.h5")

#     nncf_config = NNCFConfig.from_json("/home/bhvora/nncf/examples/tensorflow/classification/configs/pruning/densenet121-cifar10-pruning_50%_L2.json")
    
#     (train_images, train_labels), _ = tf.keras.datasets.cifar10.load_data()
    
#     train_images = train_images / 255.0
#     print("Train image object: {}, Train label object: {}".format(train_images.shape, train_labels.shape))
#     train_labels = to_categorical(train_labels)
#     print("Train image object: {}, Train label object: {}".format(train_images.shape, train_labels.shape))

#     train_dataset = tf.data.Dataset.from_tensor_slices((tf.cast(train_images, tf.float32),tf.cast(train_labels, tf.int64)))
#     train_dataset = train_dataset.batch(128)

#     compression_ctrl, compress_model = create_compressed_model(model, nncf_config)

#     metrics = [
#         tf.keras.metrics.CategoricalAccuracy(name="acc@1"),
#         tf.keras.metrics.TopKCategoricalAccuracy(k=5, name="acc@5"),
#     ]
#     loss_obj = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1)

#     compress_model.compile(loss=loss_obj, metrics=metrics)
#     compress_model.summary()

#    # checkpoint = tf.train.Checkpoint(model=compress_model, compression_state=TFCompressionState(compression_ctrl))
   

#    compression_ctrl.export_model( optimization_str+ "-classifier-DenseNet121-cifar10-on-50000.h5", save_format='h5')
    

    # save_path, save_format = get_saving_parameters(config)
    # export_model(compression_ctrl.strip(), save_path, save_format)
    # print("Saved to {}".format(save_path))


def main(argv):
    parser = get_argument_parser()
    config = get_config_from_argv(argv, parser)
    # print_args(config)
    patch_if_experimental_quantization(config.nncf_config)

    serialize_config(config.nncf_config, config.log_dir)
    serialize_cli_args(parser, argv, config.log_dir)
    
    run(config)


if __name__ == "__main__":
    main(sys.argv[1:])
