
import random
import time
import argparse
from os import mkdir
from os.path import isdir
import datetime
import pickle as pkl

import numpy as np
import matplotlib.pyplot as plt

from qiskit import QuantumCircuit
from qiskit_aer import AerSimulator

import os
from threading import Lock, Event

from skopt import gbrt_minimize
from skopt.space import Real
from skopt.utils import use_named_args

from joblib import Parallel, delayed
import time
import random

import contextlib
import io

RESULTS_FOLDER = "AmplEmbed_Mod_Fast_results"

if not isdir(RESULTS_FOLDER):
   mkdir(RESULTS_FOLDER)


# Set experiment parameters
experiment_type = "Ampl_Embed_Constr" # One of Toffoli, Adder, Ampl_Embed, or Ampl_Embed_Constr
num_runs_per_init = 5 # number of runs/trials per qubit initialization / target state
total_calls = 4000
initial_calls = 500
num_shots = 10000 # number of shots to make when evaluating cost function
split_method = 'layer' 
verbose = False

STARTING_INIT = 0
ENDING_INIT = 9

backend = AerSimulator()

if experiment_type == "Toffoli":
    isCircuit = True
    qasm_file = "modded_toffoli_n3.qasm"

    NUM_QUBITS = 3
    PARAMS_PER_LAYER = 2 * NUM_QUBITS
    NUM_LAYERS = 2
    NUM_INIT_QUBITS = 3
    NUM_INITS = 2**NUM_INIT_QUBITS # number of different qubit initializations to consider. For 3 qubits there are 8 basis states

    experiment_inits = np.pi * np.array([[i, j, k] for k in [0,1] for j in [0, 1] for i in [0, 1]])

    params_lower_bound = -np.pi
    params_upper_bound = np.pi

elif experiment_type == "Adder":
    isCircuit = True
    qasm_file = "adder_n4.qasm"

    NUM_QUBITS = 4
    PARAMS_PER_LAYER = 2 * NUM_QUBITS
    NUM_LAYERS = 3
    NUM_INIT_QUBITS = 3
    NUM_INITS = 2**NUM_INIT_QUBITS

    experiment_inits = np.pi * np.array([[i, j, k] for k in [0,1] for j in [0, 1] for i in [0, 1]])

    params_lower_bound = -np.pi
    params_upper_bound = np.pi

elif experiment_type == "VQE":
  isCircuit = True
  qasm_file = "modded_vqe_4.qasm"

  NUM_QUBITS = 4
  PARAMS_PER_LAYER = 2 * NUM_QUBITS
  NUM_LAYERS = 3
  NUM_INIT_QUBITS = 0
  NUM_INITS = 1

  experiment_inits = np.array([[0, 0, 0]])

  params_lower_bound = -np.pi
  params_upper_bound = np.pi

elif experiment_type == "Ampl_Embed":
  isCircuit = False

  NUM_QUBITS = 4
  PARAMS_PER_LAYER = NUM_QUBITS
  NUM_LAYERS = 4
  NUM_INIT_QUBITS = 0
  NUM_INITS = 10

  experiment_inits = np.load("random_target_prob_dists_unnormalized.npy")

  params_lower_bound = -np.pi
  params_upper_bound = np.pi

elif experiment_type == "Ampl_Embed_Constr":
  isCircuit = False

  NUM_QUBITS = 4
  PARAMS_PER_LAYER = NUM_QUBITS
  NUM_LAYERS = 4
  NUM_INIT_QUBITS = 4
  NUM_INITS = 10

  experiment_inits = np.load("random_target_prob_dists_unnormalized.npy")
  init_angle_vals = np.array([(np.pi / 2) for _ in range(NUM_QUBITS)])

  params_lower_bound = (-np.pi / 2) / NUM_LAYERS
  params_upper_bound = (np.pi / 2) / NUM_LAYERS

else:
  exit("Invalid Experiment Type")

# Other relevant constants
N_OUTPUT = 2 ** NUM_QUBITS
theta_dim = PARAMS_PER_LAYER * NUM_LAYERS  # Number of parameters for the quantum circuit
n_threads = NUM_LAYERS

init_param_space = [Real(params_lower_bound, params_upper_bound, name='theta%d' % i) for i in range(theta_dim)]


