import numpy as np
import os
import math

from utils.distance_jax import *
from data.data_utils import *
from model.mpe_jax import *
from data.mol_data import *

def generate_data(input_type, dat_name, n_qubits, n_train, n_test, rseed, g_range=[0.0, 1.0], n_atoms=None, n_rings=None):
  n_full_data = 10 * max(n_train, n_test)
  # Examine the input data
  if input_type == 'line':
    train_input_states = gen_line_data_with_jax(nstates = n_train, nqubits = n_qubits, seed = rseed + 27)
    test_input_states = gen_line_data_with_jax(nstates = n_test, nqubits = n_qubits, seed = rseed + 2728)
  elif input_type == 'circle':
    train_input_states = gen_circle_data_with_jax(nstates = n_train, nqubits = n_qubits, seed = rseed + 27)
    test_input_states = gen_circle_data_with_jax(nstates = n_test, nqubits = n_qubits, seed = rseed + 2728)
  elif input_type == 'product':
    train_input_states = gen_Haar_product_states(nstates = n_train, nqubits = n_qubits, seed = rseed + 27)
    test_input_states = gen_Haar_product_states(nstates = n_test, nqubits = n_qubits, seed = rseed + 2728)
  elif input_type == 'diffusion':
    # random basic states then we will apply the projected ensenmble framework later
    train_input_states = gen_rand_basis_states(nstates = n_train, nqubits = n_qubits, seed = rseed + 27)
    test_input_states = gen_rand_basis_states(nstates = n_test, nqubits = n_qubits, seed = rseed + 2728)
  else:
    # Random input
    train_input_states = gen_Haar_states(nstates = n_train, nqubits = n_qubits, seed = rseed + 27)
    test_input_states = gen_Haar_states(nstates = n_test, nqubits = n_qubits, seed = rseed + 2728)
  
  if dat_name == 'cluster0':
    real_states = gen_cluster_0(nstates = n_full_data, nqubits = n_qubits, scale = 0.06, seed = rseed + 72)
  elif dat_name == 'multi_cluster':
    real_states = generate_multi_clustered_states(n_qubits, seed=rseed + 72, N=n_full_data, scale=0.05)
  elif dat_name == 'line':
    real_states = gen_line_data_with_jax(nstates = n_full_data, nqubits = n_qubits)
  elif dat_name == 'circle':
    real_states = gen_circle_data_with_jax(nstates = n_full_data, nqubits = n_qubits)
  elif dat_name == 'tfim':
    real_states = gen_tfim_ground_states_qt(N = n_full_data, g_range=g_range, n = n_qubits, seed = rseed + 72, use_dmrg=False)
  elif dat_name == 'qm9':
    full_dataset = QDrugDataset(dat_name, n_qubits, load_from_cache=True, file_path='../datasets/mol/', n_atoms=n_atoms)
    dataset, _ = filter_dataset_by_properties(full_dataset, target_num_rings=n_rings, target_n_atoms=n_atoms, min_mol_weight=None, max_mol_weight=None)
    real_states = dataset[:n_full_data, :-1]  # Exclude the last column which is not part of the quantum state
    #print(f"Info dataset: {dataset.info[-5:]}")
  else:
    raise NameError(f"Data name {dat_name} used does not match cluster0 or line")
  return real_states, train_input_states, test_input_states


