import collections
import functools
import math
import random
from typing import List, Tuple
from absl import app
from absl import flags
from absl import logging
from typing import Optional

import tensorflow as tf
import tensorflow_federated as tff

from distributed_dp_matrix_factorization.dp_ftrl import aggregator_builder
from distributed_dp_matrix_factorization.dp_ftrl import dp_fedavg
from distributed_dp_matrix_factorization.dp_ftrl import training_loop
from utils import task_utils
from utils import utils_impl

IRRELEVANT_FLAGS = frozenset(iter(flags.FLAGS))

flags.DEFINE_string(
    'experiment_name', 'emnist', 'The name of this experiment. Will be'
    'append to  --root_output_dir to separate experiment results.')
flags.DEFINE_string('root_output_dir', '/tmp/dpftrl/emnist',
                    'Root directory for writing experiment output.')
flags.DEFINE_integer('rounds_per_checkpoint', 100,
                     'How often to checkpoint the global model.')
flags.DEFINE_integer(
    'rounds_per_eval', 20,
    'How often to evaluate the global model on the validation dataset.')
flags.DEFINE_integer('clients_per_thread', 1, 'TFF executor configuration.')

# Training
flags.DEFINE_integer('clients_per_round', 100,
                     'How many clients to sample per round.')
flags.DEFINE_integer('client_epochs_per_round', 1,
                     'Number of epochs in the client to take per round.')
flags.DEFINE_integer('client_batch_size', 16, 'Batch size used on the client.')
flags.DEFINE_integer('total_rounds', 10, 'Number of total training rounds.')
flags.DEFINE_integer(
    'total_epochs', 1,
    'If not None, use shuffling of clients instead of random sampling.')
flags.DEFINE_integer(
      'max_elements_per_client', None, 'Maximum number of '
      'elements for each training client. If set to None, all '
      'available examples are used.')
flags.DEFINE_integer(
      'num_validation_examples', -1, 'The number of validation'
      'examples to use. If set to -1, all available examples '
      'are used.')

with utils_impl.record_hparam_flags() as compression_flags:
  flags.DEFINE_integer('num_bits', 16, 'Number of bits for quantization.')
  flags.DEFINE_float('beta', math.exp(-0.5), 'Beta for stochastic rounding.')
  flags.DEFINE_integer('k_stddevs', 4,
                       'Number of stddevs to bound the signal range.')

with utils_impl.record_hparam_flags() as dp_flags:
  flags.DEFINE_float(
      'epsilon', 2.0, 'Epsilon for the DP mechanism. '
      'No DP used if this is None.')
  flags.DEFINE_float('delta', None, 'Delta for the DP mechanism. ')
  flags.DEFINE_float('l2_norm_clip', 2.0, 'Initial L2 norm clip.')

with utils_impl.record_hparam_flags() as task_flags:
  # Defines "--task" (options from `task_utils`) and "--<task>_<arg>" flags
  # aligned with input args at `tff.simulation.baselines.*` tasks.
  task_utils.define_task_flags()

# Optimizer
flags.DEFINE_enum('client_optimizer', 'sgd', ['sgd'], 'Client optimzier')
flags.DEFINE_float('client_lr', 0.02, 'Client learning rate.')
flags.DEFINE_float('server_lr', 1.0, 'Server learning rate.')
flags.DEFINE_float('server_momentum', 0.9, 'Server momentum for SGDM.')
flags.DEFINE_integer('client_datasets_random_seed', 42,
                       'Random seed for client sampling.')

# Differential privacy
flags.DEFINE_float('clip_norm', 1.0, 'Clip L2 norm.')
flags.DEFINE_float('noise_multiplier', 0.01,
                   'Noise multiplier for DP algorithm.')

_AGGREGATOR_METHOD = flags.DEFINE_enum(
    'aggregator_method', 'tree_aggregation',
    list(aggregator_builder.AGGREGATION_METHODS),
    'Enum indicating the aggregator method to use.')

flags.DEFINE_string(
    'lr_momentum_matrix_name', None,
    'Name of the mechanism (and partial path to stored matrix) '
    'for --aggregator_method=lr_momentum_matrix')

_CLIENT_SELECTION_SEED = flags.DEFINE_integer(
    'client_selection_seed',
    random.getrandbits(32),
    'Random seed for client selection.',
)

_RESHUFFLE_EACH_EPOCH = flags.DEFINE_boolean(
    'reshuffle_each_epoch',
    False,
    (
        'Requires --total_epochs >= 1. If set, reshuffle mapping of clients '
        'to rounds on each epoch.'
    ),
)

HPARAM_FLAGS = [f for f in flags.FLAGS if f not in IRRELEVANT_FLAGS]
FLAGS = flags.FLAGS


