"""Command line interface to generate, solve, and store linear ODE systems."""
from typing import Dict

import hashlib
import pathlib

import jax
import numpy as np
import pandas as pd
from absl import app, flags, logging

import metrics
import systems


# ---------------------- System parameters
flags.DEFINE_integer('num_systems', 10, 'Number of systems to generate')
flags.DEFINE_integer('dim', 7, 'Dimension of the system')
flags.DEFINE_float('frac_zeros', 0.5, 'Fraction of zeros in the matrix')
flags.DEFINE_float('epsilon', 0.1, 'Minimum absolute value of the entries')
flags.DEFINE_bool('check_zero_rows', False, 'Check that the matrix has no zero rows')
flags.DEFINE_integer('num_iv_per_system', 3, 'Number of initial conditions per system')
flags.DEFINE_bool('unit_norm', True, 'Normalize the initial conditions to unit norm')
flags.DEFINE_integer('seed', 0, 'Random seed')
flags.DEFINE_integer('asymmetric', 1, 'asymmetric')
flags.DEFINE_string('method_A', 'independent', 'Class of generation of A')

# ---------------------- Solver parameters
flags.DEFINE_integer('steps', 512, 'Number of steps to take')
flags.DEFINE_float('t_lo', 0.0, 'Start time.')
flags.DEFINE_float('t_hi', 1.0, 'End time.') 
# ---------------------- Metrics parameters
flags.DEFINE_float('metrics_tol', 1e-6, 'Tolerance for computing metrics.')
flags.DEFINE_float('rank_tol', None, 'Tolerance for computing metrics.')
flags.DEFINE_bool('slow_sigma', True, 'Compute sigma_xx more precisely, but also more slowly.')
# ---------------------- Input/Output parameters
flags.DEFINE_string('output_dir', '../data', 'Directory to save the data')
flags.DEFINE_string('summary_sheet', '../data/summary.csv', 'Directory to save the data')
# ---------------------- Misc
flags.DEFINE_bool('force', False, 'Recompute data even if it exists')
flags.DEFINE_bool('enable_x64', False, 'jax_enable_x64')

jax.config.parse_flags_with_absl()

FLAGS = flags.FLAGS

setting_params = ['dim', 'frac_zeros', 'epsilon', 'check_zero_rows', 'unit_norm','method_A', 'rank_tol', 'asymmetric']
quantity_params = ['num_systems', 'num_iv_per_system', 'steps', 'seed']


def get_hash(params: Dict):
  dict_str = str(params)
  md5 = hashlib.md5()
  md5.update(dict_str.encode('utf-8'))
  return md5.hexdigest()


def main(_):
  jax.config.update('jax_enable_x64', FLAGS.enable_x64)
  FLAGS.alsologtostderr = True

  summary_sheet = pathlib.Path(FLAGS.summary_sheet)
  values = {k: v for k, v in FLAGS.flag_values_dict().items()
            if k in setting_params + quantity_params}
  logging.info(f'Check if current specification {values} already exists...')

  if not summary_sheet.exists():
    df = pd.DataFrame(columns=setting_params + quantity_params + ['hash', 'finished'])
    logging.info(f'Creating dataframe...')
  else:
    df = pd.read_csv(summary_sheet)

  cur_hash = get_hash(values)
  if cur_hash in df['hash'].values and not FLAGS.force:
    logging.info('Specification already exists. Use `force` arg to recompute.')
    return

  values['hash'] = cur_hash
  values['finished'] = False

  
  df = pd.concat([df, pd.DataFrame([values])], ignore_index=True)
  df.to_csv(summary_sheet, index=False)
  logging.info(f"Added the specification {values['hash']} to the result sheet.")

  output_dir = pathlib.Path(FLAGS.output_dir)
  if not output_dir.exists():
    output_dir.mkdir(parents=True)
  outpath = output_dir / f"{values['hash']}.h5"

  logging.info(f"Save all output to {outpath}...")

  FLAGS.log_dir = pathlib.Path(FLAGS.output_dir)
  logging.get_absl_handler().use_absl_log_file(program_name="generate_data")

  logging.info(f"Set random seed {FLAGS.seed}...")
  np.random.seed(FLAGS.seed)
  key = jax.random.PRNGKey(FLAGS.seed)


    
  logging.info(f"Generate {FLAGS.num_systems} and {FLAGS.num_iv_per_system} initial conditions per system with:")
  logging.info(f"\tdim: {FLAGS.dim}")
  logging.info(f"\tfrac_zeros: {FLAGS.frac_zeros}")
  logging.info(f"\tepsilon: {FLAGS.epsilon}")
  logging.info(f"\tcheck_zero_rows: {FLAGS.check_zero_rows}")
  logging.info(f"\tunit_norm: {FLAGS.unit_norm}")
  logging.info(f"\tasymmetric: {FLAGS.asymmetric}")
  logging.info(f"\trank_tol: {FLAGS.rank_tol}")


  # Generate and store systems and initial conditions
  A, x0 = systems.candidate_problems(
    key, FLAGS.dim, FLAGS.frac_zeros, FLAGS.num_systems,
    FLAGS.num_iv_per_system, FLAGS.asymmetric, FLAGS.epsilon, FLAGS.method_A, FLAGS.check_zero_rows, FLAGS.unit_norm)
  logging.info(f"Store systems and initial conditions to {outpath}...")
  systems.store_systems_and_iv(outpath, A, x0)

  # Solve all IV problems and store solutions
  logging.info(f"Solve all {FLAGS.num_systems * FLAGS.num_iv_per_system} problems for {FLAGS.steps} steps...")
  xs, tt = systems.solve_system(A, x0, FLAGS.steps, (FLAGS.t_lo, FLAGS.t_hi))
  logging.info(f"Store systems and initial conditions to {outpath}...")
  systems.store_solutions(outpath, xs, tt)

  # Compute and store system level metrics
  logging.info(f"Compute system level identifiability metrics...")
  sys_metrics = metrics.system_level_identifiability(A, x0, FLAGS.rank_tol, FLAGS.metrics_tol)
  logging.info(f"Store system level identifiability metrics to {outpath}...")
  metrics.store_metrics(outpath, sys_metrics, level='system')

  # Compute and store data level metrics
  logging.info(f"Compute data level identifiability metrics...")
  data_metrics = metrics.data_level_identifiability(xs, tt, FLAGS.metrics_tol, FLAGS.slow_sigma)
  logging.info(f"Store data level identifiability metrics to {outpath}...")
  metrics.store_metrics(outpath, data_metrics, level='data')

  # Updating summary sheet to indicate that everything finished correctly
  df.loc[df['hash'] == cur_hash, 'finished'] = True
  df.to_csv(summary_sheet, index=False)

  logging.info(f"DONE")


if __name__ == '__main__':
    app.run(main)