def generate_trial_circ(params, layers=NUM_LAYERS, initial_params=[0, 0, 0], num_init_qubits=0):
  """
  Generates a quantum circuit with the specified layers of RY, RZ, and CNOT gates, parameterized
  with params. It is assumed that len(params) = PARAMS_PER_LAYER * NUM_LAYERS.

  num_init_qubits: number of qubits to be initialized using Ry gates. All other qubits are left in 0 state
  """
  circuit = QuantumCircuit(NUM_QUBITS)

  # first do state initialization, using only Ry
  for qubit_num in range(num_init_qubits):
    circuit.ry(initial_params[qubit_num], qubit_num)

  # Add gates to the circuit for each layer.
  for layer in range(layers):
    base_idx = layer * PARAMS_PER_LAYER
    # Add an Ry and Rz gate for each qubit.
    for qubit_num in range(NUM_QUBITS):
      param_idx = base_idx + qubit_num * int((PARAMS_PER_LAYER / NUM_QUBITS))
      circuit.ry(params[param_idx], qubit_num)

      # we only want Rz gates when synthesizing logical circuits. Otherwise, use Ry only to keep amplitudes real
      if PARAMS_PER_LAYER == 2*NUM_QUBITS:
        circuit.rz(params[param_idx + 1], qubit_num)

    # Add a CX gate for each consecutive qubit pair.
    # Note that I can make a loop of entangelement by doing circuit.cx(NUM_QUBITS - 1, 0),
    # but I currently do not.
    for qubit_num in range(NUM_QUBITS - 1):
      circuit.cx(qubit_num, qubit_num + 1)
  # circuit.measure_all()
  return circuit