# model
def create_1m_cnn_model(only_digits: bool = False, seed: Optional[int] = 0):
  """A CNN model with slightly under 2^20 (roughly 1 million) params.

  A simple CNN model for the EMNIST character recognition task that is very
  similar to the default recommended model from `create_conv_dropout_model`
  but has slightly under 2^20 parameters. This is useful if the downstream task
  involves randomized Hadamard transform, which requires the model weights /
  gradients / deltas concatednated as a single vector to be padded to the
  nearest power-of-2 dimensions.

  This model is used in https://arxiv.org/abs/2102.06387.

  When `only_digits=False`, the returned model has 1,018,174 trainable
  parameters. For `only_digits=True`, the last dense layer is slightly smaller.

  Args:
    only_digits: If True, uses a final layer with 10 outputs, for use with the
      digits only EMNIST dataset. If False, uses 62 outputs for the larger
      dataset.
    seed: A random seed governing the model initialization and layer randomness.

  Returns:
    A `tf.keras.Model`.
  """
  data_format = 'channels_last'
  initializer = tf.keras.initializers.GlorotUniform(seed=seed)

  model = tf.keras.models.Sequential([
      tf.keras.layers.Conv2D(
          32,
          kernel_size=(3, 3),
          activation='relu',
          data_format=data_format,
          input_shape=(28, 28, 1),
          kernel_initializer=initializer),
      tf.keras.layers.MaxPool2D(pool_size=(2, 2), data_format=data_format),
      tf.keras.layers.Conv2D(
          64,
          kernel_size=(3, 3),
          activation='relu',
          data_format=data_format,
          kernel_initializer=initializer),
      tf.keras.layers.Dropout(0.25),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(
          128, activation='relu', kernel_initializer=initializer),
      tf.keras.layers.Dropout(0.5),
      tf.keras.layers.Dense(
          10 if only_digits else 62,
          activation=tf.nn.softmax,
          kernel_initializer=initializer),
  ])

  return model

def _client_optimizer_fn(name, learning_rate):
  if name == 'sgd':
    return tf.keras.optimizers.SGD(learning_rate)
  else:
    raise ValueError('Unknown client optimizer name {}'.format(name))
  
def _sample_client_ids(
    num_clients: int,
    client_data: tff.simulation.datasets.ClientData,
    round_num: int,
    epoch: int,
) -> Tuple[List, int]:  # pylint: disable=g-bare-generic
  """Returns a random subset of client ids."""
  del round_num  # Unused.
  return random.sample(client_data.client_ids, num_clients), epoch



def _build_tff_learning_model_and_process():

 train_client_spec = tff.simulation.baselines.ClientSpec(
      num_epochs=FLAGS.client_epochs_per_round,
      batch_size=FLAGS.client_batch_size,
      max_elements=FLAGS.max_elements_per_client)
 
 
 emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
        only_digits=False)
 
 eval_client_spec = tff.simulation.baselines.ClientSpec(
        num_epochs=1, batch_size=64, shuffle_buffer_size=1)  # No shuffling.

 emnist_preprocessing = tff.simulation.baselines.emnist.emnist_preprocessing
 train_preprocess_fn = emnist_preprocessing.create_preprocess_fn(
        train_client_spec, emnist_task='character_recognition')
 eval_preprocess_fn = emnist_preprocessing.create_preprocess_fn(
        eval_client_spec, emnist_task='character_recognition')
 
 @tff.tf_computation(tf.string)
 def train_dataset_computation(client_id):
    client_train_data = emnist_train.dataset_computation(client_id)
    return train_preprocess_fn(client_train_data)



 task_datasets = tff.simulation.baselines.task_data.BaselineTaskDatasets(
        train_data=emnist_train,
        test_data=emnist_test,
        validation_data=None,
        train_preprocess_fn=train_preprocess_fn,
        eval_preprocess_fn=eval_preprocess_fn)

 def emnist_model_fn():
    return tff.learning.from_keras_model(
          keras_model=create_1m_cnn_model(),
          loss=tf.keras.losses.SparseCategoricalCrossentropy(),
          input_spec=task_datasets.element_type_structure
         )
 
 task = tff.simulation.baselines.baseline_task.BaselineTask(
                 task_datasets, emnist_model_fn)

 train_set = task.datasets.train_data.preprocess(
      task.datasets.train_preprocess_fn)


 test_set = task.datasets.get_centralized_test_data()
 validation_set = test_set.take(FLAGS.num_validation_examples)
 federated_eval = tff.learning.build_federated_evaluation(task.model_fn)
 

 #Parameters:
 compression_dict = utils_impl.lookup_flag_values(compression_flags)
 dp_dict = utils_impl.lookup_flag_values(dp_flags)
 model_trainable_variables = task.model_fn().trainable_variables

 #Optimizer function:
 client_optimizer_fn = functools.partial(
      _client_optimizer_fn,
      name=FLAGS.client_optimizer,
      learning_rate=FLAGS.client_lr)  

