class ClassicalShadow:
  def __init__(self, circuit, num_of_qubits):
      self.backend = AerSimulator()
      self.circuit = circuit
      self.num_of_qubits = num_of_qubits
      self.Zero = np.matrix([[1], [0]])
      self.One = np.matrix([[0], [1]])

      I = np.matrix([[1, 0], [0, 1]])
      H = 1/np.sqrt(2) * np.matrix([[1, 1], [1, -1]])
      S = np.matrix([[1,0],[0,1j]])

      self.U = [I, H, S.getH() @ H]

  def run(self,qc, shots = 1):
      # First we have to transpile the quantum circuit
      # to the low-level QASM instructions used by the
      # backend
      qc_compiled = transpile(qc, self.backend)

      # Execute the circuit on the qasm simulator.
      # We've set the number of repeats of the circuit
      # to be 1
      job_sim = self.backend.run(qc_compiled, shots=shots)

      if (shots > 1):
          # debug mode
          print("Debug Mode")
          result_sim = job_sim.result().get_counts(qc_compiled)
          return None

      else:
          # Grab the results from the job. (Computational basis)
          result_sim = str(list(job_sim.result().get_counts(qc_compiled).keys())[0])

          """
              result_sim is the measurment of the quantum system (n-qubits)
              Due to the weird represantation of qiskit, the result looks like

                                  b_{n-1}, ..., b_0

              where, b_i is the readout of ith qubit
          """

          ret = [] # ret = [b_0, b_1, ..., b_{n-1}]
          for i in result_sim:
              if i == '0':
                  ret = [self.Zero] + ret
              else:
                  ret = [self.One] + ret

          return result_sim, ret

  def approx(self):
      """
          unknown_state is a qiskit circuit
          size: number of qubit
      """
      # Step 1: apply unitary to each qubit
      selected_U = []
      for i in range(self.num_of_qubits):
          # randomly select an integer in {0,1,2}
          idx = np.random.randint(300) % 3 # [lower, upper), upper is exclusive
          selected_U += [self.U[idx]] # store the selection
          # print("select: ", "Qubit: ", i, "unitray id: ", idx)
          if idx == 0:
              self.circuit.id(i) # apply idenity
          elif idx == 1:
              self.circuit.h(i) # Hadamard
          else:
              """
                  evolution of a density operator is
                          U \rho U^\\dagger

                  But we want
                          U^dagger \rho U
              """
              self.circuit.s(i)
              self.circuit.h(i)

              # unknown_state.h(i)
              # unknown_state.s(i) # apply (S^\\dagger H)^\\dagger

      # Step 2: Measure the system
      self.circuit.measure_all()
      # Step 3: Run the quantum circuit
      result, bhat = self.run(self.circuit)
      # print("results:", result,  "\n", bhat)

      # Step 4: Compute the M_inverse for each qubit
      rho_hats = []
      for i in range(self.num_of_qubits):
          rho_hats += [self.M_inverse(selected_U[i], bhat[i])]

      # Step 5: Compute rho_hat = \otimes_{i=1}^d \rho_hat_i
      rho_hat = reduce(np.kron, rho_hats)

      return rho_hat # approximated state

  def approx_trace(self, num_runs=1):
      """
      Approximate the trace of the density matrix from randomized measurements.
      unknown_state: A Qiskit quantum circuit
      size: number of qubits
      num_runs: number of times to run the randomized measurement process
      """
      total_trace_estimate = 0
      
      for _ in range(num_runs):
          # Copy the circuit for each run so that we start fresh each time
          qc = self.circuit.copy()

          # Step 1: Apply random unitary to each qubit
          selected_U = []
          for i in range(self.num_of_qubits):
              # Randomly select an integer in {0,1,2
              idx = np.random.randint(300) % 3  # Randomly select from {0,1,2}
              selected_U.append(self.U[idx])  # Store the random unitary
              
              # Apply the corresponding gate to the quantum state
              if idx == 0:
                  qc.id(i)
              elif idx == 1:
                  qc.h(i)
              else:
                  qc.s(i)
                  qc.h(i)

          # Step 2: Measure the system
          qc.measure_all()

          # Step 3: Run the quantum circuit and get measurement results
          result, bhat = self.run(qc)

          # Step 4: Estimate the trace from measurement results
          trace_estimate = 0
          trace_estimate += self.M_inverse_trace(selected_U, bhat)
          
          # Accumulate the trace estimate from this run
          total_trace_estimate += trace_estimate

      # Step 5: Average the trace estimate across all runs
      avg_trace_estimate = total_trace_estimate / num_runs
      return np.real(avg_trace_estimate)

  def M_inverse(self,U, b):
    v = U @ b
    d = np.array(v).flatten()
    rho_hat11 = 2.0 * d[0] * np.conjugate(d[0]) - 1.0 * d[1] * np.conjugate(d[1])
    rho_hat21 = 3.0 * d[1] * np.conjugate(d[0])
    rho_hat12 = 3.0 * d[0] * np.conjugate(d[1])
    rho_hat22 = 2.0 * d[1] * np.conjugate(d[1]) - 1.0 * d[0] * np.conjugate(d[0])
    return np.matrix([[rho_hat11, rho_hat12], [rho_hat21, rho_hat22]])

  def M_inverse_trace(self,U, measurement):
      """Calculate the contribution to the trace from a multi-qubit measurement."""
      trace_contribution = 1  # init
      
      for i, outcome in enumerate(measurement):
          # For each measurement outcome, apply the unitary
          unitary = U[i]
          # Apply the unitary to the outcome
          d = np.dot(unitary, outcome)
          
          # flatten the array
          transformed_state = np.array(d).flatten()
          
          # Calculate the contribution to the trace
          # The trace contribution is the squared norm of the transformed state
          # using Trace (A \tensor B) = Trace(A) \tensor Trace(B) 
          trace_contribution *= (transformed_state[0] * np.conjugate(transformed_state[0]) +
                                  transformed_state[1] * np.conjugate(transformed_state[1]))
      return trace_contribution

