import numpy as np
import warnings

class ComplexWarning(Warning):
    """Re-create the old NumPy ComplexWarning class."""
    pass

np.ComplexWarning = ComplexWarning
warnings.simplefilter('ignore', ComplexWarning)

import os
import argparse

from model.mpe_utils import generate_data
from model.train_utils import train_MPE
from model.mpe_jax import MPE
from utils.loginit import get_module_logger


def check_state_norms(real_states: np.ndarray, tolerance: float = 1e-10) -> tuple:
    """Check if each row in real_states has a norm of 1, suitable for a quantum state.

    Args:
        real_states (np.ndarray): Array of shape (n_molecules, 2^n_qbits) containing feature vectors.
        tolerance (float): Tolerance for checking if norm is 1 (default: 1e-10).

    Returns:
        tuple: (is_valid, norms, normalized_states)
            - is_valid (list): Boolean list indicating if each row's norm is 1 (within tolerance).
            - norms (np.ndarray): Array of L2 norms for each row.
            - normalized_states (np.ndarray): Array with each row normalized to norm=1.
    """
    n_molecules = real_states.shape[0]
    norms = np.zeros(n_molecules)
    is_valid = []
    normalized_states = np.zeros_like(real_states)

    for i in range(n_molecules):
        try:
            vector = real_states[i]
            norm = np.linalg.norm(vector)
            norms[i] = norm
            is_valid.append(abs(norm - 1.0) < tolerance)
            
            # Normalize the vector to ensure norm=1
            # if norm == 0:
            #     logger.warning(f"Molecule {i} has zero norm; assigning uniform vector")
            #     normalized_states[i] = np.ones_like(vector) / np.sqrt(len(vector))
            # else:
            #     normalized_states[i] = vector / norm
            
            # logger.debug(f"Molecule {i} Norm={norm:.6f}, Valid={is_valid[-1]}")
        except Exception as e:
            logger.warning(f"Error processing molecule {i}: {str(e)}")
            is_valid.append(False)
            normalized_states[i] = np.zeros_like(vector)
    
    return is_valid, norms, normalized_states


if __name__ == "__main__":
  # Check for command line arguments
  parser = argparse.ArgumentParser()
  parser.add_argument('--save_dir', type=str, default='../results/del_qdm_test1')
  parser.add_argument('--load_params', type=str, default='params_file')

  parser.add_argument('--n_layers', type=int, default=10, help='number layers for forward circuit')
  parser.add_argument('--n_diff_steps', type=int, default=10, help='number steps for diffusion')

  # For data
  parser.add_argument('--dat_name', type=str, default='qm9', help='name of the data: qm9')
  parser.add_argument('--n_atoms', type=int, default=8, help='Number of atoms to filter molecules')
  parser.add_argument('--n_rings', type=int, default=1, help='Number of rings to filter molecules')
  
  parser.add_argument('--input_type', type=str, default='rand', help='type of the input')
  parser.add_argument('--n_qubits', type=int, default=7, help='Number of data qubits')
  parser.add_argument('--n_ancilla', type=int, default=3, help='Number of ancilla qubits')
  
  parser.add_argument('--n_train', type=int, default=100, help='Number of training data')
  parser.add_argument('--n_test', type=int, default=100, help='Number of test data')

  # For  training
  parser.add_argument('--n_outer_epochs', type=int, default=10, help='Number of outer training epoch')
  parser.add_argument('--batch_size', type=int, default=100, help='Batch size for training')
  parser.add_argument('--round_epochs', type=int, default=10, help='Number of epochs to round')
  parser.add_argument('--dist_type', type=str, default='wass', help='Type of distance: wass or mmd')  

  parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate')
  parser.add_argument('--mag', type=float, default=0.001, help='Magnitude of initial parameters')
  parser.add_argument('--vendi_lambda', type=float, default=0.0, help='Vendi loss lambda')
  #parser.add_argument('--ep', type=float, default=0.01, help='Step size')

  # For gen circuit type
  parser.add_argument('--gen_circuit_type', type=str, default='ryzcz', help='type of generator circuit')

  # For system
  parser.add_argument('--rseed', type=int, default=0, help='Random seed')
  parser.add_argument('--bloch', type=int, default=0, help='Plot bloch')
  parser.add_argument('--threads', type=int, default=1, help='Number of threads')

  args = parser.parse_args()

  save_dir, n_layers, n_diff_steps = args.save_dir, args.n_layers, args.n_diff_steps
  n_train, n_test, n_outer_epochs = args.n_train, args.n_test, args.n_outer_epochs
  dat_name, input_type, n_qubits, n_ancilla, rseed = args.dat_name, args.input_type, args.n_qubits, args.n_ancilla, args.rseed
  plot_bloch, load_params = args.bloch, args.load_params

  batch_size, round_epochs = args.batch_size, args.round_epochs
  gen_circuit_type,n_threads = args.gen_circuit_type, args.threads

  lr, mag, dist_type, vendi_lambda = args.lr, args.mag, args.dist_type, args.vendi_lambda
  n_atoms, n_rings = args.n_atoms, args.n_rings

  # Create folder to save results
  log_dir = os.path.join(save_dir,'log' )
  res_dir = os.path.join(save_dir,'res')

  os.makedirs(log_dir, exist_ok=True)
  os.makedirs(res_dir, exist_ok=True)


  basename = f'{dat_name}_atoms_{n_atoms}_rings_{n_rings}_{gen_circuit_type}_qubits_{n_qubits}_{n_ancilla}_steps_{n_diff_steps}_{dist_type}_lays_{n_layers}_in_{input_type}_dat_{n_train}_{n_test}_epoch_{n_outer_epochs}_{batch_size}_lr_{lr}_init_{mag}_vd_{vendi_lambda}_seed_{rseed}'
  
  log_filename = os.path.join(log_dir, f'{basename}.log')
  logger = get_module_logger(__name__, log_filename, level='info')
  
  logger.info(log_filename)
  logger.info(args)
  
  # set random seed
  np.random.seed(rseed)

  # Create data for inputs and training
  if dat_name == 'qm9':
    real_states, train_input_states, test_input_states = generate_data(input_type, dat_name, n_qubits, n_train, n_test, rseed, n_atoms=n_atoms, n_rings=n_rings)
    # Check norms
    is_valid, norms, normalized_states = check_state_norms(real_states)
    valid_count = sum(is_valid)
    logger.info(f"Checked {len(real_states)} molecules: {valid_count} have norm=1 (within tolerance)")
  else:
    raise ValueError(f"Unknown dataset name: {dat_name}")
  
  print('Shapes', real_states.shape, train_input_states.shape)
  save_file = os.path.join(res_dir, f'{basename}')

  model = MPE(n_qubits=n_qubits, n_ancilla=n_ancilla, T = n_diff_steps, n_layers=n_layers, forward_circuit_type=gen_circuit_type, rseed=rseed+2810)

  train_MPE(logger, model, save_file, real_states, train_input_states, test_input_states,\
              n_outer_epochs, lr, rseed+1234, plot_bloch, dist_type, batch_size=batch_size, round_epochs=round_epochs, mag=mag, vendi_lambda=vendi_lambda)
  
  