# Aggregation factory:
 sqrt_num_parts = FLAGS.total_epochs ** 0.5
 aggregator_factory = aggregator_builder.build_aggregator(
      aggregator_method=_AGGREGATOR_METHOD.value,
      model_fn=task.model_fn,
      clip_norm=FLAGS.clip_norm,
      noise_multiplier=FLAGS.noise_multiplier,
      clients_per_round=FLAGS.clients_per_round,
      num_rounds=FLAGS.total_rounds,
      noise_seed=None,
      momentum=FLAGS.server_momentum,
      compression_flags=compression_dict,
      client_template=model_trainable_variables,
      dp_flags=dp_dict,
      sqrt_num_parts=sqrt_num_parts,
      lr_momentum_matrix_name=FLAGS.lr_momentum_matrix_name)

 if _AGGREGATOR_METHOD.value in ['opt_momentum_matrix', 'lr_momentum_matrix']:
   server_optimizer_momentum_value = 0
 else:
   server_optimizer_momentum_value = FLAGS.server_momentum

 iterative_process = dp_fedavg.build_dpftrl_fedavg_process(
      emnist_model_fn,
      client_optimizer_fn=client_optimizer_fn,
      server_learning_rate=FLAGS.server_lr,
      server_momentum=server_optimizer_momentum_value,
      server_nesterov=False,
      use_experimental_simulation_loop=True,
      dp_aggregator_factory=aggregator_factory,
  )

 
 def evaluate_fn(model_weights, dataset):
     print("verify the loop")
     keras_model = create_1m_cnn_model()
     model_weights.assign_weights_to(keras_model)
     print("check this: ", model_weights)
     test_metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
     metrics = dp_fedavg.keras_evaluate(keras_model, dataset, test_metrics)
     return collections.OrderedDict(
                    (test_metric.name, metric.numpy())
            for test_metric, metric in zip(test_metrics, metrics))
     

  
 return iterative_process,test_set,validation_set,train_set,evaluate_fn,train_dataset_computation


def train_and_eval():
  
  logging.info('Show FLAGS for debugging:')
  for f in HPARAM_FLAGS:
    logging.info('%s=%s', f, FLAGS[f].value)

  hparam_dict = collections.OrderedDict([
      (name, FLAGS[name].value) for name in HPARAM_FLAGS
  ])

  iterative_process,test_set,validation_set,train_set,evaluate_fn,train_dataset_computation= _build_tff_learning_model_and_process()
  iterative_process = tff.simulation.compose_dataset_computation_with_learning_process(
      dataset_computation=train_dataset_computation, process=iterative_process)

  if FLAGS.total_epochs is None:  # None or 0
    rng = random.Random(_CLIENT_SELECTION_SEED.value)

    def client_dataset_ids_fn(round_num: int):
      del round_num
      return rng.sample(train_set.client_ids, FLAGS.clients_per_round), 0

    logging.info(
        'Sampling %s clients independently each round for max %d rounds',
        FLAGS.clients_per_round,
        FLAGS.total_round,
    )
    total_epochs = 0
  else:
    client_dataset_ids_fn = training_loop.ClientIDShufflerMulti(
        FLAGS.clients_per_round,
        train_set.client_ids,
        reshuffle_each_epoch=_RESHUFFLE_EACH_EPOCH.value,
        seed=_CLIENT_SELECTION_SEED.value,
    )
    logging.info(
        'Shuffle clients within epoch for max %d epochs and %d rounds',
        FLAGS.total_epochs,
        FLAGS.total_rounds,
    )
    total_epochs = FLAGS.total_epochs

  
  training_loop.run(
      iterative_process,
      client_dataset_ids_fn,
      validation_fn=functools.partial(evaluate_fn, dataset=validation_set),
      total_epochs=total_epochs,
      total_rounds=FLAGS.total_rounds,
      experiment_name=FLAGS.experiment_name,
      train_eval_fn=None,
      test_fn=functools.partial(evaluate_fn, dataset=test_set),
      root_output_dir=FLAGS.root_output_dir,
      hparam_dict=hparam_dict,
      rounds_per_eval=FLAGS.rounds_per_eval,
      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
      rounds_per_train_eval=2000)

def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))
  gpus = tf.config.experimental.list_physical_devices('GPU')
  for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

  train_and_eval()

if __name__ == '__main__':
  app.run(main)
