if __name__ == "__main__":
  print("Running qfl_main_test.py")

  from utils_qfl_doc import *
  # NOTE, layers, 4:50 PM 8/6: continue here.

  # import torch.multiprocessing as mp
  import multiprocessing as mp

  import os, random

  from pennylane import numpy as np

  import torch

  def set_seed(seed: int, deterministic: bool = False):
    # Python + NumPy
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)                          # seeds NumPy's *global* RNG

    # PyTorch (CPU + CUDA/MPS)
    torch.manual_seed(seed)                       # seeds RNG on all devices
    # Optional (esp. multi-GPU): torch.cuda.manual_seed_all(seed)

    if deterministic:
        # Make CUDA ops deterministic where possible
        torch.use_deterministic_algorithms(True)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

  
  # Define the client configuration of interest.
  # DONE: TOMODIFY, layers: client config to specify what layers optimization I'd like to do.
  # DONE: TOMODIFY, depthFL: change comm rounds for smaller clients to just be 0.
  # DONE: TOMODIFY, depthFL: commented out smaller clients; not necesssary atm.

  # Specify the dataset types for evaluation. Is a list that should contain "mnist", "Fashion-MNIST", or both.
  dataset_types = ["mnist", "Fashion-MNIST"]

  # Specifies the ansatz types to use in evaluation. V-shape is the one used in analysis in the paper.
  ansatz_types_list = [
    {10: ["reversed_staircase"],
     11: ["reversed_staircase", "reversed_staircase"],
     12: ["reversed_staircase", "reversed_staircase", "reversed_staircase"],
     13: ["reversed_staircase", "reversed_staircase", "reversed_staircase", "reversed_staircase"],
     14: ["reversed_staircase", "reversed_staircase", "reversed_staircase", "reversed_staircase", "reversed_staircase"]},

    {10: ["v_shape"],
     11: ["v_shape", "v_shape"],
     12: ["v_shape", "v_shape", "v_shape"],
     13: ["v_shape", "v_shape", "v_shape", "v_shape"],
     14: ["v_shape", "v_shape", "v_shape", "v_shape", "v_shape"]},

    {10: ["revstair_vshape"],
     11: ["revstair_vshape", "revstair_vshape"],
     12: ["revstair_vshape", "revstair_vshape", "revstair_vshape"],
     13: ["revstair_vshape", "revstair_vshape", "revstair_vshape", "revstair_vshape"],
     14: ["revstair_vshape", "revstair_vshape", "revstair_vshape", "revstair_vshape", "revstair_vshape"]}
  ]

  # ansatz_types_list.reverse()

  # Specifies the classes to evaluate on, in string form. Is a list of lists that should be the
  # same length as dataset_types, as the list in index i specifies the list of binary classification runs
  # to run for dataset_types[i].
  class_types_list = [
    [
      ["4", "9"],
      ["3", "4"],
      ["0", "1"]
    ],
    [
      ["2", "4"],
      ["5", "8"],
      ["1", "9"]
    ]
  ]

  # Seeds the run for a specific set of random seeds, for reproducibility.
  random_seeds = [
    12,
    30,
    50,
    70,
    400
  ]


  for dataset_idx, dataset_type_exper in enumerate(dataset_types):
    for qubits_and_layer_types_block_params in ansatz_types_list:
      dataset_classes = class_types_list[dataset_idx]
      print(f"dataset_classes: {dataset_classes}")
      for classes_exper in dataset_classes:
        for random_state in random_seeds:

          set_seed(random_state)

          # This is the overall configuration for the types and numbers of each client.
          # The key is the client type, which must be at least the number of qubits (in our
          # experiments, the smallest is 10). Client type 10 should have the smallest number
          # of layers. For clients with more layers, they should have monotonically increasing client types.
          # "percentage_data" specifies how much of the data should be split for clients of this type. Note that
          # this is ignored if initial data configurations are already provided below.
          # "num_clients" is the number of clients that exist for this type.
          # "local_epochs" is the number of local epochs that clients of this type run.
          # "communication_rounds" can be set to 0, and set to the desired number for the largest size client.
          client_config_exper_parallel = {
              10: {
                  "percentage_data": 0.20,
                  "num_clients": 1,
                  "local_epochs": 1,
                  "communication_rounds": 0
              },
              11: {
                  "percentage_data": 0.20,
                  "num_clients": 1,
                  "local_epochs": 1,
                  "communication_rounds": 0
              },
              12: {
                  "percentage_data": 0.20,
                  "num_clients": 1,
                  "local_epochs": 1,
                  "communication_rounds": 0
              },
              13: {
                  "percentage_data": 0.20,
                  "num_clients": 1,
                  "local_epochs": 1,
                  "communication_rounds": 0
              },
              14: {
                  "percentage_data": 0.20,
                  "num_clients": 1,
                  "local_epochs": 1,
                  "communication_rounds": 1000
              }
              # 16: {
              #     "percentage_data": 1.0,
              #     "num_clients": 4,
              #     "local_epochs": 1,
              #     "communication_rounds": 10
              # }
          }
          cli_type_counts_str = count_numcli_clitype(client_config_exper_parallel)
          print(f"cli_type_counts_str: {cli_type_counts_str}")
          # DONE: TOMODIFY, depthFL: change qubits_layers_list to just have (14, 10).
          # qubits_layers_list specifies, as a list of tuples, the (client_type, num_layers) where num_layers
          # is the number of layers all clients of client_type have.
          qubits_layers_list = [(10, 2), (11, 4), (12, 6), (13, 8), (14, 10)]
          qubits_layers_str = count_qubits_layers(qubits_layers_list)
          print(f"qubits_layers_str: {qubits_layers_str}")
          local_epochs =  max([client_config_exper_parallel[key]["local_epochs"] for key in client_config_exper_parallel])
          # Specifies where to store the logs.
          IMG_PATH = "."
          # dataset_type_exper = "mnist"
          print(f"dataset_type_exper: {dataset_type_exper}")
          # classes_exper = ["4", "9"]
          # Specifies the total amount of data to use.
          n_samples_exper = 3640
          # Specifies the total number of clients.
          n_clients_tot = 5
          # Specifies the amount of data that should be allocated for each client.
          datapoints_per_cli = 128
          n_train_samples_exper = n_clients_tot * datapoints_per_cli
          print(f"n_train_samples_exper: {n_train_samples_exper}")
          num_total_rounds_glob = max([client_config_exper_parallel[key]["communication_rounds"] for key in client_config_exper_parallel])
          testacc_rd_cutoff = 101
          print(f"testacc_rd_cutoff: {testacc_rd_cutoff}")
          is_multiproc = True
          # Specifies the local batch size for each client.
          local_batch_size = 32
          # Specifies the optimizer type used by the clients. Should just be "adam".
          optim_type = "adam"

          lr_gen = 0.004
          # Specifies the learning rate for the client's classifier, in Adam.
          lr_disc = 0.001
          lr_disc_decay = 1.0
          # Specifies whether or not Adam optimizer state persists across rounds.
          cont_optim_state = False

          compute_fid = False

          # DONE: TOMODIFY, layers: have an indicator to specify that this is the classification, layers, no QCNN case.
          # DONE: TOMODIFY, layers, currently mod: use is_qcnn in the below log_data_folder name.
          is_qcnn = False

          amp_embed = False
          # Specifies whether or not clients should do shared PCA over their data.
          shared_pca = True
          # Specifies whether or not clients perform local PCA.
          local_pca = True

          # DONE: TOMODIFY, layers: this should either be "", "posneg", or "random"
          # DONE: TOMODIFY, depthFL: alt_zeros_init should be ""
          alt_zeros_init = ""

          # DONE: TOMODIFY, depthFL: train_models_parallel should be True, heirarchical_train can be False. make these arguments that can be configured.
          # DONE: TOMODIFY, depthFL: inject some argument that indicates that I want to run a qnode multiple times.

          # Note, depthFL: should either be "", "multirun", "ancilla_endmeas", "cheating", or "tunnel_down"
          # "" is for measuring one qubit at the end. The remaining ones are for computing layerwise loss.
          # "multirun" corresponds to "Layerwise" in the paper, "ancilla_endmeas" corresponds to "Ancilla", and
          # "tunnel_down" corresponds to "Funnel".
          # DONE: TOMODIFY, depthFL: have another option for ancilla computation.
          multiclassifier_type = "multirun"

          train_models_parallel = True
          heirarchical_train = False

          # morepers is "aggshared" or "mocked_bcast" for my experiments
          # "aggshared" means parameters should be aggregated. "mocked_bcast" means parameters are not
          # aggregated (which is what is used for standalone training)
          morepers = "aggshared"

          ansatz_type = qubits_and_layer_types_block_params[list(qubits_and_layer_types_block_params.keys())[0]][0]
          print(f"ansatz_type: {ansatz_type}")

          # random_state = 12
          
          log_data_folder = f"{IMG_PATH}/data_logs_{dataset_type_exper}_classes_{'_'.join(classes_exper)}_n_samples_{n_samples_exper}_n_train_{n_train_samples_exper}_qfl_gen_{num_total_rounds_glob}rounds_le_{local_epochs}_bs_{local_batch_size}_opt_{optim_type}_mpd_mlt_lg_{lr_gen}_ld_{lr_disc}_dqdm_ba_sp_qcnn_{is_qcnn}_ba_sm_ae_{amp_embed}_dr_nce_mc_{multiclassifier_type}_mp_{train_models_parallel}_tclip_{morepers}_{random_state}_{ansatz_type}_ldd_{lr_disc_decay}_2_10l_nos"



          # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

          # temporary override. want CPU for everything EXCEPT FID calc.

          device = "cpu"

          print(f"device: {device}")

          # NOTE: tech, I don't want to mask grads...

          # import dill as pickle

          max_workers = os.cpu_count()

          mp_ctx = mp.get_context("spawn")

          # # # # TODO: redirect stdout to see what the problem is.
          # TODO: move all the stuff into if __name__ == "__main__".
          if not os.path.exists(log_data_folder):
              os.makedirs(log_data_folder)

          print("_".join(classes_exper))
          print(f"log_data_folder: {log_data_folder}")

          # import cProfile, pstats

          with open(f"{log_data_folder}/main_stdout.txt", "w") as fout, open(f"{log_data_folder}/main_stderr.txt", "w") as ferr:
              with contextlib.redirect_stdout(fout), contextlib.redirect_stderr(ferr):

                  print(f"qubits_and_layer_types_block_params: {qubits_and_layer_types_block_params}")
                  print(f"classes_exper: {classes_exper}")
                  print(f"random_state: {random_state}")

                  mp.set_start_method("spawn", force=True)
                  # torch.multiprocessing.set_sharing_strategy('file_system')

                  # import multiprocessing as mp

                  mp_ctx = mp.get_context("spawn")

                  print(f"max_workers: {max_workers}")

                  # DONE: TOMODIFY, layers (minor): change betas?
                  gen_betas = (0.5, 0.9)

                  # These are the betas used in the Adam optimizer for the classifier.
                  disc_betas = (0.9, 0.99)

                  # Run the QFL workflow.
                  data_logs_prev = None
                  initial_supp_params = None
                  # TOMODIFY, layers, minor: probably change the data_logs used here.
                  with open(f"{IMG_PATH}/quorus_data/data_logs_n_samples_{n_samples_exper}_dataset_type_{dataset_type_exper}_classes_{'_'.join(classes_exper)}_train_models_parallel_{train_models_parallel}_feature_skew_0.0_label_skew_None_local_pca_{local_pca}_shared_pca_{shared_pca}_gen_False_qcnn_{is_qcnn}_{random_state}.pkl", "rb") as file:
                    data_logs_prev = pickle.load(file)
                  print(f"data_logs_prev.keys(): {data_logs_prev.keys()}")
                  print(f"data_logs_prev['clients_data_dict'].keys(): {data_logs_prev['clients_data_dict'].keys()}")
                  # existing_training_data = data_logs_prev['clients_data_dict'][14][0][0]
                  # subsetted_training_data = [existing_training_data[0][:100, :], existing_training_data[1][:100]]
                  # data_logs_prev['clients_data_dict'][14][0][0] = subsetted_training_data
                  print(f"data_logs_prev: {data_logs_prev}")
                  # del data_logs_prev['clients_data_dict'][10]
                  # del data_logs_prev['clients_data_dict'][11]
                  # del data_logs_prev['clients_data_dict'][12]
                  # del data_logs_prev['clients_data_dict'][13]
                  # print(f"data_logs_prev['clients_data_dict'].keys(): {data_logs_prev['clients_data_dict'].keys()}")
                  # # del data_logs_prev['clients_data_dict'][5]
                  # # print(f"data_logs_prev['clients_data_dict'].keys(): {data_logs_prev['clients_data_dict'].keys()}")
                  # # data_logs_prev['clients_data_dict'][4] = data_logs_prev['clients_data_dict'][4][:1]
                  # # print(f"data_logs_prev['clients_data_dict']: {data_logs_prev['clients_data_dict']}")
                  with open(f"{IMG_PATH}/initial_configs_{n_clients_tot}cli_{datapoints_per_cli}tr_{n_samples_exper - (datapoints_per_cli * n_clients_tot)}test_{qubits_layers_str}/client_params_dict_n_samples_{n_samples_exper}_dataset_type_{dataset_type_exper}_classes_{'_'.join(classes_exper)}_train_models_parallel_{train_models_parallel}_feature_skew_0.0_label_skew_None_local_pca_{local_pca}_shared_pca_{shared_pca}_gen_False_qcnn_{is_qcnn}_{random_state}.pkl", "rb") as file:
                    initial_supp_params = pickle.load(file)
                  print(f"initial_supp_params.keys(): {initial_supp_params.keys()}")
                  print(f"initial_supp_params: {initial_supp_params}")

                  # SHARED_MODEL_NUMLAYERS = 2
                  # for cli_type in initial_supp_params:
                  #   for cli_idx, cli_params in enumerate(initial_supp_params[cli_type]):
                  #     # DONE: TOMODIFY, layers: this is for the generative case; for the nongenerative case, have the code. should be pretty straightforward; just replace block params entirely
                  #     # ^ and later, can do validation to make sure that the structure, order of the supplied block params is consistent with qubits_and_layers_to_add_block_params.
                  #     # if generative:
                  #     #   cli_gen = cli_params[5][0]
                  #     #   supp_cli_gen = initial_supp_params[cli_type][cli_idx][5][0]
                  #     #   cli_gen.load_state_dict(supp_cli_gen.state_dict())
                  #     #   cli_params[5][1].load_state_dict(initial_supp_params[cli_type][cli_idx][5][1].state_dict())
                  #     # elif not is_qcnn:
                  #     cli_params_list = list(cli_params)
                  #     # new_params_list = []
                  #     # # TOMODIFY, depthFL, HACK: heuristic used
                  #     # # NOTE: assumes that there's at least one block param in the list
                  #     # num_tot_qubits = initial_supp_params[cli_type][cli_idx][5][0].shape[1]
                  #     # print(f"cli_type: {cli_type}, cli_idx: {cli_idx}, num_tot_qubits: {num_tot_qubits}")
                  #     # for orig_tens_idx, orig_param_tens in enumerate(initial_supp_params[cli_type][cli_idx][5]):
                  #     #   new_params_list.append(orig_param_tens[:, -(num_tot_qubits - orig_tens_idx):, :])
                  #     # cli_params_list[5] = new_params_list
                  #     cli_params_list[5] = initial_supp_params[cli_type][cli_idx][5][:SHARED_MODEL_NUMLAYERS]
                  #     initial_supp_params[cli_type][cli_idx] = tuple(cli_params_list)
                  
                  # print(f"initial_supp_params, after filtering extra layers: {initial_supp_params}")
                  # print(f"initial_supp_params, after filtering extra qubits: {initial_supp_params}")


                  resc_invpca = False

                  # pr = cProfile.Profile()

                  # TOMODIFY, layers: change the params and add params for injection here. In particular:
                  # DONE: add a param specifying we are NOT doing QCNN (and just doing variational layers)
                  # DONE: add a param specifying the pennylane interface (torch or autograd or other)
                  # and change the params:
                  # (amp_embed should be changed later if I want to try amp embed.)
                  # generative should be False.
                  # modify the layers params based on the client capacity; more capacity, SAME qubits, but add more layers for the block params.
                  # DONE: ^ should be of the format {4: [(4, 10)], 5: [(4, 10), (4, 4)]}. I should at least have as many qubits in the key (4) as I have to represent my block params.

                  # TOMODIFY, layers: maybe define these above so I can specify them in the log folder name?
                  pennylane_interface = "torch"

                  # DONE: TOMODIFY, depthFL: opt_layers needs to change to be None.
                  opt_layers = None

                  print(f"random_state: {random_state}")

                  # DONE: TOMODIFY, depthFL: change qubits_and_layers_to_add_block_params argument in the below function.
                  # TOMODIFY, depthFL: add in the below function some specification of loss_type ?? (so that order of injected args is maintained)
                  data_logs = run_qfl_experiments_parallel_multiprocess(client_config_exper_parallel, classes=classes_exper, n_samples=n_samples_exper, dataset_type=dataset_type_exper, agg_strategy="fedavg_circ", test_frac=((n_samples_exper - n_train_samples_exper)/n_samples_exper), val_frac=0.0, random_state=random_state, pool_in=True,
                                          local_batch_size=local_batch_size, local_lr=0.01, shots=1024, debug=True, save_pkl=True, mask_grads=True, init_client_data_dict=data_logs_prev, qubits_and_layers_to_add_block_params={10: [(10, 2)], 11: [(10, 2), (10, 2)], 12: [(10, 2), (10, 2), (10, 2)], 13: [(10, 2), (10, 2), (10, 2), (10, 2)], 14: [(10, 2), (10, 2), (10, 2), (10, 2), (10, 2)]},
                                                          train_models_parallel=train_models_parallel, same_init=True, feature_skew=0.0, label_skew=None, local_pca=local_pca, do_lda=False, feat_sel_type="top", amp_embed=amp_embed, feat_ordering="same", morepers=morepers, custom_debug=True,
                                                          shared_pca=shared_pca, heirarchical_train=heirarchical_train, generative=False, use_torch=True, fed_pca_mocked=True, lr_gen=lr_gen, lr_disc=lr_disc, noise_func=generate_latent_noise, criterion_func=nn.BCELoss,
                                                              targ_data_folder_prefix="testing_gen_imgs", gen_data_folder_prefix="qgan_gen_imgs", device=device, fid_batch_size=None, max_workers=max_workers, mp_ctx=mp_ctx, log_data_folder=log_data_folder,
                                                                          initial_supp_params=initial_supp_params, optim_type=optim_type, gen_betas=gen_betas, disc_betas=disc_betas, resc_invpca=resc_invpca, compute_fid=compute_fid, is_qcnn=is_qcnn,
                                                                          pennylane_interface=pennylane_interface, opt_layers=opt_layers, alt_zeros_init=alt_zeros_init, multiclassifier_type=multiclassifier_type, testacc_rd_cutoff=testacc_rd_cutoff,
                                                                          qubits_and_layer_types_block_params=qubits_and_layer_types_block_params, lr_disc_decay=lr_disc_decay, cont_optim_state=cont_optim_state)
                  print("running after the main() function.")

                  print(f"data_logs.keys(): {data_logs.keys()}")

                  print(f"data_logs['clients_data_dict']: {data_logs['clients_data_dict']}")

                  with open(f"{log_data_folder}/result_datalogs.pkl", "wb") as file:
                      pickle.dump(data_logs, file)