
import random
import time
import argparse
from os import mkdir
from os.path import isdir
import datetime
import pickle as pkl

import numpy as np
import matplotlib.pyplot as plt

from qiskit import QuantumCircuit
from qiskit_aer import AerSimulator
from qiskit.quantum_info import Operator

import os
from threading import Lock, Event

from skopt import gbrt_minimize, gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args

from joblib import Parallel, delayed
import time
import random

import contextlib
import io


RESULTS_FOLDER = "horz_results"

if not isdir(RESULTS_FOLDER):
   mkdir(RESULTS_FOLDER)


# Set experiment parameters
experiment_type = "random" # One of Toffoli, Adder, or Ampl_Embed
num_runs_per_init = 5 # number of runs/trials per qubit initialization / target state
total_calls = 4000
initial_calls = 500
num_shots = 10000 # number of shots to make when evaluating cost function
split_method = None
isDistributed = False
surrogate_choice = "gbrt"
cost_function = "TVD"
intentional_corr = "horz"

STARTING_INIT = 0
ENDING_INIT = 2

backend = AerSimulator()

if experiment_type == "Toffoli":
    isCircuit = True
    qasm_file = "modded_toffoli_n3.qasm"

    NUM_QUBITS = 3
    PARAMS_PER_LAYER = 2 * NUM_QUBITS
    NUM_LAYERS = 2
    NUM_INIT_QUBITS = 3
    NUM_INITS = 2**NUM_INIT_QUBITS # number of different qubit initializations to consider. For 3 qubits there are 8 basis states

    experiment_inits = np.pi * np.array([[i, j, k] for k in [0,1] for j in [0, 1] for i in [0, 1]])

    params_lower_bound = -np.pi
    params_upper_bound = np.pi

elif experiment_type == "Adder":
    isCircuit = True
    qasm_file = "adder_n4.qasm"

    NUM_QUBITS = 4
    PARAMS_PER_LAYER = 2 * NUM_QUBITS
    NUM_LAYERS = 3
    NUM_INIT_QUBITS = 3
    NUM_INITS = 2**NUM_INIT_QUBITS

    experiment_inits = np.pi * np.array([[i, j, k] for k in [0,1] for j in [0, 1] for i in [0, 1]])

    params_lower_bound = -np.pi
    params_upper_bound = np.pi

elif experiment_type == "VQE":
  isCircuit = True
  qasm_file = "modded_vqe_4.qasm"

  NUM_QUBITS = 4
  PARAMS_PER_LAYER = 2 * NUM_QUBITS
  NUM_LAYERS = 3
  NUM_INIT_QUBITS = 0
  NUM_INITS = 1

  experiment_inits = np.array([[0, 0, 0]])

  params_lower_bound = -np.pi
  params_upper_bound = np.pi

elif experiment_type == "Ampl_Embed":
  isCircuit = False

  NUM_QUBITS = 4
  PARAMS_PER_LAYER = NUM_QUBITS
  NUM_LAYERS = 4
  NUM_INIT_QUBITS = 0
  NUM_INITS = 10

  experiment_inits = np.load("random_target_prob_dists_unnormalized.npy")

  params_lower_bound = -np.pi
  params_upper_bound = np.pi

elif experiment_type == "random":
    isCircuit = True

    NUM_QUBITS = 4
    PARAMS_PER_LAYER = 2 * NUM_QUBITS
    NUM_LAYERS = 3
    NUM_INIT_QUBITS = 0
    NUM_INITS = 3

    # experiment_inits contains the parameters of the random circuits used for testing
    experiment_inits = np.load("random_circuit_params.npy")

    params_lower_bound = -np.pi
    params_upper_bound = np.pi

else:
  exit("Invalid Experiment Type")

if ENDING_INIT >= NUM_INITS: exit("Invalid Init Selection")

# Other relevant constants
N_OUTPUT = 2 ** NUM_QUBITS
theta_dim = PARAMS_PER_LAYER * NUM_LAYERS  # Number of parameters for the quantum circuit
n_threads = NUM_LAYERS

init_param_space = [Real(params_lower_bound, params_upper_bound, name='theta%d' % i) for i in range(theta_dim)]