def distance_evolution(dist_file, X0, Xout, eval_mode=False):
  if os.path.isfile(dist_file) == True:
    dists = np.load(dist_file)
  else:
    T1, Nfull = Xout.shape[0], Xout.shape[1]
    #print(T1, Nfull)

    # Sample N points from the full data
    N = X0.shape[0]
    
    mmd = np.zeros(T1)
    wass = np.zeros(T1)
    vendi = np.zeros(T1)

    for t in range(T1):
      if eval_mode == False:
        idx = np.random.choice(Nfull, N, replace=False)
        Xt = Xout[t, idx]
      else:
        Xt = Xout[t, :]
      mmd[t] = natural_distance_jax(X0, Xt)
      wass[t] = wass_distance_jax(X0, Xt)
      vendi[t] = vendi_score(Xt)
      
      #wass[t] = wass_distance_ott(X0, Xt, epsilon=0.1)
      #print(f'Step {t}, MMD: {mmd[t]}, WASS: {wass[t]}')
    
    dists = np.vstack((mmd, wass, vendi))
    np.save(dist_file, dists)
  
  # Plot dist
  lw = 3
  mkz = 8
  putils.setPlot(fontsize=30, labelsize=30, lw=lw)
  fig, axs = plt.subplots(1, 1, figsize=(12,8), squeeze=False)
  axs = axs.ravel()
  ax = axs[0]
  putils.set_axes_facecolor(axs)
  ax.set_title(os.path.basename(dist_file), fontsize=12)
  ax.plot(dists[0], 'o--', mfc='white', markersize=mkz, c=putils.RED_m, lw=lw,
          label=r'$\mathcal{D}_{\rm MMD}(\mathcal{S}_t,\mathcal{S}^\prime_0)$')
  ax.plot(dists[1], 'o--', mfc='white', markersize=mkz, c=putils.BLUE_m, lw=lw,
          label=r'$W(\mathcal{S}_t,\mathcal{S}^\prime_0)$')

  #ax.legend(fontsize=20, framealpha=0, ncol=2, columnspacing=0.4, loc='upper left', bbox_to_anchor=(-0.1, 1.35))
  ax.set_yscale('log')
  #ax.tick_params(direction='in', length=10, width=3, top='on', right='on', labelsize=30)
  putils.set_axes_tick1([ax], xlabel='$t$', ylabel='Dist.', legend=True, tick_minor=False, top_right_spine=True, w=3, tick_length_unit=5)
  fig_file = dist_file.replace('.npy', '')
  plt.tight_layout()
  for ftype in ['pdf']:
      plt.savefig('{}.{}'.format(fig_file, ftype), bbox_inches = 'tight', dpi=300)
  plt.show()
  plt.clf()
  plt.close()


def eval_MPE(save_file, model, real_states, input_states, params_cul, plot_bloch=False):
  """
  Eval the forward process
  Args:
    model: MPE model
    input_states: input state of the system
  """
  forward_data = model.forward_gen_states(input_states, params_cul)[:, :, :2**model.n_qubits]
  distance_evolution(f'{save_file}_DIST.npy', real_states, forward_data, eval_mode=True)

  if plot_bloch > 0:
    T = forward_data.shape[0]
    for t in range(T):
      Xt = forward_data[t]
      plot_Bloch_sphere(f'{save_file}_t_{t}', Xt, f'Forward states step={t}')


def plot_loss_hist_all(save_file, loss_hist_all, dist_hist_all, vendi_hist_all):
    putils.setPlot(fontsize=26, labelsize=26, lw=2)
    fig, axs = plt.subplots(3, 1, figsize=(24, 18), sharex=True, sharey=False, squeeze=False)
    axs = axs.ravel()
    ax, bx, cx = axs[0], axs[1], axs[2]
    T = loss_hist_all.shape[0]
    tmp = 0
    for i in range(T):
      xs = np.arange(loss_hist_all[i].shape[0])
      xs += tmp
      if max(loss_hist_all[i]) == 0.0:
        break 
      ax.plot(xs, loss_hist_all[i], lw=2, color=putils.BLUE_m, alpha=0.7)
      bx.plot(xs, dist_hist_all[i], lw=2, color=putils.VERMILLION_m, alpha=0.7)
      cx.plot(xs, vendi_hist_all[i], lw=2, color=putils.GREEN_m, alpha=0.7)
      tmp += loss_hist_all[i].shape[0]
    
    putils.set_axes_tick1(axs, xlabel=r'$\rm Epochs$', ylabel='Loss', legend=False, tick_minor=True, top_right_spine=True, w=3, tick_length_unit=5)
    ax.set_ylabel('Total loss', fontsize=30)
    bx.set_ylabel('Distance loss', fontsize=30)
    cx.set_ylabel('Vendi loss', fontsize=30)
    plt.tight_layout()
    for ftype in ['pdf']:
      plt.savefig('{}.{}'.format(f'{save_file}_LOSS', ftype), bbox_inches = 'tight', dpi=300)
    plt.show()
    plt.clf()
    plt.close()


