"""Entrypoint of the experiment."""

import os
import sys
import ast
import argparse
import numpy as np
import pandas as pd
import munch

import tensorflow as tf
import gin

from ncc import algos
from ncc import data
from ncc import compositionality_metrics


def generate_eval_protocol(model_inference_fn, dataset_iter,
                                 ds_info, data_name, protocol_size):
    """Generates protocol for eval."""
    preds_all = []
    labels_all = []
    communication_protocol = []

    # Iterate over data, div over 2 to support the batch size = 2 from loader
    for _ in range(protocol_size//2):
        batch = dataset_iter.__next__()
        image = batch['image']
        labels = [batch[key].numpy()
                  for key in ds_info['features_names']]
        labels = list(zip(*labels))
        labels_all += labels

        # Evaluate the model
        signal_logits, preds = model_inference_fn(image)

        # Store predictions
        preds_all.append(preds.numpy())

        # message is a list with "shape" (B, message_len)
        message = tf.argmax(signal_logits, axis=-1).numpy()

        communication_protocol += compositionality_metrics.get_protocol(
            message, labels, list(ds_info['features_names'])
        )
    preds_all = np.concatenate(preds_all)
    labels_all = np.array(labels_all)
    acc_all = np.mean(np.all(preds_all == labels_all, axis=1))
    acc = np.mean(preds_all == labels_all)
    eval_dict = {
        f'eval_{data_name}_acc_all': acc_all,
        f'eval_{data_name}_acc': acc,
    }

    return communication_protocol, eval_dict


@gin.configurable(whitelist=['metrics', 'eval_protocol_size'])
def eval_cb_(model_inference_fn, epoch, dataset_iter, ds_info,
             suffix, metrics=None, eval_protocol_size=None):
    """Evaluation function."""

    communication_protocol, eval_dict = \
        generate_eval_protocol(model_inference_fn, dataset_iter, ds_info,
                               'test_' + suffix, eval_protocol_size)
    for metric in metrics:
        eval_dict[metric.name + ' ' + suffix] = metric.measure(
            communication_protocol
        )
    eval_dict['eval samples from ' + suffix] = len(communication_protocol)
    protocol_df = pd.DataFrame(communication_protocol,
                               columns=['derivation', 'message'])
    protocol_df.to_pickle(f'data_out/sender_{suffix}_{epoch}_msg.p')

    return eval_dict


@gin.configurable
def run_experiment(seed=123, restore_pretrain_fname=None):
    """Main function for running the experiment."""

    tf.random.set_seed(seed)
    np.random.seed(seed)

    np.set_printoptions(linewidth=190, threshold=sys.maxsize)
    os.makedirs('data_out', exist_ok=True)

    base_dir = os.environ.get('DATA_DIR', '')

    ds_train, ds_train_for_eval, ds_test_in_sample, ds_test_out_of_sample, \
          ds_info = data.loader.load(data_dir=base_dir)

    if ds_test_out_of_sample:
        oos = len(ds_info['features_list_oos'])
        intr = len(ds_info['features_list_train'])
        ds_test_all = tf.data.experimental.sample_from_datasets(
            [ds_test_in_sample, ds_test_out_of_sample],
            [intr / (oos + intr), oos / (oos + intr)]
        )

    if restore_pretrain_fname:
        restore_pretrain_fname = base_dir + '/' + restore_pretrain_fname

    sender_kwargs = dict(
        message_len=ds_info['message_length'],
        alphabet_size=ds_info['sender_alphabet_size'],
        restore_pretrain_fname=restore_pretrain_fname
    )

    receiver_kwargs = dict(
        message_len=ds_info['message_length'],
        alphabet_size=ds_info['receiver_alphabet_size'],
        activation_fn=tf.keras.activations.elu,
    )

    ds_train_for_eval_iter = iter(ds_train_for_eval)
    ds_test_in_sample_iter = iter(ds_test_in_sample)
    ds_test_out_of_sample_iter = iter(ds_test_out_of_sample) \
        if ds_test_out_of_sample else None
    ds_test_all_iter = iter(ds_test_all) \
        if ds_test_out_of_sample else None

    def eval_cb(model_inference_fn, epoch):
        # In train evaluation
        eval_train = eval_cb_(model_inference_fn, epoch,
                              ds_train_for_eval_iter, ds_info, 'in_train')
        # In-sample evaluation
        eval_test = eval_cb_(model_inference_fn, epoch,
                             ds_test_in_sample_iter, ds_info, 'in_sample')
        eval_oos = {}
        eval_all = {}
        if ds_info['features_list_oos'] is not None:
            # Out-of-sample evaluation
            eval_oos = eval_cb_(model_inference_fn, epoch,
                                ds_test_out_of_sample_iter,
                                ds_info, 'out_of_sample')

            eval_all = eval_cb_(model_inference_fn, epoch,
                                ds_test_all_iter,
                                ds_info, 'test_all')

        return {**eval_train, **eval_test, **eval_oos, **eval_all}

    algos.noisy_channel(
        ds=ds_train,
        ds_info=ds_info,
        sender_kwargs=sender_kwargs,
        receiver_kwargs=receiver_kwargs,
        softmax_fn=algos.core.gumbel_softmax,
        learned_tau=True,
        smooth_coef=0.995,
        eval_cb=eval_cb
    )

def inject_dict_to_gin(dict_):
    """Injects gin."""
    gin_bindings = []
    for key, value in dict_.items():
        if key == 'imports':
            for module_str in value:
                binding = f'import {module_str}'
                gin_bindings.append(binding)
            continue

        if isinstance(value, str) and not value[0] in (
        '@', '%', '{', '(', '['):
            binding = f'{key} = "{value}"'
        else:
            binding = f'{key} = {value}'
        gin_bindings.append(binding)

    gin.parse_config(gin_bindings)


def nest_params(params, prefixes):
    """Nest params based on keys prefixes.

    Example:
      For input
      params = dict(
        param0=value0,
        prefix0_param1=value1,
        prefix0_param2=value2
      )
      prefixes = ("prefix0_",)
      This method modifies params into nested dictionary:
      {
        "param0" : value0
        "prefix0": {
          "param1": value1,
          "param2": value2
        }
      }
    """
    for prefix in prefixes:
        dict_params = munch.Munch()
        l_ = len(prefix)
        for k in list(params.keys()):
            if k.startswith(prefix):
                dict_params[k[l_:]] = params.pop(k)
        params[prefix[:-1]] = dict_params

def main():

    parser = argparse.ArgumentParser(description='Run experiment.')

    # Parameters of dataset:
    # =========================================================================
    parser.add_argument('--load.name', type=str, default='shapes3d')
    parser.add_argument('--load.batch_size', type=int, default=64)
    parser.add_argument('--load.train_features', type=ast.literal_eval,
            default=\
            '{"label_shape": [0, 1, 2, 3], "label_object_hue": [0, 1, 2, 3]}')
    parser.add_argument('--load.train_features_helper',
            type=str, default='@ncc.data.helpers.features_cartesian')
    # To remove out-of-sample data from training set one can use:
    # parser.add_argument('--load.out_of_sample_features',
    #         type=ast.literal_eval,
    #         default=\
    #        "{'label_shape': [0, 1, 2, 3], 'label_object_hue': [0, 1, 2, 3]}")
    # parser.add_argument('--load.out_of_sample_features_helper',
    #       type=str, default='@ncc.data.helpers.features_zip')
    parser.add_argument('--load.normalize', type=bool, default=True)
    parser.add_argument('--load.count_dataset_cardinality', type=bool,
            default=False)
    parser.add_argument('--load.sender_alphabet_size', type=int, default=5)
    parser.add_argument('--load.receiver_alphabet_size', type=int, default=8)
    parser.add_argument('--load.features_transform_fn', type=str, default=None)
    parser.add_argument('--load.image_transform_fn', type=str, default=None)
    parser.add_argument('--scramble_image.tile_size', type=ast.literal_eval,
            default=None)

    # Parameters of evaluation:
    # ==========================================================================
    parser.add_argument('--eval_cb_.eval_protocol_size', type=int, default=1024)
    parser.add_argument('--eval_cb_.metrics', type=str,
            default='[@TopographicSimilarity(), @ContextIndependence(),\
            @OneMessagePerClassWrapper(), @PositionalDisentanglement()]')
    parser.add_argument('--OneMessagePerClassWrapper.metric', type=str,
            default='@ConflictCount()')

    # Parameters for training:
    # ==========================================================================
    parser.add_argument('--noisy_channel.total_epochs', type=int,
            default=100000)
    parser.add_argument('--noisy_channel.noise', type=float, default=0.1)
    parser.add_argument('--noisy_channel.tau', type=float, default=2.)
    parser.add_argument('--noisy_channel.tau_min', type=float, default=0.5)
    parser.add_argument('--noisy_channel.learning_rate', type=float,
            default=1e-4)
    parser.add_argument('--noisy_channel.kl_coef', type=float, default=1e-2)
    parser.add_argument('--noisy_channel.ent_coef', type=float, default=0.0)
    parser.add_argument('--noisy_channel.straight_through', type=str,
            default='@random_scheduler')
    parser.add_argument('--noisy_channel.log_every', type=int, default=1000)
    parser.add_argument('--noisy_channel.eval_every', type=int, default=2000)
    parser.add_argument('--noisy_channel.save_every', type=int, default=5000)
    parser.add_argument('--noisy_channel.new_noise', type=float, default=0.001)

    # Parameters for random scheduler:
    # ==========================================================================
    parser.add_argument('--random_scheduler.probability', type=float,
            default=0.5)

    # Parameters for Run:
    # ==========================================================================
    parser.add_argument('--run_experiment.seed', type=int, default=0)

    # Parameters for Sender:
    # ==========================================================================
    parser.add_argument('--Sender.input_dim', type=ast.literal_eval,
            default='(128, 128, 3)')
    parser.add_argument('--Sender.weight_decay', type=float, default=3e-4)
    parser.add_argument('--Sender.embedding', type=float, default=64.0)
    parser.add_argument('--Sender.padding', type=str, default='SAME')
    parser.add_argument('--Sender.max_pool_strides', type=ast.literal_eval,
            default='[2, 2]')
    parser.add_argument('--Sender.max_pool_sizes', type=ast.literal_eval,
            default='[2, 2]')
    parser.add_argument('--Sender.conv_strides', type=ast.literal_eval,
            default='[1, 1]')
    parser.add_argument('--Sender.conv_kernel_sizes', type=ast.literal_eval,
            default='[3, 3]')
    parser.add_argument('--Sender.conv_filters', type=ast.literal_eval,
            default='[8, 8]')

    # Parameters for Receiver:
    # ==========================================================================
    parser.add_argument('--Receiver.hidden_sizes', type=ast.literal_eval,
            default='[64]')
    parser.add_argument('--Receiver.weight_decay', type=float, default=3e-4)
    parser.add_argument('--Receiver.filter_size', type=int, default=64)

    args = vars(parser.parse_args())
    gin_params = {param_name:args[param_name]
            for param_name in args if '.' in param_name}
    inject_dict_to_gin(gin_params)

    print(args)
    return run_experiment()


if __name__ == '__main__':
    main()