def generate_trial_circ(params, layers=NUM_LAYERS, initial_params=[0, 0, 0], num_init_qubits=0, fixed_corr=None):
  """
  Generates a quantum circuit with the specified layers of RY, RZ, and CNOT gates, parameterized
  with params. It is assumed that len(params) = PARAMS_PER_LAYER * NUM_LAYERS.

  num_init_qubits: number of qubits to be initialized using Ry gates. All other qubits are left in 0 state
  """
  circuit = QuantumCircuit(NUM_QUBITS)

  # first do state initialization, using only Ry
  for qubit_num in range(num_init_qubits):
    circuit.ry(initial_params[qubit_num], qubit_num)

  # Add gates to the circuit for each layer.
  for layer in range(layers):

    if fixed_corr == None:
      base_idx = layer * PARAMS_PER_LAYER
    elif fixed_corr == "vert":
      base_idx = layer * 2
    elif fixed_corr == "horz":
      base_idx = 0
    else:
      exit("Invalid corr choice")   
    # Add an Ry and Rz gate for each qubit.

    for qubit_num in range(NUM_QUBITS):
      if fixed_corr == None:
        param_idx = base_idx + qubit_num * int((PARAMS_PER_LAYER / NUM_QUBITS))
      elif fixed_corr == "vert":
        param_idx = base_idx
      elif fixed_corr == "horz":
        param_idx = base_idx + qubit_num

      circuit.ry(params[param_idx], qubit_num)

      # we only want Rz gates when synthesizing logical circuits. Otherwise, use Ry only to keep amplitudes real
      if fixed_corr == "horz":
        circuit.rz(params[param_idx + 4], qubit_num)
      elif PARAMS_PER_LAYER == 2*NUM_QUBITS or fixed_corr == "vert":
        circuit.rz(params[param_idx + 1], qubit_num)
      
    # Add a CX gate for each consecutive qubit pair.
    # Note that I can make a loop of entangelement by doing circuit.cx(NUM_QUBITS - 1, 0),
    # but I currently do not.
    for qubit_num in range(NUM_QUBITS - 1):
      circuit.cx(qubit_num, qubit_num + 1)
  # circuit.measure_all()
  return circuit