# Class that performs full parameter space optimization, followed by concurrent subspace optimization.
class DistributedGBRTMinimizer:
  def __init__(self, space, n_threads, experiment_num):
    self.space = space

    self.n_threads = n_threads
    self.experiment_num = experiment_num
    self.initial_times = []

    self.shared_best_params = [0] * len(space)
    self.shared_best_value = float('inf')
    self.lock = Lock()
    # self.outputs = []

  @use_named_args(init_param_space)
  def threaded_objective(self, **params):
    # The full objective function to optimize
    param_values = [params['theta%d' % i] for i in range(len(init_param_space))]
    return objective_function(param_values, verbose=False)

  def update_shared_parameters(self, full_params, partial_objective, verbose=False):
    """
    Update shared parameters if new objective value is better.
    """
    # Use locks to ensure concurrency, and update the shared parameters if the objective value is better.
    with self.lock:
      if verbose:
        print(f'update_shared_parameters: acquired lock')
      if partial_objective < self.shared_best_value:
        if verbose:
          print(f'update_shared_parameters: updated global shared values')
        self.shared_best_value = partial_objective
        # Copy is probably not needed here
        self.shared_best_params = full_params[:]

  def optimize_thread(self, thread_id, n_calls, random_state, subspace_indices, split_method, verbose=False):
    """
    Optimize 8 parameters for the given job and update shared best parameters.
    """
    # can encaps. in a func. for adaptability
    if verbose:
      print(f'Started thread {thread_id}')
    params_per_thread = int(theta_dim / self.n_threads)
    partial_space_indices = subspace_indices[thread_id * params_per_thread:(thread_id + 1) * params_per_thread]
    partial_space = [self.space[subspace_idx] for subspace_idx in partial_space_indices]
    # output = []
    iteration_times = []

    # Define the objective for optimizing over the subspace.
    def partial_objective(params, verbose=verbose):
      with self.lock:
        full_params = self.shared_best_params[:]
      for i in range(len(params)):
        full_param_space_idx = partial_space_indices[i]
        full_params[full_param_space_idx] = params[i]
      obj_value = objective_function(full_params, verbose=verbose)
      # output.append(f'Thread {thread_id}, Objective: {obj_value}')
      if verbose:
        print(f'Thread {thread_id}, Calling update_shared_parameters')
      self.update_shared_parameters(full_params, obj_value, verbose=verbose)
      return obj_value

    def timing_callback(res):
      # print(f'timing_callback: res: {res}')
      # if not hasattr(res, 'iteration_times'):
      #   # print(f'res does not have iteration_times, so we are setting it to be an empty list')
      #   iteration_times = []

      # TODO: check if iteration_times has to be global?
      iteration_times.append(time.time() - timing_callback.start_time)
      timing_callback.start_time = time.time()

    # Define the relevant history (which, in this case, is the full parameter space optimization)
    relevant_history = [[x[subspace_idx] for subspace_idx in partial_space_indices] for x in self.initial_history[0]]

    # Perform subspace optimization.
    timing_callback.start_time = time.time()
    start_time = time.time()
    result = gbrt_minimize(partial_objective, partial_space, n_calls=n_calls, random_state=random_state,
                          x0=relevant_history,
                          y0=self.initial_history[1], callback=[timing_callback], acq_func="EI")

    time_elapsed = time.time() - start_time
    time_elapsed += sum(self.initial_times)

    # self.outputs.append("\n".join(output))

    with self.lock:
      optimal_thread_params = self.shared_best_params[:]

    np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_split_{split_method}_distributed_n_calls_{n_calls}_{time_elapsed:.2f}_seconds_thread_{thread_id}_run_{self.experiment_num}.npy', np.array(result.func_vals))
    np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_split_{split_method}_distributed_n_calls_{n_calls}_{time_elapsed:.2f}_seconds_thread_{thread_id}_run_{self.experiment_num}_times.npy', np.array(self.initial_times + iteration_times))
    np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_split_{split_method}_distributed_n_calls_{n_calls}_{time_elapsed:.2f}_seconds_thread_{thread_id}_run_{self.experiment_num}_params.npy', np.array(optimal_thread_params))

    if verbose: 
      print(result.models[-1])
    with open(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_split_{split_method}_distributed_n_calls_{n_calls}_{time_elapsed:.2f}_seconds_thread_{thread_id}_run_{self.experiment_num}_models.pkl', "wb") as f:
        pkl.dump(result.models[-1], f, protocol=5)

  def run_full_optimization(self, n_calls, random_state):
    """
    Run the initial full parameter space optimization.
    """
    @use_named_args(init_param_space)
    def full_objective(**params):
        param_values = [params['theta%d' % i] for i in range(theta_dim)]
        return objective_function(param_values, verbose=False)

    def timing_callback(res):
      # print(f'timing_callback: res: {res}')
      # if not hasattr(res, 'iteration_times'):
      #   # print(f'res does not have iteration_times, so we are setting it to be an empty list')
      #   iteration_times = []

      # TODO: check if iteration_times has to be global?
      self.initial_times.append(time.time() - timing_callback.start_time)
      timing_callback.start_time = time.time()


    timing_callback.start_time = time.time()
    result = gbrt_minimize(full_objective, self.space, n_calls=n_calls, random_state=random_state, callback=[timing_callback], acq_func="EI")
    self.shared_best_params = result.x
    self.shared_best_value = result.fun
    self.initial_history = (result.x_iters, result.func_vals)


  def run_optimization(self, n_calls=5000, n_initial_calls=500, random_state=0, n_jobs=-1, split_method='random', verbose=False):
    """
    Run the optimization for the given number of calls.
    """
    with open(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_split_{split_method}_distributed_n_calls_{n_calls}_{self.experiment_num}.txt', 'a') as f_out, open(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_split_{split_method}_distributed_n_calls_{n_calls}_{self.experiment_num}_err.txt', 'a') as f_err:
      with contextlib.redirect_stdout(f_out), contextlib.redirect_stderr(f_err):
        if n_jobs == -1:
          n_jobs = os.cpu_count()

        self.run_full_optimization(n_initial_calls, random_state)

        subspace_indices = list(range(theta_dim))
        if split_method == 'random':
          random.shuffle(subspace_indices)

        # Run the optimization for multiple threads.
        backend = 'threading'
        Parallel(n_jobs=n_jobs, backend=backend)(
            delayed(self.optimize_thread)(i, n_calls, random_state, subspace_indices, split_method, verbose) for i in range(self.n_threads)
        )
        if verbose:
          print(f'Best parameters found:', list(self.shared_best_params))
          print(f'Best objective value:', self.shared_best_value)

        np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_split_{split_method}_distributed_n_calls_{n_calls}_{self.experiment_num}_opt_params.npy', np.array(self.shared_best_params))
        np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_split_{split_method}_distributed_n_calls_{n_calls}_{self.experiment_num}_opt_obj_val.npy', np.array([self.shared_best_value]))


# iterate through all the experiments that have to be done. Run each for required number of trials

for exp_num in range(STARTING_INIT, ENDING_INIT+1):
    
    exp_init = experiment_inits[exp_num, :]
    
    if isCircuit:
      initial_angles = exp_init
    else:
      target_probs = exp_init / np.sum(exp_init)
      target_state = np.sqrt(target_probs)
      # target_state = exp_init / np.linalg.norm(exp_init)
      if experiment_type == "Ampl_Embed_Constr":
        initial_angles = init_angle_vals
      else:
        initial_angles = None
      if verbose:
        print(target_state)

    if verbose:
      print("Experiment Num: ", exp_num)
      print("Initialization/Target State: ", exp_init)

    # generate_trial_circ(np.random.rand(theta_dim))

    DRIVE_PATH = RESULTS_FOLDER +  f"/{NUM_LAYERS}layer_{total_calls}calls_dist_layer_gbrt_{experiment_type}_TVD_state" + str(exp_num)
    if verbose:
      print(DRIVE_PATH)

    if isdir(DRIVE_PATH):
       exit("This experiment already exits")
    else:
       mkdir(DRIVE_PATH)

    now=datetime.datetime.now()
    stamp = now.isoformat().replace("-", "_")
    stamp = stamp.replace(":", "_")
    stamp = stamp.replace(".", "_")
    if verbose:
      print(f'stamp: {stamp}')

    with open(DRIVE_PATH+'/info.txt', "w", encoding="utf-8") as f:
        f.write(experiment_type + " using distributed method\n")
        f.write(("Initial angles: " if isCircuit else "Target State: ") + str(exp_init) + "\n")
        f.write(f"Num shots per cost func eval: {num_shots}\n")
        f.write(f"Split Method: {split_method}\n")
        f.write(stamp + "\n")

    if isCircuit:
        qasm_circuit = QuantumCircuit.from_qasm_file(qasm_file)

        true_circuit = QuantumCircuit(NUM_QUBITS, NUM_QUBITS)

        for qubit_num in range(NUM_INIT_QUBITS):
            true_circuit.ry(initial_angles[qubit_num], qubit_num)
            
        for gate in qasm_circuit:
            true_circuit.append(gate)

        if verbose:
          print(true_circuit)

        # # VQE qasm file doesnt include measurements
        # if qasm_file == "vqe_4.qasm":
        #     true_circuit.measure_all()

        job = backend.run(true_circuit, shots=num_shots)

        results = job.result().get_counts()
        # print(results.items())
        counts = {int(k, 2):v for k, v in results.items()}
        if verbose:
          print(counts)

        quantiles = [0]*N_OUTPUT
        for k in range(N_OUTPUT):
            if k in counts: 
                quantiles[k] = counts[k]
        if verbose:
          print(f'quantiles: {quantiles}')
        obs_y = np.array(quantiles)/sum(quantiles)
        if verbose:
          print(f'obs_y: {obs_y}')
        observed_quantiles = obs_y   
    else:
       observed_quantiles = np.abs(target_state)**2
       
    #Quantile Regression: Instead of predicting a single value, we can predict multiple quantiles of the distribution.
    #By predicting several quantiles (e.g., 10th, 25th, 50th, 75th, 90th percentiles), we can get a sense of the spread and shape of the distribution.
    #A global seed

    # CHECK TO SEE IF WE NEED THE FOLLOWING
    shared_seed = 42
    import warnings
    warnings.filterwarnings(action='once')


    # Define the full_model function as the black-box function
    def full_model(theta, verbose=False):
        circuit = generate_trial_circ(theta, layers=NUM_LAYERS, initial_params=initial_angles, num_init_qubits=NUM_INIT_QUBITS)
        circuit.measure_all()
        
        job = backend.run(circuit, shots=num_shots)
        results = job.result().get_counts()
        counts = {int(k, 2): v for k, v in results.items()}

        quantiles = [0] * N_OUTPUT
        for k in range(N_OUTPUT):
            if k in counts:
                quantiles[k] = counts[k]
        if verbose:
          print(f'quantiles: {quantiles}')

        return np.array(quantiles) / (sum(quantiles) + 1e-12)


    def objective_function(params, verbose=False):
        generated_quantiles = full_model(params, verbose)
        total_variation_distance = np.sum(np.abs(observed_quantiles - generated_quantiles)) / 2
        return total_variation_distance

    init_params = np.load("num_qubits_4_4layer_init_params.npy")

    init_params = init_params[:num_runs_per_init, :]
    init_params = init_params[:, :theta_dim].tolist()


    # Optimize
    @use_named_args(init_param_space)
    def objective(**params):
        param_values = [params['theta%d' % i] for i in range(theta_dim)]
        return objective_function(param_values, verbose=verbose)

    results = []

    def timing_callback(res):
        if not hasattr(timing_callback, 'start_time'):
            iteration_times.append(time.time() - act_start_time)
            timing_callback.start_time = time.time()
        else:
            iteration_times.append(time.time() - timing_callback.start_time)
            timing_callback.start_time = time.time()

    start_time = time.time()

    for i in range(num_runs_per_init):
        if verbose:
          print(f'Experiment {i}')
        start_time = time.time()
        optimizer = DistributedGBRTMinimizer(init_param_space, n_threads, i)
        optimizer.run_optimization(n_calls=total_calls, n_initial_calls=initial_calls, random_state=0, n_jobs=n_threads, split_method=split_method, verbose=verbose)
        time_elapsed = time.time() - start_time
        if verbose:
          print(f'Experiment {i}, time elapsed: {time_elapsed}')
        np.save(f'{DRIVE_PATH}/num_qubits_{NUM_QUBITS}_paramsperlayer_{PARAMS_PER_LAYER}_numlayers_{NUM_LAYERS}_split_{split_method}_distributed_n_calls_{total_calls}_{i}_time_elapsed.npy', np.array([time_elapsed]))

