import os
import argparse
import numpy as np
from tqdm import tqdm
import tensorflow as tf
from tensorflow.keras.models import load_model

from cyclegan.utils import h5 as h5
from cyclegan.utils.cascade.cascade2p import cascade, config

tf.get_logger().setLevel('ERROR')


def load_models(model_name: str = 'Global_EXC_25Hz_smoothing100ms'):
  cascade.download_model(model_name)
  model_path = os.path.join(
      os.path.dirname(__file__),
      'pretrained_models',
      model_name,
  )
  model_config = config.read_config(os.path.join(model_path, 'config.yaml'))
  model_dict = cascade.get_model_paths(model_path)
  models = {}
  for noise_level, model_paths in model_dict.items():
    models[noise_level] = [load_model(model_path) for model_path in model_paths]
  return models, model_config


def signals2rates(signals, models, model_config, desc: str = None):
  # convert traces to NCW (num. samples, num. neurons, time steps)
  signals = np.transpose(signals, axes=[0, 2, 1])
  spike_rates = np.zeros(shape=signals.shape, dtype=np.float32)
  for i in tqdm(range(signals.shape[0]), desc=desc, disable=desc is None):
    spike_rates[i] = cascade.predict(
        signals[i],
        models=models,
        batch_size=model_config['batch_size'],
        sampling_rate=model_config['sampling_rate'],
        before_frac=model_config['before_frac'],
        window_size=model_config['windowsize'],
        noise_levels_model=model_config['noise_levels'],
        smoothing=model_config['smoothing'],
    )
  # convert spikes to NWC (num. samples, time steps, num. neurons)
  return np.transpose(spike_rates, axes=[0, 2, 1])


def deconvolve_signals(signals,
                       model_name: str = 'Global_EXC_25Hz_smoothing100ms',
                       desc: str = None):
  tf.keras.backend.clear_session()
  models, model_config = load_models(model_name)
  return signals2rates(signals, models, model_config, desc=desc)


def deconvolve_file(signals_filename,
                    spikes_filename,
                    model_name: str = 'Global_EXC_25Hz_smoothing100ms'):
  if not os.path.exists(signals_filename):
    raise FileNotFoundError('{} not found'.format(signals_filename))

  if os.path.exists(spikes_filename):
    os.remove(spikes_filename)

  tf.keras.backend.clear_session()

  models, model_config = load_models(model_name)

  keys = h5.get_keys(filename=signals_filename)
  print(f'deconvolve file {signals_filename}\n')
  for data in ['x', 'y', 'fake_x', 'fake_y', 'cycle_x', 'cycle_y']:
    assert data in keys
    signals = h5.get(signals_filename, key=data)
    spike_rates = signals2rates(signals,
                                models,
                                model_config,
                                desc=f'deconvolve {data}')
    h5.write(spikes_filename, data={data: spike_rates})


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--signals_filename', type=str, required=True)
  parser.add_argument('--spikes_filename', type=str, required=True)
  args = parser.parse_args()
  deconvolve_file(signals_filename=args.signals_filename,
                  spikes_filename=args.spikes_filename)
