import numpy as np
import warnings

class ComplexWarning(Warning):
    """Re-create the old NumPy ComplexWarning class."""
    pass

np.ComplexWarning = ComplexWarning
warnings.simplefilter('ignore', ComplexWarning)

import os
import argparse

from model.mpe_utils import generate_data
from model.train_utils import train_MPE
from model.mpe_jax import MPE
from utils.loginit import get_module_logger

if __name__ == "__main__":
  # Check for command line arguments
  parser = argparse.ArgumentParser()
  parser.add_argument('--save_dir', type=str, default='../results/del_qdm_test1')
  parser.add_argument('--load_params', type=str, default='params_file')
  
  parser.add_argument('--n_layers', type=int, default=10, help='number layers for forward circuit')
  parser.add_argument('--n_diff_steps', type=int, default=10, help='number steps for diffusion')

  # For data
  parser.add_argument('--dat_name', type=str, default='circle', help='name of the data: cluster0, line, circle, tfim')
  parser.add_argument('--input_type', type=str, default='rand', help='type of the input')
  parser.add_argument('--n_qubits', type=int, default=4, help='Number of data qubits')
  parser.add_argument('--n_ancilla', type=int, default=4, help='Number of ancilla qubits')
  
  parser.add_argument('--n_train', type=int, default=10, help='Number of training data')
  parser.add_argument('--n_test', type=int, default=10, help='Number of test data')

  # For  training
  parser.add_argument('--n_outer_epochs', type=int, default=1000, help='Number of outer training epoch')
  parser.add_argument('--batch_size', type=int, default=100, help='Batch size for training')
  parser.add_argument('--round_epochs', type=int, default=10, help='Number of epochs to round')
  parser.add_argument('--dist_type', type=str, default='wass', help='Type of distance: wass or mmd')  


  # For update matrix
  parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
  parser.add_argument('--mag', type=float, default=1.0, help='Magnitude of initial parameters')
  parser.add_argument('--vendi_lambda', type=float, default=0.0, help='Vendi loss lambda')

  #parser.add_argument('--ep', type=float, default=0.01, help='Step size')

  # For gen circuit type
  parser.add_argument('--gen_circuit_type', type=str, default='rxycz', help='type of generator circuit')

  # For system
  parser.add_argument('--rseed', type=int, default=0, help='Random seed')
  parser.add_argument('--bloch', type=int, default=0, help='Plot bloch')
  parser.add_argument('--threads', type=int, default=1, help='Number of threads')

  args = parser.parse_args()

  save_dir, n_layers, n_diff_steps = args.save_dir, args.n_layers, args.n_diff_steps
  n_train, n_test, n_outer_epochs = args.n_train, args.n_test, args.n_outer_epochs
  dat_name, input_type, n_qubits, n_ancilla, rseed = args.dat_name, args.input_type, args.n_qubits, args.n_ancilla, args.rseed
  plot_bloch, load_params = args.bloch, args.load_params

  gen_circuit_type, n_threads = args.gen_circuit_type, args.threads
  batch_size, round_epochs = args.batch_size, args.round_epochs
  lr, mag, dist_type, vendi_lambda = args.lr, args.mag, args.dist_type, args.vendi_lambda

  # Create folder to save results
  log_dir = os.path.join(save_dir,'log' )
  res_dir = os.path.join(save_dir,'res')
  
  os.makedirs(log_dir, exist_ok=True)
  os.makedirs(res_dir, exist_ok=True)

  basename = f'{dat_name}_{gen_circuit_type}_qubits_{n_qubits}_{n_ancilla}_steps_{n_diff_steps}_{dist_type}_lays_{n_layers}_in_{input_type}_dat_{n_train}_{n_test}_epoch_{n_outer_epochs}_{batch_size}_lr_{lr}_init_{mag}_vd_{vendi_lambda}_seed_{rseed}'
  
  log_filename = os.path.join(log_dir, f'{basename}.log')
  logger = get_module_logger(__name__, log_filename, level='info')
  
  logger.info(log_filename)
  logger.info(args)
  
  # set random seed
  np.random.seed(rseed)

  # Create data for inputs and training
  real_states, train_input_states, test_input_states = generate_data(input_type, dat_name, n_qubits, n_train, n_test, rseed)
  print('Shapes', real_states.shape, train_input_states.shape)

  # Create data for inputs and training
  save_file = os.path.join(res_dir, f'{basename}')
  
  model = MPE(n_qubits=n_qubits, n_ancilla=n_ancilla, T = n_diff_steps, n_layers=n_layers, forward_circuit_type=gen_circuit_type, rseed=rseed+2810)
  
  train_MPE(logger, model, save_file, real_states, train_input_states, test_input_states,\
              n_outer_epochs, lr, rseed+1234, plot_bloch, dist_type, batch_size=batch_size, \
              round_epochs=round_epochs, mag=mag, vendi_lambda=vendi_lambda)
  
  
