from functools import partial
from itertools import combinations

import jax
import jax.numpy as jnp
import jax.scipy.linalg as jsp

#import scipy as sp
#from scipy.stats import unitary_group

import tensorcircuit as tc
from utils.tc_utils import *
from utils.utils import *

from opt_einsum import contract

K = tc.set_backend('jax')
tc.set_dtype('complex64')

@partial(jax.jit, static_argnames=('n_ancilla','n_qubits'))
def _random_measure_jax(inputs: jnp.ndarray,
                       n_ancilla: int,
                       n_qubits: int,
                       key: jax.random.PRNGKey) -> jnp.ndarray:
    """
    inputs: [batch, 2**(n_ancilla + n_qubits)] complex amplitudes
    Returns: [batch, 2**n_qubits] post-measurement states (normalized).
    """
    batch = inputs.shape[0]

    # Compute |ψ|^2 over ancilla vs system
    probs = jnp.abs(inputs.reshape((batch, 2**n_ancilla, 2**n_qubits)))**2
    m_probs = jnp.sum(probs, axis=2)       # shape [batch, 2**n_ancilla]

    # Sample ancilla outcome
    rnd = jax.random.uniform(key, (batch,), minval=0.0, maxval=1.0)
    cum = jnp.cumsum(m_probs, axis=1)
    cum = cum / jnp.sum(m_probs, axis=1, keepdims=True)
    # first index where cum > rnd
    m_res = jnp.argmax(cum > rnd[:, None], axis=1).astype(jnp.int32)

    # Pick out the corresponding system amplitudes
    idx = (2**n_qubits) * m_res[:, None] + jnp.arange(2**n_qubits)
    post = jnp.take_along_axis(inputs, idx, axis=1)  # [batch, 2**n_qubits]

    norm = jnp.linalg.norm(post, axis=1, keepdims=True)
    return post / norm

@staticmethod
@partial(jax.jit, static_argnames=('n_ancilla','n_qubits'))
def _random_measure_pure(inputs, n_ancilla, n_qubits, subkey: jax.random.PRNGKey):
  # Take a PRNGKey in as an explicit argument and split insider for jit

  next_key, meas_key = jax.random.split(subkey)
  post_states = _random_measure_jax(inputs, n_ancilla, n_qubits, meas_key)
  return post_states, next_key

@staticmethod
@partial(jax.jit, static_argnames=('n_ancilla','n_qubits','forward_circuit_vmap'))
def forward_output_pure(inputs, params, 
        n_ancilla, n_qubits,
        forward_circuit_vmap: callable,
        subkey: jax.random.PRNGKey):
  """
    Backward denoising process at step t
    Args: 
      inputs: the input data set at step t
      params: the parameters of the forward circuit at step t
      subkey: PRNGKey for random measurement
  """
  # Outputs through quantum circuits before measurement
  output_full = forward_circuit_vmap(inputs, params)
  # Perform the ancilla measurement and get the post-measurement state
  output_t, next_key = _random_measure_pure(output_full, n_ancilla, n_qubits, subkey)
  return output_t, next_key

class MPE():
  def __init__(self, n_qubits, n_ancilla, T, n_layers, forward_circuit_type='rxycz', rseed=0):
    """
      n_qubits: number of data qubits
      n_ancilla: number of ancilla qubits
      T: number of steps
      n_layers: number of layers in forward circuit
      forward_circuit_type: type of forward circuit, 'rxycz' for rxycz circuit
    """
    self.n_qubits = n_qubits
    self.n_ancilla = n_ancilla
    self.n_tot = n_qubits + n_ancilla
    self.T = T
    self.n_layers = n_layers
    self.forward_circuit_type = forward_circuit_type

    # Modified parameter counting for forward circuit
    if self.forward_circuit_type == 'rxycz' or self.forward_circuit_type == 'ryzcz':
        self.forward_n_params = 2 * self.n_tot * self.n_layers
    elif self.forward_circuit_type == 'rxyzcz':
        self.forward_n_params = 3 * self.n_tot * self.n_layers
    elif self.forward_circuit_type == 'SU2-full':
        ent1_count = self.n_tot // 2
        ent2_count = (self.n_tot - 1) // 2
        params_per_layer = 3 * self.n_tot + 3 * (ent1_count + ent2_count)
        self.forward_n_params = params_per_layer * self.n_layers
    else:
        raise ValueError(f"Unsupported forward_circuit_type: {self.forward_circuit_type}")

    # Bind all static arguments via partial
    batched_circuit = partial(
        generator_circuit,
        total_qubits=self.n_tot, 
        n_layers=self.n_layers, 
        circuit_type=self.forward_circuit_type
    )

    vmapped_circuit = jax.vmap(batched_circuit, in_axes=(0, None))
    self.forward_circuit_vmap = jax.jit(vmapped_circuit)
  
    self._key = jax.random.PRNGKey(rseed)
  
  def set_forward_states_diff(self, forward_states_diff):
    self.forward_states_diff = jnp.asarray(forward_states_diff, dtype=jnp.complex64)

  def forward_output_t(self, inputs, params):
    """
      Forward denoising process at step t
      Args: 
        inputs: the input data set at step t
        params: the parameters of the forward circuit at step t
      
      Non‐jit wrapper that:
        1) pulls the stored key,
        2) calls the pure‐JIT version to get (output_t, new_key),
        3) writes back new_key into self._key, and
        4) returns output_t.
    """
    old_key = self._key
    output_t, new_key = forward_output_pure(inputs, params, self.n_ancilla, self.n_qubits, self.forward_circuit_vmap, old_key)
    self._key = new_key  # update the key for the next call
    return output_t
  
  def prepare_input_t(self, inputs_T, params_cul, t, n_dat):
    """
    Prepare the input samples for step t
    Args:
      inputs_T: the input state at the beginning 
      params_cul: all circuit parameters till step t
      t: step 
      n_dat: number of data
    Ouput: input samples (combined) and not combined input for step t
    """
    dim_anc = 2**self.n_tot - 2**self.n_qubits
    # Tensor with |0>^{n_ancilla} state and inputs_T is 
    # the concat of the input matrix with zeros (remained dimension)
    pad_zeros = jnp.zeros((n_dat, dim_anc), dtype=jnp.complex64)
    self.input_t_plus = jnp.concatenate([inputs_T, pad_zeros], axis=1)
    params_cul = jnp.asarray(params_cul, dtype=jnp.float32)
    output = inputs_T
    for step in range(t):
      output = self.forward_output_t(self.input_t_plus, params_cul[step])
      self.input_t_plus = jnp.concatenate([output, pad_zeros], axis=1)
    output = output.astype(jnp.complex64)
    return self.input_t_plus, output
  
  def forward_gen_states(self, inputs_T, params_tot):
    """
    Generate the dataset in forward denoising process
    with training data set
    Args:
      inputs_T: the input state at the beginning
      params_tol: all circuit parameters
    Ouput: generated states from inputs_T after all forward process
    """
    n_dat = len(inputs_T)
    states = [inputs_T]
    dim_anc = 2**self.n_tot - 2**self.n_qubits
    pad_zeros = jnp.zeros((n_dat, dim_anc), dtype=jnp.complex64)

    input_t_plus = jnp.concatenate([inputs_T, pad_zeros], axis=1)
    params_tot = jnp.asarray(params_tot, dtype=jnp.float32)
    for step in range(self.T):
      output = self.forward_output_t(input_t_plus, params_tot[step])
      input_t_plus = jnp.concatenate([output, pad_zeros], axis=1)
      states.append(output)
    states = jnp.stack(states)
    return states
  