# Class that performs full parameter space optimization, followed by concurrent subspace optimization.
class DistributedGBRTMinimizer:
  def __init__(self, space, n_threads, experiment_num):
    self.space = space

    self.n_threads = n_threads
    self.experiment_num = experiment_num
    self.initial_times = []

    self.shared_best_params = [0] * len(space)
    self.shared_best_value = float('inf')
    self.lock = Lock()
    self.outputs = []

  @use_named_args(init_param_space)
  def threaded_objective(self, **params):
    # The full objective function to optimize
    param_values = [params['theta%d' % i] for i in range(len(init_param_space))]
    return objective_function(param_values)

  def update_shared_parameters(self, full_params, partial_objective):
    """
    Update shared parameters if new objective value is better.
    """
    # Use locks to ensure concurrency, and update the shared parameters if the objective value is better.
    with self.lock:
      print(f'update_shared_parameters: acquired lock')
      if partial_objective < self.shared_best_value:
        print(f'update_shared_parameters: updated global shared values')
        self.shared_best_value = partial_objective
        self.shared_best_params = full_params[:]

  def optimize_thread(self, thread_id, n_calls, random_state, subspace_indices, split_method):
    """
    Optimize 8 parameters for the given job and update shared best parameters.
    """
    # can encaps. in a func. for adaptability
    print(f'Started thread {thread_id}')
    params_per_thread = int(theta_dim / self.n_threads)
    partial_space_indices = subspace_indices[thread_id * params_per_thread:(thread_id + 1) * params_per_thread]
    partial_space = [self.space[subspace_idx] for subspace_idx in partial_space_indices]
    output = []
    iteration_times = []

    # Define the objective for optimizing over the subspace.
    def partial_objective(params):
      with self.lock:
        full_params = self.shared_best_params[:]
      for i in range(len(params)):
        full_param_space_idx = partial_space_indices[i]
        full_params[full_param_space_idx] = params[i]
      obj_value = objective_function(full_params)
      output.append(f'Thread {thread_id}, Objective: {obj_value}')
      print(f'Thread {thread_id}, Calling update_shared_parameters')
      self.update_shared_parameters(full_params, obj_value)
      return obj_value

    def timing_callback(res):
      # print(f'timing_callback: res: {res}')
      # if not hasattr(res, 'iteration_times'):
      #   # print(f'res does not have iteration_times, so we are setting it to be an empty list')
      #   iteration_times = []

      # TODO: check if iteration_times has to be global?
      iteration_times.append(time.time() - timing_callback.start_time)
      timing_callback.start_time = time.time()

    # Define the relevant history (which, in this case, is the full parameter space optimization)
    relevant_history = [[x[subspace_idx] for subspace_idx in partial_space_indices] for x in self.initial_history[0]]

    # Perform subspace optimization.
    timing_callback.start_time = time.time()
    start_time = time.time()
    result = gbrt_minimize(partial_objective, partial_space, n_calls=n_calls, random_state=random_state,
                          x0=relevant_history,
                          y0=self.initial_history[1], callback=[timing_callback], acq_func="EI")

    time_elapsed = time.time() - start_time
    time_elapsed += sum(self.initial_times)

    self.outputs.append("\n".join(output))

    with self.lock:
      optimal_thread_params = self.shared_best_params[:]

    np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_split_{split_method}_distributed_n_calls_{n_calls}_{time_elapsed:.2f}_seconds_thread_{thread_id}_run_{self.experiment_num}.npy', np.array(result.func_vals))
    np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_split_{split_method}_distributed_n_calls_{n_calls}_{time_elapsed:.2f}_seconds_thread_{thread_id}_run_{self.experiment_num}_times.npy', np.array(self.initial_times + iteration_times))
    np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_split_{split_method}_distributed_n_calls_{n_calls}_{time_elapsed:.2f}_seconds_thread_{thread_id}_run_{self.experiment_num}_params.npy', np.array(optimal_thread_params))

    print(result.models[-1])
    with open(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_split_{split_method}_distributed_n_calls_{n_calls}_{time_elapsed:.2f}_seconds_thread_{thread_id}_run_{self.experiment_num}_models.pkl', "wb") as f:
        pkl.dump(result.models[-1], f, protocol=5)

  def run_full_optimization(self, n_calls, random_state):
    """
    Run the initial full parameter space optimization.
    """
    @use_named_args(init_param_space)
    def full_objective(**params):
        param_values = [params['theta%d' % i] for i in range(theta_dim)]
        return objective_function(param_values)

    def timing_callback(res):
      # print(f'timing_callback: res: {res}')
      # if not hasattr(res, 'iteration_times'):
      #   # print(f'res does not have iteration_times, so we are setting it to be an empty list')
      #   iteration_times = []

      # TODO: check if iteration_times has to be global?
      self.initial_times.append(time.time() - timing_callback.start_time)
      timing_callback.start_time = time.time()


    timing_callback.start_time = time.time()
    result = gbrt_minimize(full_objective, self.space, n_calls=n_calls, random_state=random_state, callback=[timing_callback], acq_func="EI")
    self.shared_best_params = result.x
    self.shared_best_value = result.fun
    self.initial_history = (result.x_iters, result.func_vals)


  def run_optimization(self, n_calls=5000, n_initial_calls=500, random_state=0, n_jobs=-1, split_method='random'):
    """
    Run the optimization for the given number of calls.
    """
    with open(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_split_{split_method}_distributed_n_calls_{n_calls}_{self.experiment_num}.txt', 'a') as f_out, open(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_split_{split_method}_distributed_n_calls_{n_calls}_{self.experiment_num}_err.txt', 'a') as f_err:
      with contextlib.redirect_stdout(f_out), contextlib.redirect_stderr(f_err):
        if n_jobs == -1:
          n_jobs = os.cpu_count()

        self.run_full_optimization(n_initial_calls, random_state)

        subspace_indices = list(range(theta_dim))
        if split_method == 'random':
          random.shuffle(subspace_indices)

        # Run the optimization for multiple threads.
        backend = 'threading'
        Parallel(n_jobs=n_jobs, backend=backend)(
            delayed(self.optimize_thread)(i, n_calls, random_state, subspace_indices, split_method) for i in range(self.n_threads)
        )

        print(f'Best parameters found:', list(self.shared_best_params))
        print(f'Best objective value:', self.shared_best_value)


# iterate through all the experiments that have to be done. Run each for required number of trials

for exp_num in range(STARTING_INIT, ENDING_INIT+1):
    
    exp_init = experiment_inits[exp_num, :]
    
    if isCircuit:
      initial_angles = exp_init
    else:
      target_probs = exp_init / np.sum(exp_init)
      target_state = np.sqrt(target_probs)
      # target_state = exp_init / np.linalg.norm(exp_init)
      initial_angles = None
      print(target_state)

    print("Experiment Num: ", exp_num)
    print("Initialization/Target State: ", exp_init)

    # generate_trial_circ(np.random.rand(theta_dim))

    DRIVE_PATH = RESULTS_FOLDER +  f"/{NUM_LAYERS}layer_{total_calls}calls_{split_method}_{surrogate_choice}_{experiment_type}_TVD_state" + str(exp_num)
    print(DRIVE_PATH)

    if isdir(DRIVE_PATH):
       exit("This experiment already exits")
    else:
       mkdir(DRIVE_PATH)

    now=datetime.datetime.now()
    stamp = now.isoformat().replace("-", "_")
    stamp = stamp.replace(":", "_")
    stamp = stamp.replace(".", "_")
    print(f'stamp: {stamp}')

    with open(DRIVE_PATH+'/info.txt', "w", encoding="utf-8") as f:
        f.write(experiment_type+"\n")
        f.write(("Initial angles: " if isCircuit else "Target State: ") + str(exp_init) + "\n")
        f.write(f"Num shots per cost func eval: {num_shots}\n")
        f.write(f"Is distributed : {isDistributed}\n")
        f.write(f"Split Method: {split_method}\n")
        f.write(f"Cost func: {cost_function}\n")
        f.write(f"Forcecd Param corr: {intentional_corr}")
        f.write(stamp + "\n")

    if isCircuit:

        if experiment_type != "random":
            qasm_circuit = QuantumCircuit.from_qasm_file(qasm_file)
        elif experiment_type == "random":
            qasm_circuit = generate_trial_circ(initial_angles, layers=NUM_LAYERS, initial_params=[0], num_init_qubits=NUM_INIT_QUBITS)

        true_circuit = QuantumCircuit(NUM_QUBITS, NUM_QUBITS)

        for qubit_num in range(NUM_INIT_QUBITS):
            true_circuit.ry(initial_angles[qubit_num], qubit_num)
            
        for gate in qasm_circuit:
            true_circuit.append(gate)

        print(true_circuit)

        # # VQE qasm file doesnt include measurements
        # if qasm_file == "vqe_4.qasm":
        #     true_circuit.measure_all()
        if cost_function == "UMD":
          true_matrix = np.array(Operator(true_circuit).data)

        if experiment_type == "random":
          true_circuit.measure_all(add_bits=False)

        job = backend.run(true_circuit, shots=num_shots)

        results = job.result().get_counts()
        # print(results.items())
        counts = {int(k, 2):v for k, v in results.items()}
        print(counts)

        quantiles = [0]*N_OUTPUT
        for k in range(N_OUTPUT):
            if k in counts: 
                quantiles[k] = counts[k]
        print(f'quantiles: {quantiles}')
        obs_y = np.array(quantiles)/sum(quantiles)
        print(f'obs_y: {obs_y}')
        observed_quantiles = obs_y   
    else:
       observed_quantiles = np.abs(target_state)**2
       
    #Quantile Regression: Instead of predicting a single value, we can predict multiple quantiles of the distribution.
    #By predicting several quantiles (e.g., 10th, 25th, 50th, 75th, 90th percentiles), we can get a sense of the spread and shape of the distribution.
    #A global seed

    # CHECK TO SEE IF WE NEED THE FOLLOWING
    shared_seed = 42
    import warnings
    warnings.filterwarnings(action='once')


    # Define the full_model function as the black-box function

    if cost_function == "TVD":

      def full_model(theta):
          circuit = generate_trial_circ(theta, layers=NUM_LAYERS, initial_params=initial_angles, num_init_qubits=NUM_INIT_QUBITS, fixed_corr=intentional_corr)
          circuit.measure_all()
          
          job = backend.run(circuit, shots=num_shots)
          results = job.result().get_counts()
          counts = {int(k, 2): v for k, v in results.items()}

          quantiles = [0] * N_OUTPUT
          for k in range(N_OUTPUT):
              if k in counts:
                  quantiles[k] = counts[k]
          # print(f'quantiles: {quantiles}')

          return np.array(quantiles) / (sum(quantiles) + 1e-12)


      def objective_function(params):
          generated_quantiles = full_model(params)
          total_variation_distance = np.sum(np.abs(observed_quantiles - generated_quantiles)) / 2
          return total_variation_distance

    elif cost_function == "UMD":
       
      def full_model(theta):
        circuit = generate_trial_circ(theta, layers=NUM_LAYERS, initial_params=initial_angles, num_init_qubits=NUM_INIT_QUBITS)
        matrix = np.array(Operator(circuit).data)
        return matrix

      def objective_function(params):
        generated_matrix = full_model(params)

        matrix_diff = np.eye(N_OUTPUT) - generated_matrix.T.conj() @ true_matrix

        matrix_norm = np.linalg.norm(matrix_diff, ord=2)

        return matrix_norm


    init_params = np.load("num_qubits_4_4layer_init_params.npy")

    init_params = init_params[:num_runs_per_init, :]
    init_params = init_params[:, :theta_dim].tolist()


    # Optimize
    @use_named_args(init_param_space)
    def objective(**params):
        param_values = [params['theta%d' % i] for i in range(theta_dim)]
        return objective_function(param_values)

    results = []

    def timing_callback(res):
        if not hasattr(timing_callback, 'start_time'):
            iteration_times.append(time.time() - act_start_time)
            timing_callback.start_time = time.time()
        else:
            iteration_times.append(time.time() - timing_callback.start_time)
            timing_callback.start_time = time.time()

    start_time = time.time()

    if isDistributed:
      for i in range(num_runs_per_init):
          print(f'Experiment {i}')
          optimizer = DistributedGBRTMinimizer(init_param_space, n_threads, i)
          optimizer.run_optimization(n_calls=total_calls, n_initial_calls=initial_calls, random_state=0, n_jobs=n_threads, split_method=split_method)
    else:
      for idx, x0 in enumerate(init_params):
        print(f'Experiment {idx}')
        act_start_time = time.time()
        iteration_times = []

        if surrogate_choice == "gbrt":
          result = gbrt_minimize(objective, init_param_space, x0=x0, n_calls=total_calls, random_state=0, acq_func="EI", callback=[timing_callback])
        elif surrogate_choice == "gbqr":
          exit("gbqr not available")
        elif surrogate_choice == "qrf":
          exit("qrf not available")
        elif surrogate_choice == "gp":
          result = gp_minimize(objective, init_param_space, x0=x0, n_calls=total_calls, random_state=0, acq_func="EI", callback=[timing_callback])
        else:
           exit("incorrect surrogate choice")   

        cur_end_time = time.time()
        time_elapsed = cur_end_time - start_time

        if surrogate_choice == "gbrt":
          np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_{surrogate_choice}_fullspace_n_calls_{total_calls}_{time_elapsed:.2f}_seconds_run_{idx}.npy', result.func_vals)
          np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_{surrogate_choice}_fullspace_n_calls_{total_calls}_{time_elapsed:.2f}_seconds_run_{idx}_params.npy', result.x)
          np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_{surrogate_choice}_fullspace_n_calls_{total_calls}_{time_elapsed:.2f}_seconds_run_{idx}_times.npy', np.array(iteration_times))
          np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_{surrogate_choice}_fullspace_n_calls_{total_calls}_{time_elapsed:.2f}_seconds_run_{idx}_points.npy', result.x_iters)
          
          with open(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_{surrogate_choice}_fullspace_n_calls_{total_calls}_{time_elapsed:.2f}_seconds_run_{idx}_models.pkl', "wb") as f:
              pkl.dump(result.models[-1], f, protocol=5)
        
        elif surrogate_choice == "gbqr":
          np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_4layer_gbqr_seq_n_calls_{total_calls}_{time_elapsed:.2f}_seconds_{idx}.npy', best_values_gbqr)
          np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_4layer_gbqr_seq_n_calls_{total_calls}_{time_elapsed:.2f}_seconds_{idx}_params.npy', best_x_gbqr)
        
        elif surrogate_choice == "qrf":
          np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_4layer_qrf_seq_n_calls_{total_calls}_{time_elapsed:.2f}_seconds_{idx}.npy', best_values_qrf)
          np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_4layer_qrf_seq_n_calls_{total_calls}_{time_elapsed:.2f}_seconds_{idx}_params.npy', best_x_qrf)


        # results.append(result)
        # print(f'Minimum TVD for run: {result.fun}')

