# -*- coding: utf-8 -*-
# from google.colab import drive
# drive.mount('/content/drive')

from custom_slog import print_cust

import os
os.cpu_count()

# """# Install Dependencies"""

# import subprocess, sys
# subprocess.check_call([sys.executable, "-m", "pip", "install", "qiskit"])

# subprocess.check_call([sys.executable, "-m", "pip", "install", "qiskit_aer"])

# # !pip install qiskit-aer-gpu

# subprocess.check_call([sys.executable, "-m", "pip", "install", "pennylane"])

# # for FID calculation
# subprocess.check_call([sys.executable, "-m", "pip", "install", "pytorch-fid"])

# # for backprop for torch
# # ONLY for backwards compatibility for code that uses this
# subprocess.check_call([sys.executable, "-m", "pip", "install", "torch_pca"])

# devdep; for computation graph vis
# !pip install torchviz

# !pip show graphviz

# !pip list

# !pip freeze > requirements.txt

# import numpy as np

# Import dependencies
from pennylane import numpy as np
import matplotlib.pyplot as plt
from qiskit import QuantumCircuit
from qiskit_aer import Aer, AerSimulator
from qiskit.circuit import Parameter
from qiskit.circuit.library import RXXGate, RYYGate, RZZGate
from sklearn.decomposition import PCA
# from torch_pca import PCA
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from scipy.stats import ks_2samp

# for integration of quantum and classical params
import torch

import pennylane as qml

import pickle

"""# Data Loading

## Load MNIST Digits Helpers
"""

from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_openml
from urllib.error import HTTPError, URLError

# TOMODIFY, layers: these are globals and are ONLY here to support names for fashion. it's not exactly necessary.
_FASHION_ID_TO_NAME = {
    0: "T-shirt/top", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat",
    5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle boot"
}
_FASHION_NAME_TO_ID = {v.lower(): k for k, v in _FASHION_ID_TO_NAME.items()}

def _normalize_and_flatten(X):
    X = X.astype(np.float32) / 255.0
    if X.ndim == 3:  # (N, 28, 28)
        X = X.reshape((X.shape[0], -1))
    return X

def _canon_label_list(digits_to_keep, is_fashion):
    ids = []
    for d in digits_to_keep:
        s = str(d).strip()
        if s.isdigit():
            ids.append(int(s))
        elif is_fashion and s.lower() in _FASHION_NAME_TO_ID:
            ids.append(_FASHION_NAME_TO_ID[s.lower()])
        else:
            raise ValueError(
                f"Unrecognized class '{d}'. "
                + ("Use 0–9 or Fashion-MNIST names like 'sneaker', 'bag'."
                   if is_fashion else "Use digits 0–9.")
            )
    # dedupe but keep stable order by sorting later
    return sorted(set(ids))

"""## Load MNIST Digits Function"""

def load_mnist_digits(digits_to_keep, n_samples=2000, dataset_name="mnist"):
    """
    Load MNIST or Fashion-MNIST, filter to selected classes, and return (X, y).
    Tries OpenML first; if that errors, falls back to TensorFlow's keras.datasets.

    Parameters:
      digits_to_keep: list of class identifiers. For MNIST: digits (e.g., [4, 9]).
                      For Fashion-MNIST: digits 0–9 or names (e.g., ["sneaker", "bag"]).
      n_samples: number of examples to return (stratified). If >= available, returns all.
      dataset_name: "mnist" or "fashion-mnist" (case/underscore/dash tolerant).

    Returns:
      X: float32 array of shape (N, 784), normalized to [0, 1].
      y: int array of shape (N,), labels remapped to 0..K-1 following sorted(digits_to_keep).
    """
    ds = dataset_name.replace("_", "-").lower()
    # NOTE, layers: can later add some additional conditions to allow for CIFAR-10.
    if ds in {"mnist", "mnist-784"}:
        openml_name = "mnist_784"          # OpenML dataset name
        is_fashion = False
        tf_loader_path = ("tensorflow.keras.datasets.mnist", "mnist")
    elif ds in {"fashion-mnist", "fashion mnist", "fashion"}:
        openml_name = "Fashion-MNIST"      # OpenML dataset name
        is_fashion = True
        tf_loader_path = ("tensorflow.keras.datasets.fashion_mnist", "fashion_mnist")
    else:
        raise ValueError("dataset_name must be 'mnist' or 'fashion-mnist'.")

    keep_ids = _canon_label_list(digits_to_keep, is_fashion)

    # 1) Try OpenML first
    X, y = None, None
    try:
        mnist = fetch_openml(openml_name, version=1, as_frame=False)
        X = _normalize_and_flatten(mnist.data)
        # y can be strings; coerce to int if possible, else map names for fashion
        try:
            y = mnist.target.astype(int)
        except ValueError:
            # NOTE: this should be unnecessary; Fashion-MNIST from OpenML has labels that are string representations of digits.
            if is_fashion:
                y = np.array([_FASHION_NAME_TO_ID[str(lbl).lower()] for lbl in mnist.target])
            else:
                # MNIST should be numeric; re-raise if it's not
                raise
    except (HTTPError, URLError, OSError, RuntimeError, ValueError) as _:
        # 2) Fallback to TensorFlow's Keras loader
        try:
            print_cust(f"load_mnist_digits, USING TENSORFLOW TO LOAD IN MNIST DATA (b/c fetch_openml has errors)")
            print_cust(f"load_mnist_digits, error from fetch_openml code, _: {_}")
            # lazy import to avoid TF dependency unless needed
            import importlib
            mod_name, attr = tf_loader_path
            mod = importlib.import_module(mod_name)
            (X_train, y_train), (X_test, y_test) = getattr(mod, "load_data")()
            X = _normalize_and_flatten(np.concatenate([X_train, X_test], axis=0))
            y = np.concatenate([y_train, y_test], axis=0).astype(int)
        except Exception as e:
            raise RuntimeError(
                f"Both OpenML and TensorFlow loading failed: {type(e).__name__}: {e}"
            )

    # Filter to desired classes
    mask = np.isin(y, keep_ids)
    X, y = X[mask], y[mask]

    # Remap labels to 0..K-1 in sorted order of requested classes
    mapping = {label: idx for idx, label in enumerate(keep_ids)}
    y = np.array([mapping[int(lbl)] for lbl in y], dtype=int)

    # Optional stratified subsample
    if n_samples < len(y):
        # NOTE: because random_state is hardcoded in here, then we randomly select the SAME subset of data in total which is randomly sampled from
        # to create the clients' individual data.
        X, _, y, _ = train_test_split(
            X, y, train_size=n_samples, stratify=y, random_state=42
        )

    return X, y

"""## Load CIFAR-10 Data"""

# from tensorflow.keras.datasets import cifar10

def load_cifar10(classes, n_samples=2000):
    """
    Load the CIFAR-10 dataset, filter by the specified classes,
    and return the filtered data and labels.

    Parameters:
      classes (list): List of classes to keep (either as numeric strings like ["3", "5"]
                      or as names like ["cat", "dog"]).
      n_samples (int): Number of samples to return.

    Returns:
      X (np.array): Normalized and flattened image data.
      y (np.array): Integer labels remapped according to sorted(selected_classes).
    """

    # Load the data from Keras (train and test combined)
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()
    X = np.concatenate([X_train, X_test], axis=0)
    y = np.concatenate([y_train, y_test], axis=0).flatten()

    # CIFAR-10 standard class names mapping
    cifar10_classes = {
        0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer",
        5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"
    }

    # Determine if classes are specified as digit strings or names.
    try:
        if classes and classes[0].isdigit():
            selected_indices = [int(c) for c in classes]
        else:
            # Normalize to lower case for comparison
            reverse_mapping = {v.lower(): k for k, v in cifar10_classes.items()}
            selected_indices = [reverse_mapping[c.lower()] for c in classes]
    except Exception as e:
        print_cust("Error processing classes for CIFAR-10: ", e)
        raise e

    mask = np.isin(y, selected_indices)
    X, y = X[mask], y[mask]

    # Create a mapping similar to load_mnist_digits, ensuring ordering is consistent.
    mapping = {val: idx for idx, val in enumerate(sorted(selected_indices))}

    def to_int(x):
        # If x is a PyTorch tensor, use .item() to extract the value.
        return int(x.item()) if hasattr(x, "item") else int(x)

    # Convert each element in y to a standard integer before mapping.
    y = np.array([mapping[to_int(val)] for val in y])

    # Normalize to [0, 1] and flatten the images.
    X = X.astype("float32") / 255.0
    X = X.reshape(X.shape[0], -1)

    # If needed, subsample the dataset via stratified sampling.
    if n_samples < len(y):
        X, _, y, _ = train_test_split(X, y, train_size=n_samples, stratify=y, random_state=42)
    return X, y

"""# Conv layers computation helper funcs

## Conv layers computation function
"""

def compute_conv_layers(n_qubits, n_output_qubits, generative=False, is_qcnn=True):
    """
    Given the number of input qubits and the number of output qubits for a QCNN, computes the
    number of convolutional layers in the QCNN. Assumes that, after each convolutional layer, the number
    of qubits in the circuit is halved.

    Parameters:
      n_qubits (int): The number of input qubits to the QCNN.
      n_output_qubits (int): The number of output qubits in the QCNN (these are likely measured to obtain your final result).

    Returns:
      conv_layers (int): The number of convolutional layers in your QCNN.
    """
    if generative or not is_qcnn:
      print_cust(f"compute_conv_layers, returning 0")
      return 0
    conv_layers = 0
    # Start with all your qubits.
    temp = n_qubits
    # While you haven't reached your output, halve the number of qubits.
    while temp >= n_output_qubits:
        print_cust(f"compute_conv_layers, temp: {temp}, n_output_qubits: {n_output_qubits}")
        if n_output_qubits == 1 and temp == 1:
          break
        # If pooling with an odd number, keep the extra qubit.
        if temp % 2 == 0:
            temp = temp // 2
        else:
            temp = temp // 2 + 1
        if temp >= n_output_qubits:
          conv_layers += 1
    return conv_layers

# TOADD: an argument for the function above, for generative QFL, to just return 0.
# DONE: TOMODIFY, layers: an additional argument that specifies if we are operating in the non-conv layers case, in which case, we return 0.

"""## Num qubits computation function"""

def compute_reduced_qubits(n_qubits, conv_layers):
  cur_num_qubits = n_qubits
  for _ in range(conv_layers):
    if cur_num_qubits % 2 == 0:
      cur_num_qubits = cur_num_qubits // 2
    else:
      cur_num_qubits = cur_num_qubits // 2 + 1
  return cur_num_qubits

"""# Angle encoding circuit"""

# Not used
NUM_BASE_QUBITS = 4

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

"""## Random Sketching Class"""

# This class adheres to the "fit_transform(X)" and "transform(X)" interface, for compatibility with Numpy type objects.
# This class implements random sketching, given an input sketch matrix.
# TODO: generate the sketch matrix in the class, for better separation of concerns.
class SketchTransformer:
    def __init__(self, sketch_mat: np.ndarray):
        """
        Initialize with a fixed sketch matrix.

        Parameters
        ----------
        sketch_mat : ndarray of shape (d, k)
            The random sketching matrix to be used for all transforms.
        """
        # store the sketch matrix
        self.sketch_mat = sketch_mat

    def fit_transform(self, X: np.ndarray) -> np.ndarray:
        """
        Apply the sketch to new data X by matrix multiplication.

        Parameters
        ----------
        X : ndarray of shape (n_samples, d)
            The data to be sketched.

        Returns
        -------
        X_sketch : ndarray of shape (n_samples, k)
            The sketched data, X @ sketch_mat.
        """
        # check dimensions
        n_samples, d = X.shape
        d2, k = self.sketch_mat.shape
        if d != d2:
            raise ValueError(f"Input data has dimension {d}, but sketch_mat expects {d2}.")
        # perform the matrix multiplication
        return X @ self.sketch_mat

    def transform(self, X: np.ndarray) -> np.ndarray:
        """
        Apply the sketch to new data X by matrix multiplication.

        Parameters
        ----------
        X : ndarray of shape (n_samples, d)
            The data to be sketched.

        Returns
        -------
        X_sketch : ndarray of shape (n_samples, k)
            The sketched data, X @ sketch_mat.
        """
        # check dimensions
        n_samples, d = X.shape
        d2, k = self.sketch_mat.shape
        if d != d2:
            raise ValueError(f"Input data has dimension {d}, but sketch_mat expects {d2}.")
        # perform the matrix multiplication
        return X @ self.sketch_mat

def angle_encode_data(X, n_components,
                      y=None,
                      do_pca=False, ret_pca=False,
                      do_lda=False, ret_lda=False,
                      sketch_mat=None, custom_debug=False,
                      ret_orig_data=False):
    """
    Reduce X via PCA or LDA (in this case, instead of LDA, we have random sketching) to n_components, then scale each component to [0, π].

    Parameters
    ----------
    X : array-like, shape (n_samples, n_features)
    n_components : int
      Number of components to keep.
    y : array-like, shape (n_samples,), optional
      Class labels; required if do_lda=True.
    do_pca : bool, default=False
      If True, perform PCA.
    ret_pca : bool, default=False
      If True and do_pca=True, return (X_angle, pca, X_pca).
    do_lda : bool, default=False
      If True, perform LDA.
    ret_lda : bool, default=False
      If True and do_lda=True, return (X_angle, lda, X_lda).
    sketch_mat : np.ndarray, default=None
      If do_lda=True, then use the provided sketch_mat to perform random sketching.

    Returns
    -------
    X_angle : ndarray, shape (n_samples, n_components)
      Angle-encoded data in [0, π].
    [pca, X_pca] or [lda, X_lda] : only if ret_pca or ret_lda is True
      The fitted model and the raw projected data.
    """
    # 1) reduction
    print_cust(f"angle_encode_data, X.shape: {X.shape}")
    if do_pca and do_lda:
        raise ValueError("Choose exactly one of do_pca or do_lda")
    if do_pca:
        pca = PCA(n_components=n_components, random_state=42)
        X_proj = pca.fit_transform(X)
        print_cust(f"angle_encode_data, pca.explained_variance_ratio_: {pca.explained_variance_ratio_}")
        print_cust(f"angle_encode_data, pca.explained_variance_: {pca.explained_variance_}")
        model, raw = pca, X_proj

    elif do_lda:
        # if y is None:
        #     raise ValueError("y labels are required when do_lda=True")
        # lda = LinearDiscriminantAnalysis(n_components=n_components)
        # X_proj = lda.fit_transform(X, y)
        if sketch_mat is None:
          raise ValueError("sketch_mat is required when do_lda=True")
        sketch_transformer = SketchTransformer(sketch_mat)
        X_proj = sketch_transformer.fit_transform(X)
        model, raw = sketch_transformer, X_proj

        # sklearn's LDA doesn't expose explained_variance_ratio_ by default
        # model, raw = lda, X_proj

    else:
        X_proj = X
        print_cust("angle_encode_data, no reduction applied")
        model, raw = None, X_proj

    # 2) scale each component to [0, π]
    X_angle = np.zeros_like(X_proj)
    for i in range(X_proj.shape[1]):
        comp = X_proj[:, i]
        lo, hi = comp.min(), comp.max()
        X_angle[:, i] = ( (comp - lo) / (hi - lo + 1e-8) ) * np.pi

    # 3) return
    if do_pca and ret_pca:
        return X_angle, model, raw
    if do_lda and ret_lda:
        return X_angle, model, raw
    return X_angle

def pennylane_angleencode(qubits, inputs, reps=1, generative=False, axis="Y", is_reuploading=False):
    """
    Perform angle encoding or amplitude encoding based on the size of the inputs
    with respect to the number of qubits.

    Parameters
    ----------
    qubits : Sequence[QubitIdentifier]
        The qubits/wires to encode the data onto.
    inputs : 1D or 2D array
        The inputs to encode.
    reps : int
        The number of times to encode the data (for angle encoding specifically).

    Returns
    -------
    None (modfifies qml to add operations to a quantum circuit)
    """
    print_cust(f"pennylane_angleencode, axis: {axis}")
    print_cust(f"pennylane_angleencode, generative: {generative}")
    if len(inputs.shape) == 1:
      input_dim = len(inputs)
    elif len(inputs.shape) == 2:
      input_dim = inputs.shape[1]
    print_cust(f"pennylane_angleencode, input_dim: {input_dim}")
    # NOTE, layers: changed the logic here. only if can fit the data on qubits, do I do angle encoding.
    if input_dim <= len(qubits):
      print_cust(f"pennylane_angleencode, doing angle encoding")
      for rep in range(reps):
        if len(inputs.shape) == 1:
          for qubit_idx in range(len(qubits)):
            qubit = qubits[qubit_idx]
            if generative:
              qml.RX(inputs[qubit_idx], wires=qubit)
            qml.RY(inputs[qubit_idx], wires=qubit)
        elif len(inputs.shape) == 2:
          print_cust(f"pennylane_angleencode, running qml.templates.AngleEmbedding")
          if generative:
            qml.templates.AngleEmbedding(inputs, wires=qubits, rotation="X")
          qml.templates.AngleEmbedding(inputs, wires=qubits, rotation=axis)
    else:
      if not is_reuploading:
        print_cust(f"pennylane_angleencode, doing amplitude encoding")
        qml.AmplitudeEmbedding(
            features=inputs,
            wires=qubits,
            pad_with=0,
            normalize=True
        )

# TOADD: custom angle encoding, for PCA QGAN.

import functools
import operator

def conv_layer(params, wires, encoded_angles=None, reversed_indices=None, pool_in=False):
    """
    A convolutional layer applying rotations and CNOTs on pairs of qubits.

    Parameters:
      params: a tuple (even_params, odd_params) where:
         - even_params is a numpy array of shape (num_even_pairs, 12) applied to pairs:
           (wires[0], wires[1]), (wires[2], wires[3]), ...
         - odd_params is a numpy array of shape (num_odd_pairs, 12) applied to pairs:
           (wires[1], wires[2]), (wires[3], wires[4]), ...
      wires: list of qubit indices.
      encoded_angles: list of # TODO
      reversed_indices: list of qubit indices for which to apply a reverse operation (used in identity initialization)
      pool_in: boolean indicating whether or not the qubits should pool inward (not used)

    For each block:
      1. Apply a first set of rotations (using parameters indices 0–5) on the two qubits.
      2. Apply a CNOT gate (entangling).
      3. Apply a second set of rotations (using parameters indices 6–11) on the two qubits.
      4. Apply a second CNOT (with reversed control).

    Returns:
      None (modifies qml to add operations)
    """
    num_wires = len(wires)

    # Process even pairs: indices 0,1; 2,3; etc.
    num_even_pairs = num_wires // 2
    for idx in range(num_even_pairs):
        i = idx * 2
        # First block of rotations on the even pair.
        qml.Rot(params[0][idx, 0], params[0][idx, 1], params[0][idx, 2], wires=wires[i])
        qml.Rot(params[0][idx, 3], params[0][idx, 4], params[0][idx, 5], wires=wires[i + 1])

        if encoded_angles is not None:
          # reversed indices is the QUBIT indices.
          if i in reversed_indices:
            qml.RY(-1 * encoded_angles[i], wires=wires[i])
          if (i + 1) in reversed_indices:
            qml.RY(-1 * encoded_angles[i + 1], wires=wires[i + 1])

        qml.CNOT(wires=[wires[i], wires[i + 1]])
        # Second block of rotations.
        qml.Rot(params[0][idx, 6], params[0][idx, 7], params[0][idx, 8], wires=wires[i])
        qml.Rot(params[0][idx, 9], params[0][idx, 10], params[0][idx, 11], wires=wires[i + 1])
        qml.CNOT(wires=[wires[i + 1], wires[i]])

    # Process odd pairs: pairs (wires[1], wires[2]), (wires[3], wires[4]), etc.
    num_odd_pairs = (num_wires - 1) // 2
    for idx in range(num_odd_pairs):
        i = 1 + idx * 2
        # First block of rotations for odd pair.
        qml.Rot(params[1][idx, 0], params[1][idx, 1], params[1][idx, 2], wires=wires[i])
        qml.Rot(params[1][idx, 3], params[1][idx, 4], params[1][idx, 5], wires=wires[i + 1])
        if encoded_angles is not None:
          if (i + 1) == (num_wires - 1):
            if (i + 1) in reversed_indices:
              qml.RY(-1 * encoded_angles[i + 1], wires=wires[i + 1])
        qml.CNOT(wires=[wires[i], wires[i + 1]])
        # Second block of rotations.
        qml.Rot(params[1][idx, 6], params[1][idx, 7], params[1][idx, 8], wires=wires[i])
        qml.Rot(params[1][idx, 9], params[1][idx, 10], params[1][idx, 11], wires=wires[i + 1])
        qml.CNOT(wires=[wires[i + 1], wires[i]])

    # If there is an extra wire at the end, apply an identity rotation.
    if num_wires % 2 == 1:
        qml.Rot(0.0, 0.0, 0.0, wires=wires[-1])

def pool_layer_with_measurement(params, wires, pool_in=False):
    """
    Parameters:
      params: a np.ndarray of parameters
      wires: list of qubit indices.
      pool_in: boolean indicating whether or not the qubits should pool inward (not used)

    Pooling layer using mid-circuit measurements.
    For each consecutive pair, measures the second qubit and conditionally applies RY on the first.
    Expects params of shape (number_of_pairs, 1) where number_of_pairs = floor(len(wires)/2).
    Returns a list of the retained (control) wires.
    """
    pooled_wires = []
    num_pairs = len(wires) // 2
    for i in range(0, num_pairs * 2, 2):
      if pool_in:
        if i < (num_pairs):
          m = qml.measure(wires[i])
          qml.cond(m, lambda w=wires[i + 1]: qml.RY(np.pi, wires=w))
          # qml.cond(m, lambda w=wires[i]: qml.RY(params[i // 2, 0], wires=w))
          pooled_wires.append(wires[i + 1])
        else:
          m = qml.measure(wires[i + 1])
          qml.cond(m, lambda w=wires[i]: qml.RY(np.pi, wires=w))
          pooled_wires.append(wires[i])
      else:
        m = qml.measure(wires[i + 1])
        qml.cond(m, lambda w=wires[i]: qml.RY(np.pi, wires=w))
        # qml.cond(m, lambda w=wires[i]: qml.RY(params[i // 2, 0], wires=w))
        pooled_wires.append(wires[i])
    if len(wires) % 2 == 1:
        pooled_wires.append(wires[-1])
    return pooled_wires

def final_pool(wires, param):
    """
    Parameters:
      wires: list of qubit indices.
      param: a float indicating the parameter to apply for pooling.

    Final pooling: measures the second qubit and conditionally rotates the first.
    Returns the first wire.
    """
    m = qml.measure(wires[1])
    qml.cond(m, lambda w=wires[0]: qml.RY(np.pi, wires=w))
    # qml.cond(m, lambda w=wires[0]: qml.RY(param, wires=w))
    return wires[0]

"""## Block circuit"""

# REPURPOSE: repurpose block_variational_circuit as the layers in my QGAN.
# NOTE: I could in theory just have one param for the rotations, but I'll just have 3, because I've alr accounted for it.
# loss.backward().... and stuff

def block_variational_circuit(params, layers, wires, generative=False, id_init_circ=False, is_hea=False, circuit_type="staircase",
                              reupload_func=None, inv_reupload_func=None, offset_idx=None):
    """
    Parameters:
      params: Array of shape (layers, len(wires), 3)
      layers: Number of layers in the block (should equal params.shape[0])
      wires: List of qubit indices.

    Applies a staircase hardware-efficient variational block.
    For each layer in the block:
      - Applies a qml.Rot gate (with 3 parameters) on each qubit in 'wires'.
      - Applies a CNOT on qubits (0,1), (1,2), ..., (n-1,0).

    Returns:
      None (modifies qml by adding circuit operations)
    """
    print_cust(f"block_variational_circuit, params.shape: {params.shape}")
    print_cust(f"block_variational_circuit, layers: {layers}")
    print_cust(f"block_variational_circuit, circuit_type: {circuit_type}")
    print_cust(f"block_variational_circuit, reupload_func: {reupload_func}")
    print_cust(f"block_variational_circuit, inv_reupload_func: {inv_reupload_func}")
    print_cust(f"block_variational_circuit, offset_idx: {offset_idx}")
    print_cust(f"block_variational_circuit, wires: {wires}")
    for layer in range(layers):
        print_cust(f"block_variational_circuit, layer: {layer}")
        if reupload_func is not None:
          print_cust(f"block_variational_circuit, applying reupload_func()")
          reupload_func()
        # Apply single-qubit rotations.
        # NOTE, layers: in general, block params is NOT assumed to have the same number of qubits as the
        # number of wires.
        for i in range(params.shape[1]):
            wire = wires[i]
            qml.Rot(params[layer, i, 0], params[layer, i, 1], params[layer, i, 2], wires=wire)
        # Apply the staircase of CNOTs.
        # NOTE, layers: in general, block params is NOT assumed to have the same number of qubits as the
        # number of wires.
        num_qubits = params.shape[1]
        n = num_qubits
        # if generative or is_hea:
        #   n = n - 1
        # TOMODIFY, layers: get rid of this below hack....
        # TOMODIFY, layers, HACK: generally this is NOT what I want, BUT I have it here to allow for final tuning in ID init.
        if ((layer + 1) % 3 == 0) and (layers % 3 == 0):
          print_cust(f"block_variational_circuit, layer: {layer}, layers: {layers}, skipping CNOT applications")
          continue
        print_cust(f'block_variational_circuit, n: {n}, num_qubits: {num_qubits}')
        if circuit_type == "staircase":
          if not id_init_circ or (layer % 2 == 0):
            print_cust(f"block_variational_circuit, staircase, not id_init_circ, even layer, layer: {layer}")
            for i in range(n):
                control = wires[i]
                target = wires[(i+1) % num_qubits]
                if i == (n - 1):
                  if not generative and not is_hea:
                    qml.CNOT(wires=[control, target])
                else:
                  qml.CNOT(wires=[control, target])
          elif id_init_circ and (layer % 2 == 1):
            print_cust(f"block_variational_circuit, staircase, id_init_circ, odd layer, layer: {layer}")
            for i in range(n - 1, -1, -1):
              control = wires[i]
              target = wires[(i+1) % num_qubits]
              if i == (n - 1):
                if not generative and not is_hea:
                  qml.CNOT(wires=[control, target])
              else:
                qml.CNOT(wires=[control, target])
        elif circuit_type == "brick":
          if not id_init_circ or (layer % 2 == 0):
            print_cust(f"block_variational_circuit, brick, not id_init_circ, even layer, layer: {layer}")
            for i in range(0, n - 1, 2):
              qml.CNOT(wires=[wires[i], wires[(i + 1)]])
            for i in range(1, n - 1, 2):
              qml.CNOT(wires=[wires[i], wires[(i + 1)]])
          elif id_init_circ and (layer % 2 == 1):
            print_cust(f"block_variational_circuit, brick, id_init_circ, odd layer, layer: {layer}")
            for i in reversed(range(1, n - 1, 2)):
              qml.CNOT(wires=[wires[i], wires[(i + 1)]])
            for i in reversed(range(0, n - 1, 2)):
              qml.CNOT(wires=[wires[i], wires[(i + 1)]])
        elif circuit_type == "x_shape":
          if not id_init_circ or (layer % 2 == 0):
            print_cust(f"block_variational_circuit, x_shape, not id_init_circ, even layer, layer: {layer}")
            for i in range(n - 1):
                control = wires[i]
                target = wires[(i + 1)]
                qml.CNOT(wires=[control, target])
                ending_control_idx = n - 2 - i
                if ending_control_idx != i:
                  control = wires[ending_control_idx]
                  target = wires[(ending_control_idx + 1)]
                  qml.CNOT(wires=[control, target])
          elif id_init_circ and (layer % 2 == 1):
            print_cust(f"block_variational_circuit, x_shape, id_init_circ, odd layer, layer: {layer}")
            for i in reversed(range(n - 1)):
                control = wires[i]
                target = wires[(i + 1)]
                qml.CNOT(wires=[control, target])
                ending_control_idx = n - 2 - i
                if ending_control_idx != i:
                  control = wires[ending_control_idx]
                  target = wires[(ending_control_idx + 1)]
                  qml.CNOT(wires=[control, target])
        elif circuit_type == "reversed_staircase":
          if not id_init_circ or (layer % 2 == 0):
            print_cust(f"block_variational_circuit, reversed_staircase, not id_init_circ, even layer, layer: {layer}")
            for i in range(n - 1, -1, -1):
                control = wires[i]
                target = wires[(i-1) % num_qubits]
                # TOMODIFY, depthFL: don't condition based on target QUBIT identifier, condition based on index
                if target == (n - 1):
                  if not generative and not is_hea:
                    qml.CNOT(wires=[control, target])
                else:
                  qml.CNOT(wires=[control, target])
          elif id_init_circ and (layer % 2 == 1):
            print_cust(f"block_variational_circuit, reversed_staircase, id_init_circ, odd layer, layer: {layer}")
            for i in range(n - 1):
              control = wires[i]
              target = wires[(i-1) % num_qubits]
              # TOMODIFY, depthFL: don't condition based on target QUBIT identifier, condition based on index
              if target == (n - 1):
                if not generative and not is_hea:
                  qml.CNOT(wires=[control, target])
              else:
                qml.CNOT(wires=[control, target])
        elif circuit_type == "v_shape":
          print_cust(f"block_variational_circuit, v_shape, layer, layer: {layer}")
          print_cust(f"block_variational_circuit, n: {n}, num_qubits: {num_qubits}")
          # TODO, depthFL: implement ID-init circ for v-shape circuit, too.
          # TOMODIFY, depthFL: this should go from i to n
          for i in range(n-1):
              print_cust(f"block_variational_circuit, v_shape, first loop, i: {i}")
              control = wires[i]
              target = wires[(i+1) % num_qubits]
              if i == (n - 1):
                if not generative and not is_hea:
                  qml.CNOT(wires=[control, target])
              else:
                qml.CNOT(wires=[control, target])
          for i in range(n - 1, -1, -1):
            print_cust(f"block_variational_circuit, v_shape, second loop, i: {i}")
            control = wires[i]
            target = wires[(i-1) % num_qubits]
            if i == 0:
              if not generative and not is_hea:
                qml.CNOT(wires=[control, target])
            else:
              qml.CNOT(wires=[control, target])
        elif circuit_type == "revstair_vshape":
          print_cust(f"block_variational_circuit, layer: {layer}")
          # if layers == 1:
          layer = offset_idx + layer
          print_cust(f"block_variational_circuit, after layer counts check, layer: {layer}")
          if (layer % 2 == 0):
            print_cust(f"block_variational_circuit, revstair_vshape, even layer, layer: {layer}")
            for i in range(n - 1, -1, -1):
                control = wires[i]
                target = wires[(i-1) % num_qubits]
                # TOMODIFY, depthFL: don't condition based on target QUBIT identifier, condition based on index
                if target == (n - 1):
                  if not generative and not is_hea:
                    qml.CNOT(wires=[control, target])
                else:
                  qml.CNOT(wires=[control, target])
          elif (layer % 2 == 1):
            print_cust(f"block_variational_circuit, revstair_vshape, odd layer, layer: {layer}")
            for i in range(n-1):
                control = wires[i]
                target = wires[(i+1) % num_qubits]
                if i == (n - 1):
                  if not generative and not is_hea:
                    qml.CNOT(wires=[control, target])
                else:
                  qml.CNOT(wires=[control, target])
            for i in range(n - 1, -1, -1):
              control = wires[i]
              target = wires[(i-1) % num_qubits]
              # TOMODIFY, depthFL: don't condition based on target QUBIT identifier, condition based on index
              if target == (n - 1):
                if not generative and not is_hea:
                  qml.CNOT(wires=[control, target])
              else:
                qml.CNOT(wires=[control, target])
        if inv_reupload_func is not None:
           print_cust(f"block_variational_circuit, applying inv_reupload_func")
           inv_reupload_func()        

# TOADD: a generative statement above, to indicate whether or not we want circular entanglement.

# DONE (implicitly, no change, just not generative): TOMODIFY, layers: have circular entanglement for classification.

"""### Helper for unpermuting inputs"""

def unpermute_inputs(inputs, indices):
  new_inputs = [0 for _ in range(len(inputs))]
  for orig_input_idx, targ_input_idx in enumerate(indices):
    new_inputs[targ_input_idx] = inputs[orig_input_idx]
  return new_inputs

# test_inputs = [4, 0, 5, 1, 2, 6, 3, 7]
# test_indices = list(test_inputs)

# print_cust(unpermute_inputs(test_inputs, test_indices))

"""## QCNN Circuit"""

from functools import partial

def QCNN_circuit_dynamic(inputs, conv_params_tuple, pool_params_tuple, final_pool_param, final_params, n_qubits, n_classes=2, expansion_data=[], block_params_list=[], num_ancillas=0, layer_types_list=[], cheating=False, tunn_down=False):
    """
    Dynamic QCNN circuit with optional variational blocks.
    Parameters:
      inputs: Numpy array representing the input data to the circuit.
      conv_params_tuple: Numpy array representing the convolutional parameters for the circuit.
      pool_params_tuple: Numpy array representing the pooling parameters for the circuit.
      final_pool_param: Numpy array representing the final pool parameter for the circuit.
      final_params: Numpy array representing the final parameters used in the circuit.
      n_qubits: Integer representing the number of qubits in the circuit.
      n_classes: Integer represnting the number of classes to classify for. (currently, in the code, it is assumed that n_output_qubits = int(np.ceil(np.log2(n_classes)))).
      expansion_data: A list of (encoded_angles, reversed_indices) for use for maintaining an identity initialization for this QCNN when the circuit expands, if supplied.
      block_params_list: list of block parameter arrays. Each array should have shape
        (num_layers_bp, num_qubits_bp, 3). If nonempty, then as the circuit is built,
        whenever the current number of wires equals num_qubits_bp, the block is applied.

      Returns:
        The probability of measuring the amplitudes for the remaining qubits.
    """
    # Data embedding
    # print_cust(f"QCNN_circuit_dynamic, inputs: {inputs}")
    # assume that the input data is permuted correctly to fit linearly on the qubits.
    print_cust(f"QCNN_circuit_dynamic, n_qubits: {n_qubits}, num_ancillas: {num_ancillas}")

    print_cust(f"QCNN_circuit_dynamic, cheating: {cheating}")

    print_cust(f"QCNN_circuit_dynamic, tunn_down: {tunn_down}")

    num_tot_qubits = n_qubits

    n_qubits = n_qubits - num_ancillas

    print_cust(f"QCNN_circuit_dynamic, n_qubits: {n_qubits}, num_tot_qubits: {num_tot_qubits}")

    unpermuted_inputs = None

    num_layers = len(conv_params_tuple)

    print_cust(f"QCNN_circuit_dynamic, num_layers: {num_layers}")

    if num_layers > 0:
      print_cust(f"QCNN_circuit_dynamic, applying pennylane_angleencode at the start because num_layers > 0")
      pennylane_angleencode(range(0, n_qubits), inputs)
    wires = list(range(n_qubits))

    # TOMODIFY, layers: quick thing to tell the circuit whether or not I want HEA layers.
    is_hea = True
    # TOMODIFY, layers, HACK: hack to prespecify the types of layers as I want; please INJECT THIS IN.
    # layer_types_list = ["staircase", "staircase", "staircase", "staircase", "staircase"]

    if len(layer_types_list) == 0:
      # TOMODIFY, layers: some placeholder, just in case the layer types is not supplied.
      print_cust(f"QCNN_circuit_dynamic, len(layer_types_list) == 0, supplying some placeholder layer_types_list")
      layer_types_list = ["reversed_staircase" for _ in range(len(block_params_list))]
      print_cust(f"QCNN_circuit_dynamic, placeholder layer_types_list: {layer_types_list}")

    # Copy block_params_list so that each block is applied only once.
    block_params_remaining = block_params_list

    # Each layer contains: a block layer (if supplied), a convolutional layer, and a pooling layer.
    for layer in range(num_layers):
        print_cust(f"QCNN_circuit_dynamic, applying layer, layer: {layer}")
        # Check if any block parameter should be applied at this stage.
        new_block_params_remaining = []
        for bp in block_params_remaining:
            # bp.shape is (num_layers_bp, num_qubits_bp, 3)
            if bp.shape[1] == len(wires):
                num_layers_bp = bp.shape[0]
                block_variational_circuit(bp, num_layers_bp, wires)
            else:
                new_block_params_remaining.append(bp)
        block_params_remaining = new_block_params_remaining

        conv_params = conv_params_tuple[layer]
        pool_params = pool_params_tuple[layer]
        expansion_layer_data = expansion_data[layer]
        encoded_angles = None
        reversed_indices = None
        # NOTE: the encoded_angles here are INDICES of the input features, not the ACTUAL data values.
        # reversed_indices are the QUBIT LABELS/INDICES.
        if expansion_layer_data != []:
            encoded_angle_indices = expansion_layer_data[0]
            # relying on the fact that this is true, but it's OK i think
            if len(encoded_angle_indices) == len(inputs):
              unpermuted_inputs = unpermute_inputs(inputs, encoded_angle_indices)
              # print_cust(f"QCNN_circuit_dynamic, unpermuted_inputs: {unpermuted_inputs}")
            encoded_angles = [unpermuted_inputs[feat_idx] for feat_idx in encoded_angle_indices]
            reversed_indices = expansion_layer_data[1]
            # print_cust(f"QCNN_circuit_dynamic, encoded_angles: {encoded_angles}, reversed_indices: {reversed_indices}")

        # print_cust(f"QCNN_circuit_dynamic, encoded_angles: {encoded_angles}, reversed_indices: {reversed_indices}")

        # Apply a convolutional and pooling layer.
        conv_layer(conv_params, wires, encoded_angles=encoded_angles, reversed_indices=reversed_indices, pool_in=True)
        wires = pool_layer_with_measurement(pool_params, wires, pool_in=True)

    # Compute the number of output qubits needed.
    n_output_qubits = int(np.ceil(np.log2(n_classes)))
    if len(wires) < n_output_qubits:
        raise ValueError(f"Not enough wires left after pooling for the requested number of classes, len(wires): {len(wires)}, n_output_qubits: {n_output_qubits}.")

    # Apply another block layer on the last qubits.
    # NOTE: num_layers_applied is wrt to this specific variational block.
    num_layers_applied = 0
    # ret_probs_list_tunn = []
    for bp_idx, bp in enumerate(block_params_remaining):
      # TOMODIFY, layers: quick hack to see if ID init even helps.
      id_init_circ = False
      reupload_data = True
      print_cust(f"QCNN_circuit_dynamic, bp_idx: {bp_idx}, id_init_circ: {id_init_circ}")
      # bp.shape is (num_layers_bp, num_qubits_bp, 3)
      # NOTE, layers: if I'd like to have the smaller qubit count classifiers, I'd need to get rid of this below condition. (breaks
      # backwards compatibility, but whatever for now.)
      print_cust(f"QCNN_circuit_dynamic, len(wires): {len(wires)}")
      if bp.shape[1] == len(wires) or tunn_down:
          if (bp_idx % 2 == 0):
            axis = "Y"
          else:
            axis = "Y"
          print_cust(f"QCNN_circuit_dynamic, bp_idx: {bp_idx}, axis: {axis}")
          layer_type = layer_types_list[bp_idx]
          print_cust(f"QCNN_circuit_dynamic, layer_type: {layer_type}")
          if reupload_data:
            # HACKY, depthFL: using the size of the inputs
            if len(inputs.shape) == 2:
              print_cust(f"QCNN_circuit_dynamic, reupload_data, truncating input dim for inputs.shape == 2")
              reuploaded_inputs = inputs[:, -(len(wires)):]
            elif len(inputs.shape) == 1:
              print_cust(f"QCNN_circuit_dynamic, reupload_data, truncating input dim for inputs.shape == 1")
              reuploaded_inputs = inputs[-(len(wires)):]
            reupload_func = partial(pennylane_angleencode, qubits=wires, inputs=reuploaded_inputs, axis=axis, is_reuploading=True)
          else:
            reupload_func = None
          if id_init_circ:
            if len(inputs.shape) == 2:
              print_cust(f"QCNN_circuit_dynamic, id_init_circ, truncating input dim for inputs.shape == 2")
              reuploaded_inputs = inputs[:, -(len(wires)):]
            elif len(inputs.shape) == 1:
              print_cust(f"QCNN_circuit_dynamic, id_init_circ, truncating input dim for inputs.shape == 1")
              reuploaded_inputs = inputs[-(len(wires)):]
            inv_reupload_func = partial(pennylane_angleencode, qubits=wires, inputs=-1 * reuploaded_inputs, axis=axis, is_reuploading=True)
          else:
            inv_reupload_func = None
          print_cust(f"QCNN_circuit_dynamic, reupload_func: {reupload_func}, inv_reupload_func: {inv_reupload_func}")
          num_layers_bp = bp.shape[0]
          # TOMODIFY, layers, HACK: prespecify the order of the layer types.
          print_cust(f"QCNN_circuit_dynamic, bp_idx: {bp_idx}, num_layers_applied: {num_layers_applied}")
          block_variational_circuit(bp, num_layers_bp, wires, id_init_circ=id_init_circ, is_hea=is_hea, circuit_type=layer_type, reupload_func=reupload_func, inv_reupload_func=inv_reupload_func, offset_idx=num_layers_applied)
          # if id_init_circ:
          #   pennylane_angleencode(range(0, n_qubits), -1 * inputs, axis=axis, is_reuploading=True)
          num_layers_applied += num_layers_bp
          if num_ancillas > 0:
            print_cust(f"QCNN_circuit_dynamic, using ancillas for getting statistics")
            # TOMODIFY, depthFL: have a better way of extracting midcirc statistics; make the below line more
            # configurable.
            qml.CNOT(wires=[wires[0], n_qubits + bp_idx])
          if cheating:
            print_cust(f"QCNN_circuit_dynamic, bp_idx: {bp_idx}, applying cheating measurement on q0")
            qml.Snapshot(f"p0_bp{bp_idx}", measurement=qml.probs(wires=[0]))
          if tunn_down:
            print_cust(f"QCNN_circuit_dynamic, doing tunneling down accumulation")
            # ret_probs_list_tunn.append(qml.probs(wires=wires[0]))
            wires = wires[1:]

    # Instead of performing final pooling to a single wire, apply an extra convolution.
    # Here, we assume final_params is formatted like a conv layer
    # that acts on all the remaining wires.
    # Apply another convolution layer on the last qubits.
    if len(wires) == 1:
      qml.Rot(final_params[0, 0], final_params[0, 1], final_params[0, 2], wires=wires[0])
    else:
      if num_layers > 0:
        conv_layer(final_params, wires, pool_in=True)

    # Return the probability distribution for the first n_output_qubits.
    if num_ancillas == 0:
      if tunn_down:
        print_cust(f"QCNN_circuit_dynamic, returning list of probs from tunneling down")
        ret_probs_list_tunn = [qml.probs(wires=qubit_idx) for qubit_idx in range(0, len(block_params_remaining))]
        return ret_probs_list_tunn
      else:
        return qml.probs(wires=wires[:n_output_qubits])
    else:
      print_cust(f"QCNN_circuit_dynamic, returning ancilla statistics")
      n_ancillas_used = len(block_params_remaining)
      print_cust(f"QCNN_circuit_dynamic, n_ancillas_used: {n_ancillas_used}")
      return [qml.probs(wires=ancilla_qubit_idx) for ancilla_qubit_idx in range(n_qubits, n_qubits + n_ancillas_used)]

# TOADD, Layers: A function that takes in inputs, n_qubits, and block_params_list, and returns the probs.
# For speed, augment to use Pennylane's angleencoding.
# LOWKEY, I don't think I need to add a func?
# can do TOMODIFY for the above (specifically, for speed, use pennylane angleencoding.)
# DONE: TOMODIFY, layers: only if num conv layers is NOT 0, then do I apply a final conv layer.

"""## QGAN Circuit"""

def QGAN_circuit(noise, n_qubits, block_params_list=[]):
  block_params_list_sorted = sorted(block_params_list, key=lambda x:x.shape[1])
  for block_params in block_params_list_sorted:
    print_cust(f"QGAN_circuit, block_params.shape: {block_params.shape}")

  if len(noise.shape) == 1:
    pennylane_angleencode(range(0, n_qubits), noise, generative=True)
  elif len(noise.shape) == 2:
    print_cust(f"QGAN_circuit, running pennylane.angleencode")
    qml.templates.AngleEmbedding(noise, wires=range(0, n_qubits), rotation="X")
    qml.templates.AngleEmbedding(noise, wires=range(0, n_qubits), rotation="Y")
  wires = list(range(n_qubits))

  for block_params in block_params_list_sorted:
    num_layers_bp = block_params.shape[0]
    num_qubits_bp = block_params.shape[1]
    block_variational_circuit(block_params, num_layers_bp, list(range(num_qubits_bp)), generative=True)

  return [qml.expval(qml.Z(wires=qubit_idx_ret)) for qubit_idx_ret in range(n_qubits)]

"""### QGAN Expectations Helper"""

def get_expectations(noise, qnode, n_qubits_func, block_params_list):
    # Non-linear Transform
    ret_expvals = qnode(noise, block_params_list)
    # print_cust(f"get_expectations, type(ret_expvals): {type(ret_expvals)}")
    # print_cust(f"get_expectations, type(weights): {type(weights)}")
    # print_cust(f"get_expectations, len(weights): {len(weights)}")
    # for block_param in block_params_list:
    #     print_cust(f"get_expectations, block_param.shape: {block_param.shape}")
    ret_expvals_sum = sum(ret_expvals)
    print_cust(f"get_expectations, type(ret_expvals_sum): {type(ret_expvals_sum)}")
    print_cust(f"get_expectations, ret_expvals_sum: {ret_expvals_sum}")
    # weights.zero_grad()
    # ret_expvals_sum.backward()
    # print_cust(f"get_expectations, weights.grad: {weights.grad}")
    # print_cust(f"get_expectations, np.linalg.norm(weights.grad): {np.linalg.norm(weights.grad)}")
    # grad_fn_debug = qml.grad(qnode, argnum=0)
    # grads_test = grad_fn_debug(weights, n_qubits_circ=n_qubits_func, qubit_depth_dict=qubit_depth_dict, alpha=alpha, ret_exp=True)
    # print_cust(f"get_expectations, grads_test: {grads_test}")
    # print_cust(f"get_expectations, np.linalg.norm(grads_test): {np.linalg.norm(grads_test)}")
    # print_cust(f"partial_measure, probs: {probs}")
    # print_cust(f"partial_measure, torch.sum(probs): {torch.sum(probs)}")
    # probsgiven0 = probs[: (2 ** (n_qubits_func - n_a_qubits_func))]
    # # print_cust(f"partial_measure, probsgiven0: {probsgiven0}")
    # # print_cust(f"partial_measure, torch.sum(probsgiven0): {torch.sum(probsgiven0)}")
    # probsgiven0 /= torch.sum(probsgiven0)
    # # print_cust(f"partial_measure, torch.max(probsgiven0): {torch.max(probsgiven0)}")

    # # Post-Processing
    # probsgiven = probsgiven0 / torch.max(probsgiven0)
    # print_cust(f"partial_measure, probsgiven: {probsgiven}")
    torch_ret_expvals = torch.stack(ret_expvals)
    print_cust(f"get_expectations, type(torch_ret_expvals): {type(torch_ret_expvals)}")
    print_cust(f"get_expectations, torch_ret_expvals.shape: {torch_ret_expvals.shape}")
    return torch_ret_expvals

# TOADD: a QGAN circuit that returns Z expectations on individual qubits.

"""## Dynamic parameter initialization

### Definition of PCARescaler
"""

import torch.nn.functional as F

import os, sys, resource, time

class PCARescaler:
    def __init__(self, pca_obj, pca_mins, pca_maxs, inv_pca_mins, inv_pca_maxs, device):
        self.pca_obj = pca_obj
        self.pca_mins = pca_mins
        self.pca_maxs = pca_maxs
        self.inv_pca_min = inv_pca_mins
        self.inv_pca_max = inv_pca_maxs
        self.device = device

    def __str__(self):
        return f"""
        PCARescaler:
        self.pca_obj: {self.pca_obj}
        self.pca_mins: {self.pca_mins}
        self.pca_maxs: {self.pca_maxs}
        self.inv_pca_min: {self.inv_pca_min}
        self.inv_pca_max: {self.inv_pca_max}
        self.device: {self.device}
        """
    __repr__ = __str__

    def rescale_pca_comps(self, gen_data):
        # it is assumed that gen_data is b/w 0 and 1.|
        print_cust(f"PCARescaler, gen_data.shape: {gen_data.shape}")
        gen_data_dim = gen_data.shape[1]
        pca_maxs = self.pca_maxs
        pca_mins = self.pca_mins
        if isinstance(pca_maxs, torch.return_types.max):
            pca_maxs = pca_maxs.values
        if isinstance(pca_mins, torch.return_types.min):
            pca_mins = pca_mins.values
        pca_maxs = pca_maxs[:gen_data_dim]
        pca_mins = pca_mins[:gen_data_dim]

        print_cust(f"PCARescaler, pca_maxs: {pca_maxs}, pca_mins: {pca_mins}")

        pca_comps = gen_data * (pca_maxs - pca_mins) + pca_mins
        print_cust(f"PCARescaler, rescale_pca_comps, pca_comps: {pca_comps}")
        print_cust(f"PCARescaler, type(pca_comps): {type(pca_comps)}")
        print_cust(f"PCARescaler, rescale_pca_comps, torch.min(pca_comps, dim=0): {torch.min(pca_comps, dim=0)}, torch.max(pca_comps, dim=0): {torch.max(pca_comps, dim=0)}")
        # print_cust("[DBG]", os.getpid(), "about to print; rss", resource.getrusage(resource.RUSAGE_SELF).ru_maxrss, file=sys.stderr, flush=True)
        # try:
        #   print_cust(f"PCARescaler, rescale_pca_comps, torch.min(pca_comps, dim=0): {torch.min(pca_comps, dim=0)}, torch.max(pca_comps, dim=0): {torch.max(pca_comps, dim=0)}")
        # except Exception as e:
        #   print_cust(f"PCARescaler, exception e: {e}")
        #   raise e
        print_cust(f"PCARescaler, about to return pca_comps")
        # gen_imgs_reconstr = self.pca_obj.inverse_transform(pca_comps)

        # print_cust(f"PCARescaler, rescale_pca_comps, gen_imgs_reconstr: {gen_imgs_reconstr}")

        # print_cust(f"PCARescaler, rescale_pca_comps, gen_imgs_reconstr.min(): {gen_imgs_reconstr.min()}, gen_imgs_reconstr.max(): {gen_imgs_reconstr.max()}")

        # gen_imgs_reconstr_rescaled = (gen_imgs_reconstr - self.inv_pca_min) / (self.inv_pca_max - self.inv_pca_min)

        # print_cust(f"PCARescaler, rescale_pca_comps, after min-max, gen_imgs_reconstr_rescaled.min(): {gen_imgs_reconstr_rescaled.min()}, gen_imgs_reconstr_rescaled.max(): {gen_imgs_reconstr_rescaled.max()}")

        return pca_comps

    def invpca_to_img(self, gen_pca_comps, resc_invpca=True):
        pca_dim = self.pca_obj.n_components
        print_cust(f"PCARescaler, invpca_to_img, resc_invpca: {resc_invpca}")
        print_cust(f"PCARescaler, invpca_to_img, pca_dim: {pca_dim}")

        cur_dim = gen_pca_comps.size(-1)             # length of the last (fast-changing) axis
        if cur_dim < pca_dim:
            pad_right = pca_dim - cur_dim
            # (left, right) for 1-D padding of the last dimension
            gen_pca_comps = F.pad(gen_pca_comps, (0, pad_right), mode="constant", value=0.0)
            print_cust(f"PCARescaler, invpca_to_img, gen_pca_comps.shape: {gen_pca_comps.shape}")
            print_cust(f"PCARescaler, invpca_to_img, after padding, gen_pca_comps: {gen_pca_comps}")

        gen_pca_comps_numpy = gen_pca_comps.detach().cpu().numpy()

        gen_imgs_reconstr_numpy = self.pca_obj.inverse_transform(gen_pca_comps_numpy)

        gen_imgs_reconstr = torch.from_numpy(gen_imgs_reconstr_numpy).to(self.device)

        print_cust(f"PCARescaler, invpca_to_img, gen_imgs_reconstr: {gen_imgs_reconstr}")

        print_cust(f"PCARescaler, invpca_to_img, gen_imgs_reconstr.min(): {gen_imgs_reconstr.min()}, gen_imgs_reconstr.max(): {gen_imgs_reconstr.max()}")

        if resc_invpca:
          print_cust(f"PCARescaler, invpca_to_img, rescaling generated images")
          gen_imgs_reconstr_rescaled = (gen_imgs_reconstr - self.inv_pca_min) / (self.inv_pca_max - self.inv_pca_min)
        else:
          print_cust(f"PCARescaler, invpca_to_img, leaving generated images alone")
          gen_imgs_reconstr_rescaled = gen_imgs_reconstr

        print_cust(f"PCARescaler, invpca_to_img, after min-max, gen_imgs_reconstr_rescaled.min(): {gen_imgs_reconstr_rescaled.min()}, gen_imgs_reconstr_rescaled.max(): {gen_imgs_reconstr_rescaled.max()}")

        return gen_imgs_reconstr_rescaled

    def get_data_components(self):
      return self.pca_obj, self.pca_mins, self.pca_maxs, self.inv_pca_min, self.inv_pca_max, self.device

"""### Definition of PCADiscriminator"""

import torch.nn as nn

# Maybe change model architecture??? (if it's not working well)
class PCADiscriminator(nn.Module):
    def __init__(self, input_size, scale_factor):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, input_size * scale_factor),
            nn.ReLU(),
            nn.Linear(input_size * scale_factor, input_size),
            nn.ReLU(),
            nn.Linear(input_size, 1),
            nn.Sigmoid()
        )
        self.input_size = input_size
        self.scale_factor = scale_factor

    def __str__(self):
        return f"""
        PCADiscriminator
        self.net: {self.net}
        self.input_size: {self.input_size}
        self.state_dict(): {self.state_dict()}
        self.scale_factor: {self.scale_factor}
        """
    __repr__ = __str__

    def forward(self, x, alpha=1.0, generator_weights=None):
        """
        If x.shape[-1] < self.target_dim, pad zeros on the *right* so that the
        last dimension equals self.target_dim, then pass through self.net.
        """
        cur_dim = x.size(-1)             # length of the last (fast-changing) axis
        if cur_dim < self.input_size:
            pad_right = self.input_size - cur_dim
            # (left, right) for 1-D padding of the last dimension
            x = F.pad(x, (0, pad_right), mode="constant", value=0.0)
            print_cust(f"PCADiscriminator, forward, x.shape: {x.shape}")
            print_cust(f"PCADiscriminator, forward, after padding, x: {x}")
        return self.net(x)

    def get_data_components(self):
      return (self.input_size, self.scale_factor)

"""### Definition of PatchQuantumGenerator"""

class PatchQuantumGenerator(nn.Module):
    """Quantum generator class for the patch method"""

    def __init__(self, n_generators, q_delta, n_qubits_gen, n_a_qubits_gen, qnode, qubit_depth_dict, device, img_size, gen_pca=True, pca_rescaler=None):
        """
        Args:
            n_generators (int): Number of sub-generators to be used in the patch method.
            q_delta (float, optional): Spread of the random distribution for parameter initialisation.
        """

        super().__init__()

        # NOTE: getting rid of patch for now; can adapt code for improvement later.

        # if existing_params is None:

        #     self.q_params = nn.ParameterList(
        #         [
        #             nn.Parameter(q_delta * torch.rand(q_depth_layer * n_qubits_layer), requires_grad=True)
        #             for n_qubits_layer, q_depth_layer in qubit_depth_dict.items()
        #         ]
        #     )
        # else:
        #     # TODO: ensure that existing_params is indeed a valid PyTorch tensor
        #     self.q_params = nn.ParameterList(
        #         [
        #             nn.Parameter(existing_params, requires_grad=True)
        #             for _ in range(n_generators)
        #         ]
        #     )

        # will assume that existing_params is a list.
        # existing_num_qubits_params = [params.shape[1] for params in existing_params]
        # print_cust(f"PatchQuantumGenerator, existing_num_qubits_params: {existing_num_qubits_params}")
        params_qnode = []

        for n_qubits_layer in sorted(qubit_depth_dict.keys()):
            q_depth_layer = qubit_depth_dict[n_qubits_layer]
            # if n_qubits_layer not in existing_num_qubits_params:
            cur_q_params = nn.Parameter(q_delta * torch.rand(q_depth_layer, n_qubits_layer, 3), requires_grad=True)
            # else:
            #     existing_params_idx = existing_num_qubits_params.index(n_qubits_layer)
            #     cur_q_params = nn.Parameter(existing_params[existing_params_idx], requires_grad=True)
            params_qnode.append(cur_q_params)

        print_cust(f"PatchQuantumGenerator, params_qnode: {params_qnode}")

        self.q_params = nn.ParameterList(params_qnode)

        self.n_generators = n_generators
        self.n_qubits_gen = n_qubits_gen
        self.n_a_qubits_gen = n_a_qubits_gen
        self.qnode = qnode
        self.qubit_depth_dict = qubit_depth_dict
        self.device = device
        self.img_size = img_size
        self.gen_pca = gen_pca
        self.pca_rescaler = pca_rescaler
        self.q_delta = q_delta

    def __str__(self):
        return f"""
        self.q_params: {self.q_params}
        self.n_generators: {self.n_generators}
        self.q_delta: {self.q_delta}
        self.n_qubits_gen: {self.n_qubits_gen}
        self.n_a_qubits_gen: {self.n_a_qubits_gen}
        self.qnode: {self.qnode}
        self.qubit_depth_dict: {self.qubit_depth_dict}
        self.device: {self.device}
        self.img_size: {self.img_size}
        self.gen_pca: {self.gen_pca}
        self.pca_rescaler: {self.pca_rescaler}
        self.state_dict(): {self.state_dict()}
        """
    __repr__ = __str__

    def initialize_existing_parameters(self, existing_params):
      # assume that existing params is NOT larger than current param sizes.
      if existing_params is None:
        return
      current_params_qubits = [params.shape[1] for params in self.q_params]
      for existing_params_tens in existing_params:
        print_cust(f"initialize_existing_parameters, type(existing_params_tens): {type(existing_params_tens)}")
        existing_params_tens_qub = existing_params_tens.shape[1]
        if existing_params_tens_qub in current_params_qubits:
          current_param_idx = current_params_qubits.index(existing_params_tens_qub)
          existing_param_nn = nn.Parameter(existing_params_tens, requires_grad=True)
          self.q_params[current_param_idx] = existing_param_nn

      print_cust(f"PatchQuantumGenerator, initialize_existing_parameters, self.q_params: {self.q_params}")
      print_cust(f"PatchQuantumGenerator, initialize_existing_parameters, self.state_dict(): {self.state_dict()}")

    def forward(self, x, n_qubits_gen_forward=None, alpha=1.0):
        # Size of each sub-generator output
        if n_qubits_gen_forward is None:
            n_qubits_gen_forward = self.n_qubits_gen
        patch_size = self.n_qubits_gen * self.n_generators
        print_cust(f"PatchQuantumGenerator, forward, len(self.q_params): {len(self.q_params)}")

        # Create a Tensor to 'catch' a batch of images from the for loop. x.size(0) is the batch size.
        images = torch.Tensor(x.size(0), 0).to(self.device)
        print_cust(f"PatchQuantumGenerator, forward, images.shape: {images.shape}")

        # # Iterate over all sub-generators
        # for params in self.q_params:

        # def tape(name):
        #     return lambda g: grad_log_glob.setdefault(name, g.norm().item())
        # def save_norm(name):
        #     def hook(g):
        #         grad_log_glob[name] = (g.norm().item(), g)
        #         return g
        #     return hook

        # TODO: adapt this code for multiple subgenerators
        # Create a Tensor to 'catch' a batch of the patches from a single sub-generator
        patches = torch.Tensor(0, patch_size).to(self.device)
        print_cust(f"PatchQuantumGenerator, forward, patches.shape: {patches.shape}")
        # for elem in x:
        #     if self.gen_pca:
        #         q_out = get_expectations(elem, self.qnode, n_qubits_gen_forward, self.q_params).float().unsqueeze(0)
        #         q_out = (q_out + 1.0) / 2.0
        #         print_cust(f"PatchQuantumGenerator, forward, q_out: {q_out}")
        #         # q_out.retain_grad()
        #         # q_out.register_hook(save_norm("before"))
        #     # else:
        #     #     q_out = partial_measure(elem, self.q_params, self.qnode, n_qubits_func=n_qubits_gen_forward, n_a_qubits_func=self.n_a_qubits_gen, qubit_depth_dict=self.qubit_depth_dict, alpha=alpha).float().unsqueeze(0)

        #     patches = torch.cat((patches, q_out))
        #     print_cust(f"PatchQuantumGenerator, forward, patches.shape: {patches.shape}")

        if self.gen_pca:
          q_out = get_expectations(x, self.qnode, n_qubits_gen_forward, self.q_params).float()
          q_out = (q_out + 1.0) / 2.0
          print_cust(f"PatchQuantumGenerator, forward, q_out: {q_out}")
          patches = q_out.T

        # Each batch of patches is concatenated with each other to create a batch of images
        print_cust(f"PatchQuantumGenerator, patches.shape: {patches.shape}")
        print_cust(f"PatchQuantumGenerator, patches.device: {patches.device}")
        print_cust(f"PatchQuantumGenerator, forward, patches: {patches}")
        images = torch.cat((images, patches), 1)
        # print_cust(f"PatchQuantumGenerator, forward, images.shape: {patches.shape}")

        # images = (images + 1) / 2
        print_cust(f"PatchQuantumGenerator, forward, images.shape: {images.shape}")
        if self.gen_pca:
            print_cust(f"PatchQuantumGenerator, is gen_pca, so rescaling expectations to PCA components")
            reconstr_images = self.pca_rescaler.rescale_pca_comps(images)
            print_cust(f"PatchQuantumGenerator, reconstr_images.min(): {reconstr_images.min()}, reconstr_images.max(): {reconstr_images.max()}")
            reconstr_images_abs = reconstr_images.abs()
            print_cust(f"PatchQuantumGenerator, reconstr_images_abs.min(): {reconstr_images_abs.min()}, reconstr_images_abs.max(): {reconstr_images_abs.max()}")
            reconstr_images_sum = reconstr_images_abs.sum()
            print_cust(f"PatchQuantumGenerator, reconstr_images_sum: {reconstr_images_sum}")
            # self.q_params.zero_grad()
            # print_cust(f"PatchQuantumGenerator, forward, self.q_params[0].grad: {self.q_params[0].grad}")
            # reconstr_images_sum.backward()
            # self.q_params.zero_grad()
            # print_cust(f'PatchQuantumGenerator, forward, after backward(), self.q_params[0].grad: {self.q_params[0].grad}')
            # print_cust(f"PatchQuanutmGenerator, forward, after backward(), np.linalg.norm(self.q_params[0].grad): {np.linalg.norm(self.q_params[0].grad)}")
            # reconstr_images.retain_grad()
            # reconstr_images.register_hook(save_norm("after"))
        else:
            reconstr_images = images
        return reconstr_images

    def get_data_components(self):
      return (self.n_generators, self.q_delta, self.n_qubits_gen, self.n_a_qubits_gen, None, self.qubit_depth_dict, self.device, self.img_size, self.gen_pca, self.pca_rescaler.get_data_components())

"""### Definition of VariationalQuantumClassifier"""

import copy

class VariationalQuantumClassifier(nn.Module):
  def __init__(self, qnode, device="cpu"):
    super().__init__()
    # Because these will be passed in.
    self.q_params = None
    self.all_params = None
    self.qnode = qnode
    self.device = device

  def __str__(self):
    return f"""
    VariationalQuantumClassifier,
    self.q_params: {self.q_params}
    self.all_params: {self.all_params}
    self.qnode: {self.qnode}
    self.device: {self.device}
    """

  __repr__ = __str__

  def initialize_existing_parameters(self, existing_params):
    print_cust(f"VariationalQuantumClassifier, initialize_existing_parameters, existing_params: {existing_params}")
    self.all_params = copy.deepcopy(list(existing_params))
    param_list = []
    for param_tens in existing_params[5]:
      # TODO, layers: assert that param_tens is a PyTorch tensor?
      param_nnparam = nn.Parameter(param_tens, requires_grad=True)
      param_list.append(param_nnparam)
    self.q_params = nn.ParameterList(param_list)
    self.all_params[5] = self.q_params
    self.all_params = tuple(self.all_params)

  def forward(self, x):
    # TODO, layers: change this to batched later; trying to make minimal changes.
    # list_tens_probs = []
    # for elem in x:
    #   tens_probs = self.qnode(elem, self.all_params).float()
    #   print_cust(f"VariationalQuantumClassifier, tens_probs: {tens_probs}")
    #   print_cust(f"VariationalQuantumClassifier, tens_probs.shape: {tens_probs.shape}")
    #   print_cust(f"VariationalQuantumClassifier, type(tens_probs): {type(tens_probs)}")
    #   print_cust(f"VariationalQuantumClassifier, tens_probs.dtype: {tens_probs.dtype}")
    #   list_tens_probs.append(tens_probs)
    # stacked_tens_probs = torch.stack(list_tens_probs)
    tens_probs = self.qnode(x, self.all_params).float()
    print_cust(f"VariationalQuantumClassifier, tens_probs.shape: {tens_probs.shape}")
    stacked_tens_probs = tens_probs
    print_cust(f"VariationalQuantumClassifier, stacked_tens_probs: {stacked_tens_probs}")
    print_cust(f"VariationalQuantumClassifier, stacked_tens_probs.shape: {stacked_tens_probs.shape}")
    print_cust(f"VariationalQuantumClassifier, type(stacked_tens_probs): {type(stacked_tens_probs)}")
    print_cust(f"VariationalQuantumClassifier, stacked_tens_probs.dtype: {stacked_tens_probs.dtype}")
    return stacked_tens_probs

# DONE: TOADD, layers: pytorch class for the quantum classifier.
# DONE: NOTE, layers: make sure you have grad enabled for the created tensor in the quantum classifier.

"""## Factory Functions for Objects/Models

### PCARescaler Factory
"""

def build_pca_rescaler(data_comps):
  print_cust(f"build_pca_rescaler, called")
  return PCARescaler(*data_comps)

"""### PCADiscriminator Factory"""

def build_pca_discriminator(data_comps, state_dict):
  print_cust(f"build_pca_discriminator: called")
  built_pca_disc = PCADiscriminator(*data_comps)
  built_pca_disc.load_state_dict(state_dict)
  return built_pca_disc

"""### PatchQuantumGenerator Factory"""

def build_patchquantumgen(data_comps, state_dict, qnode_builder):
  print_cust(f"build_patchquantumgen: called")
  data_comps_list = list(data_comps)
  created_qnode = qnode_builder(data_comps[2])
  data_comps_list[4] = created_qnode
  # TODO: can call build_pca_rescaler here instead, for reusability
  print_cust(f"build_patchquantumgen, type(data_comps_list[-1]): {type(data_comps_list[-1])}")
  print_cust(f"build_patchquantumgen, type(data_comps_list[-1][0]): {type(data_comps_list[-1][0])}")
  data_comps_list[-1] = list(data_comps_list[-1])
  data_comps_list[-1][0] = copy.deepcopy(data_comps_list[-1][0])
  data_comps_list[-1] = PCARescaler(*data_comps_list[-1])
  built_patchquantumgen = PatchQuantumGenerator(*data_comps_list)
  built_patchquantumgen.load_state_dict(state_dict)

  return built_patchquantumgen

"""### VariationalQuantumClassifier Factory"""

# this has a slightly different signature; doesn't use state_dict explicitly.
def build_variationalquantumclassifier(data_comps, existing_params, qnode_builder):
  # data_comps should be [n_data, conv_layers, expansion_data, n_classes, pennylane_interface, device]
  print_cust(f"build_variationalquantumclassifier: called")
  model_device = data_comps["device"]
  # data_comps_qnode = data_comps[:-1]
  data_comps_qnode = copy.deepcopy(data_comps)
  del data_comps_qnode["device"]
  print_cust(f"build_variationalquantumclassifier, data_comps_qnode: {data_comps_qnode}")
  created_qnode = qnode_builder(**data_comps_qnode)
  built_varquantumclassifier = VariationalQuantumClassifier(created_qnode, model_device)
  built_varquantumclassifier.initialize_existing_parameters(existing_params)

  return built_varquantumclassifier

# DONE (BUT, make sure that this is called correctly): TOADD, layers: factory function for creating the QuantumClassifier.

"""## Helper for Alternate Zeros Init for Dynamic QCNN Params"""

from typing import Tuple, Optional, Literal

def make_paired_layers(
    shape: Tuple[int, int, int],
    backend: Literal["numpy", "torch"] = "numpy",
    *,
    seed: Optional[int] = None,
    dtype=None,                 # np.dtype or torch.dtype (optional)
    device: Optional[str] = None,  # torch device string (e.g., "cuda") if backend="torch"
    requires_grad: bool = False     # only used for torch
):
    """
    Generate an array/tensor of shape (num_layers, num_qubits, 3).

    Behavior:
      - If num_layers is a multiple of 3: build triads (params, 0, -params) per group of 3.
      - Else if num_layers is a multiple of 2: use pairs as before (params, -params) per group of 2.
      - Otherwise: use the original pair logic and append leftover random layers.

    Notes:
      - When num_layers is divisible by both 2 and 3, triads take precedence.
    """
    L, Q, C = shape
    print_cust(f"make_paired_layers, shape: {shape}")
    if C != 3:
        raise ValueError(f"shape[-1] must be 3, got {C}")
    if L < 0 or Q < 0:
        raise ValueError("num_layers and num_qubits must be non-negative")

    # Decide mode (triads first to give it precedence over pairs)
    use_triads = (L % 3 == 0 and L > 0)
    use_pairs = (L % 2 == 0 and not use_triads and L > 0)

    if backend == "numpy":
        # import numpy as np

        # if seed is not None:
        #     np.random.seed(seed)

        if use_triads:
            # --- TRIADS: (params, 0, -params) ---
            t = L // 3
            if t > 0:
                base = np.random.randn(t, 1, Q, 3)                         # (t,1,Q,3)
                if dtype is not None:
                    base = base.astype(dtype)
                zeros = np.zeros_like(base)                                 # (t,1,Q,3)
                triad = np.concatenate([base, zeros, -base], axis=1)       # (t,3,Q,3)
                stacked = triad.reshape(3 * t, Q, 3)                        # (3t,Q,3)
            else:
                stacked = np.empty((0, Q, 3), dtype=dtype if dtype is not None else float)
            return stacked

        # Fallback to PAIRS (original behavior, including leftovers)
        k = L // 2  # number of full pairs
        if k > 0:
            base = np.random.randn(k, 1, Q, 3)                              # (k,1,Q,3)
            if dtype is not None:
                base = base.astype(dtype)
            pair = np.concatenate([base, -base], axis=1)                    # (k,2,Q,3)
            stacked = pair.reshape(2 * k, Q, 3)                             # (2k,Q,3)
        else:
            stacked = np.empty((0, Q, 3), dtype=dtype if dtype is not None else float)

        if (L % 2) == 1:  # odd layer count → add one extra layer
            extra = np.random.randn(1, Q, 3)
            if dtype is not None:
                extra = extra.astype(dtype)
            return np.concatenate([stacked, extra], axis=0)
        return stacked

    elif backend == "torch":
        # import torch

        # if seed is not None:
        #     torch.manual_seed(seed)

        final_dtype = dtype if dtype is not None else torch.get_default_dtype()

        if use_triads:
            # --- TRIADS: (params, 0, -params) ---
            t = L // 3
            if t > 0:
                base = torch.randn((t, 1, Q, 3), dtype=final_dtype, device=device,
                                   requires_grad=requires_grad)             # (t,1,Q,3)
                zeros = torch.zeros_like(base)                               # (t,1,Q,3)
                triad = torch.cat([base, zeros, -base], dim=1)              # (t,3,Q,3)
                stacked = triad.reshape(3 * t, Q, 3)                         # (3t,Q,3)
            else:
                stacked = torch.empty((0, Q, 3), dtype=final_dtype, device=device)
                if requires_grad:
                    stacked.requires_grad_()
            return stacked

        # Fallback to PAIRS (original behavior, including leftovers)
        k = L // 2
        if k > 0:
            base = torch.randn((k, 1, Q, 3), dtype=final_dtype, device=device,
                               requires_grad=requires_grad)                  # (k,1,Q,3)
            pair = torch.cat([base, -base], dim=1)                           # (k,2,Q,3)
            stacked = pair.reshape(2 * k, Q, 3)                              # (2k,Q,3)
        else:
            stacked = torch.empty((0, Q, 3), dtype=final_dtype, device=device)
            if requires_grad:
                stacked.requires_grad_()

        if (L % 2) == 1:
            extra = torch.randn((1, Q, 3), dtype=final_dtype, device=device,
                                requires_grad=requires_grad)
            return torch.cat([stacked, extra], dim=0)
        return stacked

    else:
        raise ValueError('backend must be either "numpy" or "torch"')

"""## Dynamic QCNN params initialization"""

def init_dynamic_qcnn_params(n_qubits, conv_layers, debug=False, zeros_init=False,
                             qubits_and_layers_to_add_block_params=[], generative=False, use_torch=False, alt_zeros_init=""):
    """
    Initialize dynamic QCNN parameters for a QCNN with a given number of convolution layers.

    After applying the pooling layers (one per conv_layer), compute the final number of qubits (current_qubits)
    remaining. Then, generate a final set of convolution layer parameters with the same format as the prior conv layers:
      - If current_qubits > 1:
           even_params has shape (current_qubits // 2, 12)
           odd_params  has shape ((current_qubits - 1) // 2, 12)
         and final_params is a tuple: (even_params, odd_params)
      - If current_qubits == 1:
           final_params is generated as a single-qubit rotation with shape (1, 3)

    Parameters:
      n_qubits: Integer representing the number of qubits in the circuit.
      conv_layers: Integer representing the number of convolution (and, assumed to also be the number of pooling) layers in the circuit.
      debug: Boolean indicating whether or not to initialize parameters to some distinct value, for debugging.
      zeros_init: Boolean indicating whether or not to initialize all parameters to 0.
      qubits_and_layers_to_add_block_params: List of (n_qubits, n_layers) representing the size of the block variational circuit for each of those qubits and layers.

    Returns:
      A tuple containing:
        (conv_params_list, pool_params_list, final_pool_param, final_params, bias_param, block_params_list)
    """

    conv_params_list = []
    pool_params_list = []
    current_qubits = n_qubits

    # Initialize parameters for each convolution + pooling layer.
    for layer in range(conv_layers):
        num_even_pairs = current_qubits // 2
        num_odd_pairs = (current_qubits - 1) // 2

        if debug:
            even_params = np.empty((num_even_pairs, 12))
            for j in range(num_even_pairs):
                for k in range(12):
                    even_params[j, k] = layer * 1e6 + (j * 12 + k + 1)
            odd_params = np.empty((num_odd_pairs, 12))
            for j in range(num_odd_pairs):
                for k in range(12):
                    odd_params[j, k] = layer * 1e6 + (j * 12 + k + 1)
        else:
            even_params = (np.zeros((num_even_pairs, 12)) if zeros_init
                           else np.random.randn(num_even_pairs, 12))
            odd_params = (np.zeros((num_odd_pairs, 12)) if zeros_init
                          else np.random.randn(num_odd_pairs, 12))

        conv_params_list.append((even_params, odd_params))

        if debug:
            pool_params = np.empty((num_even_pairs, 1))
            for j in range(num_even_pairs):
                pool_params[j, 0] = layer * 1e6 + (j + 1)
        else:
            pool_params = np.random.randn(num_even_pairs, 1)
        pool_params_list.append(pool_params)

        # Update current_qubits based on pooling logic.
        # When the number of qubits is even: new number = n/2.
        # When odd: new number = (n//2) + 1 (keeping the last unpaired qubit).
        if current_qubits % 2 == 0:
            current_qubits = current_qubits // 2
        else:
            current_qubits = current_qubits // 2 + 1

    # After conv_layers and pooling, initialize final conv layer parameters based on remaining qubits.
    if current_qubits > 1 and not generative:
        num_even_pairs_final = current_qubits // 2
        num_odd_pairs_final = (current_qubits - 1) // 2
        if debug:
            final_even = np.empty((num_even_pairs_final, 12))
            final_odd = np.empty((num_odd_pairs_final, 12))
            for j in range(num_even_pairs_final):
                for k in range(12):
                    final_even[j, k] = conv_layers * 1e6 + (j * 12 + k + 1)
            for j in range(num_odd_pairs_final):
                for k in range(12):
                    final_odd[j, k] = conv_layers * 1e6 + (j * 12 + k + 1)
        else:
            final_even = (np.random.randn(num_even_pairs_final, 12)
                          if not zeros_init else np.zeros((num_even_pairs_final, 12)))
            final_odd = (np.random.randn(num_odd_pairs_final, 12)
                         if not zeros_init else np.zeros((num_odd_pairs_final, 12)))
        final_params = (final_even, final_odd)
    else:
        # If only one qubit remains, generate parameters for a single-qubit rotation.
        if debug:
            final_params = np.array([[conv_layers * 1e6 + 1, conv_layers * 1e6 + 2, conv_layers * 1e6 + 3]])
        else:
            final_params = (np.random.randn(1, 3) if not zeros_init else np.zeros((1, 3)))

    # Generate final_pool_param and bias_param as before.
    if debug:
        final_pool_param = np.array([conv_layers * 1e6 + 1])
        bias_param = np.array([conv_layers * 1e6 + 1])
    else:
        final_pool_param = np.random.randn(1)
        bias_param = np.random.randn(1)

    # Initialize block parameters.
    block_params_list = []
    for block_idx, (num_qubits_bp, num_layers_bp) in enumerate(qubits_and_layers_to_add_block_params):
        shape = (num_layers_bp, num_qubits_bp, 3)
        if not use_torch:
          bp = np.zeros(shape) if zeros_init or debug else np.random.randn(*shape)
          # NOTE, layers: this is an override of bp.
          if zeros_init and alt_zeros_init == "posneg":
            print_cust(f"init_dynamic_qcnn_params, posneg numpy initialization")
            bp = make_paired_layers(shape, "numpy")
          elif zeros_init and alt_zeros_init == "random":
            print_cust(f"init_dynamic_qcnn_params, random numpy initialization")
            bp = np.random.randn(*shape)
        else:
          bp = torch.zeros(shape) if zeros_init or debug else torch.randn(*shape)
          if zeros_init and alt_zeros_init == "posneg":
            print_cust(f"init_dynamic_qcnn_params, posneg torch initialization")
            bp = make_paired_layers(shape, "torch", requires_grad=True)
          elif zeros_init and alt_zeros_init == "random":
            print_cust(f"init_dynamic_qcnn_params, random torch initialization")
            bp = torch.randn(*shape)
        block_params_list.append(bp)

    return (tuple(conv_params_list), tuple(pool_params_list), final_pool_param,
            final_params, bias_param, block_params_list)

# REPURPOSE: init_dynamic_qcnn_params to have zero of all the other parameters, EXCEPT the block_params_list. later -- can add bias_param.
# for the above, it need not be called
# DONE (nothing was changed; whether final_params is 0 or not, it should not be called if I specify that no conv layer means no additional final conv layer): TOMODIFY, layers: IF conv layers is 0, then ALWAYS make final_params 0.
# Note: the above DOES change the functionality for QCNN's, but I think that that is OK.

"""# Angle encoding evaluation funcs

## Pennylane loss function
"""

def pennylane_loss_fn(params, input_angles, y, qnode):
    """
    Parameters:
      params: List of parameters to feed into the quantum model.
      input_angles: Input data to feed into the quantum model.
      y: Expected classification output (assumed to be an integer representing the class of interest).
      qnode: A Pennylane QNode object for evaluating the quantum model.


    Compute the multi-class cross entropy loss.

    The QCNN now returns a probability vector.
    y is expected to be an integer label.

    Returns:
      A float representing the multi class cross entropy loss.
    """
    probs = qnode(input_angles, params)
    epsilon = 1e-7
    # For the given true label y, compute negative log-likelihood
    return -qml.math.log(probs[int(y)] + epsilon)

# REPURPOSE: pennylane_loss_fn should just take the discriminator model (params should be stored in it), real input to discriminator, and fake input to disc (detached).
# Careful -- REAL and FAKE input should be DETACHED. Discriminator should NOT know how the data was generated.
# Detaching should be done in the function itself.

# NOTE: I am not repurposing this; will just compute the loss in the training loop itself. going to change the training loop.

def compute_loss_angle_param_batch(params, X_angles, y, layers=1, shots=1024, batch_size=None, qnode=None):
    """
    Parameters:
      params: List of parameters to feed into the quantum model.
      X_angles: Numpy array of input data to feed into the model
      y: Expected classification outputs for each of the input datum (assumed to be integers representing the class of interest).
      layers: Integer representing the number of layers in the QCNN.
      batch_size: Integer representing the batch size.
      shots: Integer representing the number of shots used to evaluate the QCNN.
      qnode: A Pennylane QNode object for evaluating the quantum model.

    Computes the loss for all of the input data specified.

    Returns:
      A float representing the average cross entropy loss over the dataset.
    """
    # total_loss = 0
    # for xi, yi in zip(X_angles, y):
    #     total_loss += pennylane_loss_fn(params, xi, yi, qnode)
    # return total_loss / len(X_angles)
    losses = [pennylane_loss_fn(params, xi, yi, qnode) for xi, yi in zip(X_angles, y)]
    return qml.math.mean(qml.math.stack(losses))

"""## Accuracy computation function"""

def compute_avg_acc_angle_param_batch(params, X_angles, y, layers=1, shots=1024, batch_size=None, qnode=None):
    """
    Parameters:
      params: List of parameters to feed into the quantum model.
      X_angles: Numpy array of input data to feed into the model
      y: Expected classification outputs for each of the input datum (assumed to be integers representing the class of interest).
      layers: Integer representing the number of layers in the QCNN.
      batch_size: Integer representing the batch size.
      shots: Integer representing the number of shots used to evaluate the QCNN.
      qnode: A Pennylane QNode object for evaluating the quantum model.

    Computes the loss for all of the input data specified.

    Returns:
      A float representing the average cross entropy accuracy over the dataset.

    Computes average accuracy by comparing the index of the maximum probability
    in the output vector to the true label.
    """
    preds = []
    for input_angles in X_angles:
        probs = qnode(input_angles, params)
        preds.append(np.argmax(probs))
    preds = np.array(preds)
    return np.mean(preds == y)

# NOT CHANGED (not called where relevant): TOMODIFY, layers: above function to allow for pennylane angleencode or sequential angleencode.

"""## Standard deviation of accuracy function"""

def compute_std_acc_angle_param_batch(params, X_angles, y, layers=1, shots=1024, batch_size=None, qnode=None):
    """
    Computes the standard deviation of per-class accuracies for a multiclass classifier.

    For each class, the function computes the accuracy (i.e., the fraction of correct predictions
    for examples belonging to that class) and then calculates the standard deviation of these accuracies.

    Parameters:
        params: Parameters to be passed to the qnode.
        X_angles: Array-like, inputs to the qnode function.
        y: Array-like, true labels.
        layers: (Optional) Parameter for potential layer configuration (currently unused).
        shots: (Optional) Parameter for simulation shots (currently unused).
        batch_size: (Optional) Parameter for batch processing (currently unused).
        qnode: A function that takes (input_angles, params) and returns a vector of probabilities.

    Returns:
        std_acc: The standard deviation of the per-class accuracies.
    """
    preds = []
    # Generate predictions using the qnode
    for input_angles in X_angles:
        probs = qnode(input_angles, params)
        preds.append(np.argmax(probs))

    preds = np.array(preds)

    # Compute the accuracy for each class.
    unique_classes = np.unique(y)
    class_accuracies = []

    for cls in unique_classes:
        # Indices for the current class
        indices = np.where(y == cls)[0]
        # Compute accuracy for the current class
        if len(indices) > 0:
            acc = np.mean(preds[indices] == y[indices])
            class_accuracies.append(acc)
        else:
            # If no examples are present for a class, we can either ignore or set a default accuracy (e.g., 0.0).
            class_accuracies.append(0.0)

    # Calculate the standard deviation across the per-class accuracies.
    std_acc = np.std(class_accuracies)
    return std_acc

# NOT CHANGED (not called where relevant): TOMODIFY, layers: above function to allow for pennylane angleencode or sequential angleencode.

"""## Top K accuracy"""

def compute_top_k_acc_angle_param_batch(params, X_angles, y, layers=1, shots=1024, batch_size=None, qnode=None):
    """
    Computes an array of top-k accuracies for a multiclass classifier.

    For each k = 1, 2, ..., N (where N is the total number of classes), the function
    checks how many times the true label appears in the top k predictions (sorted in descending
    order) produced by the qnode for the provided inputs, and then returns the proportion of
    correct cases for each k.

    Parameters:
        params: Parameters to be passed to the qnode.
        X_angles: Array-like, inputs to the qnode function.
        y: Array-like, true labels.
        layers: (Optional) Parameter for potential layer configuration (currently unused).
        shots: (Optional) Parameter for simulation shots (currently unused).
        batch_size: (Optional) Parameter for batch processing (currently unused).
        qnode: A function that takes (input_angles, params) and returns a vector of probabilities.

    Returns:
        top_k_accuracies: A NumPy array of shape (N,) where the element at index k-1
                          is the top-k accuracy (with k from 1 to N).
    """
    # Compute the number of classes using the first qnode call (assumes X_angles is non-empty)
    first_probs = qnode(X_angles[0], params)
    num_classes = len(first_probs)

    top_k_counts = np.zeros(num_classes)
    total_samples = len(X_angles)

    # Loop over each input and its corresponding true label.
    for i, input_angles in enumerate(X_angles):
        probs = qnode(input_angles, params)
        # Get indices of predictions sorted from highest to lowest probability.
        sorted_indices = np.argsort(-np.array(probs))
        true_label = y[i]

        # For each k from 1 to the total number of classes, check if the true label is among the top k predictions.
        for k in range(1, num_classes + 1):
            if true_label in sorted_indices[:k]:
                top_k_counts[k - 1] += 1

    # Convert counts to accuracies
    top_k_accuracies = top_k_counts / total_samples
    return top_k_accuracies

# NOT CHANGED (not called where relevant): TOMODIFY, layers: above function to allow for pennylane angleencode or sequential angleencode.

"""## Aggregated metrics function"""

def compute_metrics_angle_param_batch(params, X_angles, y, layers=1, shots=1024, batch_size=None, qnode=None, math_int=np):
    """
    Function that computes the average accuracies, standard deviation of the per-class accuracies, top-k accuracy, and average loss
    with just one call of the quantum circuit.

    Parameters:
        params: Parameters to be passed to the qnode.
        X_angles: Array-like, inputs to the qnode function.
        y: Array-like, true labels.
        layers: (Optional) Parameter for potential layer configuration (currently unused).
        shots: (Optional) Parameter for simulation shots (currently unused).
        batch_size: (Optional) Parameter for batch processing (currently unused).
        qnode: A function that takes (input_angles, params) and returns a vector of probabilities.

    Returns:
        std_acc: The standard deviation of the per-class accuracies.
    """
    if math_int == np:
      container_creator = np.array
      float_dtype = np.float64
    elif math_int == torch:
      container_creator = torch.tensor
      float_dtype = torch.float32
    print_cust(f"compute_metrics_angle_param_batch, math_int: {math_int}")
    print_cust(f"compute_metrics_angle_param_batch, container_creator: {container_creator}")
    # Compute the number of classes using the first qnode call (assumes X_angles is non-empty)
    # DONE: TOMODIFY, depthFL: this is multiple output probabilities. you need to ensemble it, and/or compute the accs for EACH classifier model.
    first_probs = qnode(X_angles[0], params)
    if len(first_probs.shape) == 2:
      # then assumed to be of shape (n_classifiers, prob_dim)
      print_cust(f"compute_metrics_angle_param_batch, len(first_probs.shape) == 2, first_probs.shape: {first_probs.shape}")
      num_classes = first_probs.shape[1]
    else:
      # should be of shape prob_dim
      print_cust(f"compute_metrics_angle_param_batch, len(first_probs.shape) != 2, first_probs.shape: {first_probs.shape}")
      num_classes = len(first_probs)


    top_k_counts = math_int.zeros(num_classes)
    total_samples = len(X_angles)

    preds = []
    losses = []

    epsilon = 1e-7

    # Loop over each input and its corresponding true label.
    # DONE: TOMODIFY, depthFL: this is multiple output probabilities. you need to ensemble it, and/or compute the accs for EACH classifier model.
    all_probs = qnode(X_angles, params)
    gen_all_probs = None
    if len(all_probs.shape) == 3:
      print_cust(f"compute_metrics_angle_param_batch, len(all_probs.shape) == 3, all_probs.shape: {all_probs.shape}")
      # should be of shape (n_classifiers, input_dim, prob_dim)
      # TOMODIFY, depthFL: BEFORE taking the mean, get the statistics for each INDIVIDUAL classifier; this could be helpful.
      gen_all_probs = all_probs
      all_probs = all_probs.mean(axis=0)
    print_cust(f"compute_metrics_angle_param_batch, all_probs.shape: {all_probs.shape}")
    # of shape input_dim, prob_dim
    for i, probs in enumerate(all_probs):
        # probs = qnode(input_angles, params)
        preds.append(math_int.argmax(probs))
        # Get indices of predictions sorted from highest to lowest probability.

        sorted_indices = math_int.argsort(-container_creator(probs))
        true_label = y[i]

        loss = -qml.math.log(probs[int(true_label)] + epsilon)
        losses.append(loss)

        # For each k from 1 to the total number of classes, check if the true label is among the top k predictions.
        for k in range(1, num_classes + 1):
            if true_label in sorted_indices[:k]:
                top_k_counts[k - 1] += 1

    # Convert counts to accuracies
    top_k_accuracies = top_k_counts / total_samples


    preds = container_creator(preds)

    # Compute the accuracy for each class.
    # CHANGED, layers: wrapped y in container_creator b/c torch.unique only takes in tensors; make sure that functionality is still OK in the numpy case.
    unique_classes = math_int.unique(container_creator(y))
    class_accuracies = []

    for cls in unique_classes:
        # Indices for the current class
        indices = math_int.where(container_creator(y) == cls)[0]
        # Compute accuracy for the current class
        if len(indices) > 0:
            acc = math_int.mean(preds[indices] == container_creator(y)[indices], dtype=float_dtype)
            class_accuracies.append(acc)
        else:
            # If no examples are present for a class, we can either ignore or set a default accuracy (e.g., 0.0).
            class_accuracies.append(0.0)

    # Calculate the standard deviation across the per-class accuracies.
    # NOTE, layers (although this is not specific to the layer classifier setup): for torch, ddof=1 default; for numpy, ddof=0 default, so will get diff stddevs. (prob should use)
    # sample stddev here.
    std_acc = math_int.std(container_creator(class_accuracies))

    avg_acc = math_int.mean(preds == container_creator(y), dtype=float_dtype)

    avg_loss = qml.math.mean(qml.math.stack(losses))

    if gen_all_probs is not None:
      avg_acc_classifiers = []
      std_acc_classifiers = []
      top_k_accuracies_classifiers = []
      avg_loss_classifiers = []
      print_cust(f"compute_metrics_angle_param_batch, gen_all_probs is not None")
      # gen_all_probs assumed to be of shape 3
      for all_probs in gen_all_probs:
        # TOMODIFY, depthFL: encapsulate this logic for computing classifier accuracies in a function.
        top_k_counts = math_int.zeros(num_classes)
        total_samples = len(X_angles)

        preds = []
        losses = []

        epsilon = 1e-7

        print_cust(f"compute_metrics_angle_param_batch, all_probs.shape: {all_probs.shape}")
        # of shape input_dim, prob_dim
        for i, probs in enumerate(all_probs):
            # probs = qnode(input_angles, params)
            preds.append(math_int.argmax(probs))
            # Get indices of predictions sorted from highest to lowest probability.

            sorted_indices = math_int.argsort(-container_creator(probs))
            true_label = y[i]

            loss = -qml.math.log(probs[int(true_label)] + epsilon)
            losses.append(loss)

            # For each k from 1 to the total number of classes, check if the true label is among the top k predictions.
            for k in range(1, num_classes + 1):
                if true_label in sorted_indices[:k]:
                    top_k_counts[k - 1] += 1

        # Convert counts to accuracies
        top_k_accuracies_subclassif = top_k_counts / total_samples


        preds = container_creator(preds)

        # Compute the accuracy for each class.
        # CHANGED, layers: wrapped y in container_creator b/c torch.unique only takes in tensors; make sure that functionality is still OK in the numpy case.
        unique_classes = math_int.unique(container_creator(y))
        class_accuracies = []

        for cls in unique_classes:
            # Indices for the current class
            indices = math_int.where(container_creator(y) == cls)[0]
            # Compute accuracy for the current class
            if len(indices) > 0:
                acc = math_int.mean(preds[indices] == container_creator(y)[indices], dtype=float_dtype)
                class_accuracies.append(acc)
            else:
                # If no examples are present for a class, we can either ignore or set a default accuracy (e.g., 0.0).
                class_accuracies.append(0.0)

        # Calculate the standard deviation across the per-class accuracies.
        # NOTE, layers (although this is not specific to the layer classifier setup): for torch, ddof=1 default; for numpy, ddof=0 default, so will get diff stddevs. (prob should use)
        # sample stddev here.
        std_acc_subclassif = math_int.std(container_creator(class_accuracies))

        avg_acc_subclassif = math_int.mean(preds == container_creator(y), dtype=float_dtype)

        avg_loss_subclassif = qml.math.mean(qml.math.stack(losses))

        avg_acc_classifiers.append(avg_acc_subclassif)
        std_acc_classifiers.append(std_acc_subclassif)
        top_k_accuracies_classifiers.append(top_k_accuracies_subclassif)
        avg_loss_classifiers.append(avg_loss_subclassif)




    if gen_all_probs is None:
      return avg_acc, std_acc, top_k_accuracies, avg_loss
    else:
      print_cust(f"compute_metrics_angle_param_batch, gen_all_probs is not None, returning")
      return avg_acc, std_acc, top_k_accuracies, avg_loss, avg_acc_classifiers, std_acc_classifiers, top_k_accuracies_classifiers, avg_loss_classifiers, gen_all_probs

# TOMODIFY, layers: above function to allow for pennylane angleencode or sequential angleencode.
# DONE: TOMODIFY, layers: make sure above function plays nicely with pytorch tensors.

"""### Latent Noise Generation Function"""

def generate_latent_noise(latent_dim, n_qubits, device, min=0.0, max=(np.pi / 2), n_qubits_small=0):
    return (torch.rand(latent_dim, n_qubits, device=device) * (max - min)) + min

"""### Generate Images from Generator Function"""

def generate_images_generator(generator_obj, many_samples_noise, pca_rescaler, ret_pca_feats=False, resc_invpca=True):
    # many_samples_noise = latent_noise_func(n_imgs, generator_obj.n_qubits_gen, generator_obj.device)
    print_cust(f"generate_images_generator, resc_invpca: {resc_invpca}")
    many_samples_pcafeats = generator_obj(many_samples_noise)
    many_samples_imgs_glob = pca_rescaler.invpca_to_img(many_samples_pcafeats, resc_invpca=resc_invpca)
    if not ret_pca_feats:
        return many_samples_imgs_glob.view(many_samples_imgs_glob.shape[0], math.isqrt(many_samples_imgs_glob.shape[1]), math.isqrt(many_samples_imgs_glob.shape[1]))
    else:
        return many_samples_imgs_glob.view(many_samples_imgs_glob.shape[0], math.isqrt(many_samples_imgs_glob.shape[1]), math.isqrt(many_samples_imgs_glob.shape[1])), many_samples_pcafeats

"""### Save Tensors to Folder Function"""

from PIL import Image

def save_tensors_to_folder(img_tensors, folder_name, input_prefix, img_ext="png"):
    img_tensors = img_tensors.detach()
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
    print_cust(f"save_tensors_to_folder, img_tensors.dtype: {img_tensors.dtype}")
    print_cust(f"save_tensors_to_folder, img_tensors.max(): {img_tensors.max()}")
    print_cust(f"save_tensors_to_folder, img_tensors.min(): {img_tensors.min()}")
    assert img_tensors.max() <= 1.0, f"save_tensors_to_folder, img_tensors.max() > 1.0: {img_tensors.max()}"
    assert img_tensors.min() >= 0.0, f"save_tensors_to_folder, img_tensors.min() < 0.0: {img_tensors.min()}"
    if not img_tensors.dtype == torch.uint8 and img_tensors.max() <= 1.0:
        print_cust(f'save_tensors_to_folder: renormalizing data')
        img_tensors = (img_tensors * 255).round()
    img_tensors = img_tensors.numpy().astype(np.uint8)
    for img_idx, img_arr in enumerate(img_tensors):
        img_obj = Image.fromarray(img_arr)
        img_obj.save(f"{folder_name}/{input_prefix}_{img_idx}.{img_ext}")

# TOADD: have the function not compute any accuracy metric; only the loss metric. (consider writing my own function for repurposing.)

"""### FID Calculation Function"""

from pytorch_fid.fid_score import calculate_fid_given_paths

def compute_fid_to_data(generator, noise_func, targ_data_folder_name, gen_data_folder_name, n_samples_gen, device, fid_batch_size=1000, resc_invpca=True):
  # n_samples_noise = targ_data.shape[0]
  print_cust(f"compute_fid_to_data, resc_invpca: {resc_invpca}")
  print_cust(f"compute_fid_to_data, fid_batch_size: {fid_batch_size}")

  print_cust(f"compute_fid_to_data, n_samples_noise: {n_samples_gen}")
  n_qubits_noise = generator.n_qubits_gen
  noise_samples = noise_func(n_samples_gen, n_qubits_noise, device)
  pca_rescaler = generator.pca_rescaler
  generator_gen_imgs = generate_images_generator(generator, noise_samples, pca_rescaler, resc_invpca=resc_invpca)

  print_cust(f"compute_fid_to_data, generator_gen_imgs: {generator_gen_imgs}")

  if generator_gen_imgs.max() >= 1.0 or generator_gen_imgs.min() < 0.0:
    print_cust(f"compute_fid_to_data, generator_gen_imgs exceeds range of pixel values. generator_gen_imgs.max(): {generator_gen_imgs.max()}, generator_gen_imgs.min(): {generator_gen_imgs.min()}")
    # maybe print out the number of pixels that needed to be clamped?
    generator_gen_imgs = torch.clamp(generator_gen_imgs, max=1.0, min=0.0)
    print_cust(f"compute_fid_to_data, after clamping, generator_gen_imgs.max(): {generator_gen_imgs.max()}, generator_gen_imgs.min(): {generator_gen_imgs.min()}")

  save_tensors_to_folder(generator_gen_imgs, gen_data_folder_name, "img")

  # temporary override. the main time bottleneck appears to be FID calculation.
  # I suppose if it is still taking a long time, then reduce the testing data size.
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  fid_score = calculate_fid_given_paths([gen_data_folder_name, targ_data_folder_name], fid_batch_size, device, 2048)

  return fid_score

from pennylane.optimize import AdamOptimizer

def tuples_allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
    """
    Parameters:
      a: a list, tuple, or numpy array composed of other lists, tuples or numpy arrays
      b: a list, tuple, or numpy array composed of other lists, tuples or numpy arrays
      rtol: float representing the relative tolerance parameter
      atol: float representing the absolute tolerance parameter
      equal_nan: boolean indicating whether to compare NaN's as equal.
    """
    # If both are list or tuple, compare elementwise recursively.
    if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
        if len(a) != len(b):
            return False
        return all(tuples_allclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan)
                   for x, y in zip(a, b))
    else:
        # Otherwise, assume they are arrays or scalars.
        return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)

def flatten_params(params):
    """
    Flatten a list of arrays into a single 1D array.

    Parameters:
      params: a list of numpy arrays

    Returns:
        flat_params: a 1D numpy array containing all parameters
        shapes: a list of shapes (one per original parameter)
    """
    flat_list = []
    shapes = []
    for p in params:
        p = np.array(p)  # ensure it's a Pennylane array
        shapes.append(p.shape)
        flat_list.append(p.flatten())
    flat_params = np.concatenate(flat_list)
    return flat_params, shapes

def unflatten_params(flat_params, shapes):
    """
    Reshape a flat parameter array back into a list of arrays with the given shapes.

    Parameters:
      flat_params: a 1D numpy array containing all parameters
      shapes: a list of shapes (one per original parameter)

    Returns:
      params: a list of numpy arrays
    """
    params_new = []
    index = 0
    for shape in shapes:
        size = int(np.prod(shape))
        param = flat_params[index:index+size].reshape(shape)
        params_new.append(param)
        index += size
    return params_new

def flatten_grad(grad, shapes):
    """
    Flatten a list of gradient arrays into a single 1D array.

    Parameters:
      grad: a list of numpy arrays
      shapes: a list of shapes of the parameters (that the gradients are assumed to follow)

    Returns:
        flat_grad: a 1D numpy array containing all gradients

    The gradients must be provided in the same order as the parameters.
    """
    flat_list = []
    for g, shape in zip(grad, shapes):
        g = np.array(g)
        flat_list.append(g.flatten())
    flat_grad = np.concatenate(flat_list)
    return flat_grad

def flatten_params_recursive(params):
    """

    Parameters:
       params: a list of numpy arrays

    Recursively flattens a nested structure of parameters and returns a 1D numpy array and a list of shapes.

    Returns:
      flat_params: a 1D numpy array containing all parameters
      shapes: a list of shapes (one per original parameter)
    """
    flat_list = []
    shapes = []

    def recurse(item):
        if isinstance(item, (list, tuple)):
            for sub in item:
                recurse(sub)
        else:
            arr = np.array(item)
            shapes.append(arr.shape)
            flat_list.append(arr.flatten())

    recurse(params)
    flat_params = np.concatenate(flat_list)
    return flat_params, shapes

def unflatten_params_recursive(flat_params, shapes, structure):
    """
    Rebuilds the nested parameter structure from the flat_params using the recorded shapes and the original structure.

    Parameters:
      flat_params: a 1D numpy array containing all parameters
      shapes: a list of shapes (one per original parameter)
      structure: the original structure of parameters
    """
    flat_elems = []
    pointer = 0
    for shape in shapes:
        size = np.prod(shape)
        flat_elems.append(flat_params[pointer:pointer+size].reshape(shape))
        pointer += size

    it = iter(flat_elems)

    def rebuild(struct):
        if isinstance(struct, (list, tuple)):
            return type(struct)(rebuild(sub) for sub in struct)
        else:
            return next(it)

    return rebuild(structure)

def flatten_grad_recursive(grad):
    """
    grad: a list of numpy arrays representing the gradient

    Recursively flattens a nested structure of gradients into a single 1D numpy array.

    Returns:
      flat_grad: a 1D numpy array containing all gradients
    """
    flat_list = []

    def recurse(item):
        if isinstance(item, (list, tuple)):
            for sub in item:
                recurse(sub)
        else:
            arr = np.array(item)
            flat_list.append(arr.flatten())

    recurse(grad)
    flat_grad = np.concatenate(flat_list)
    return flat_grad

"""## Gradient Norm Computation Func"""

def grad_l2_norm(g):
    """
    Compute the global L2 (Euclidean) norm of a nested gradient structure.

    Parameters
    ----------
    g : array‑like | list | tuple
        Gradient tree whose leaves are array‑like tensors.

    Returns
    -------
    float or tensor
        Scalar ‖g‖₂.  Type matches the backend of the leaves
        (NumPy float for vanilla NumPy, 0‑D torch.Tensor for PyTorch, etc.).
    """
    if isinstance(g, (list, tuple)):
        # Recursively accumulate the *squared* norms of the children
        sq_norms = [qml.math.square(grad_l2_norm(subg)) for subg in g]
        return qml.math.sqrt(qml.math.sum(qml.math.stack(sq_norms)))
    else:
        # Leaf: plain array / tensor
        return qml.math.linalg.norm(g)

"""## Tree to list helper"""

def tree_to_list(tree):
    """Return (flat_leaves, rebuild_fn) for an arbitrary nested tree."""
    leaves = []

    def walk(node):
        if isinstance(node, (list, tuple)):
            return [walk(n) for n in node]          # mirror structure
        else:
            leaves.append(node)
            return len(leaves) - 1                  # index placeholder

    structure = walk(tree)

    def rebuild_tree(new_leaves):
        it = iter(new_leaves)
        def r(node):
            if isinstance(node, list):
                return [r(n) for n in node]
            elif isinstance(node, tuple):
                return tuple(r(n) for n in node)
            else:                   # placeholder -> grab next leaf
                return next(it)
        return r(structure)

    return leaves, rebuild_tree

"""## Core training function"""

import copy

def train_epochs_angle_param_adam(params, X_angles, y, X_val, y_val, n_epochs=5, layers=1, shots=1024, batch_size=32,
                                  lr=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8,
                                  grad_func=None, trainable_mask=None, qnode=None,
                                  patience=3):

    """
    Trains through n_epochs epochs over the given X_angles dataset in minibatches, using the validation set to have early termination.

    Parameters:
      params: a list of numpy arrays
      X_angles: Array-like, inputs to the qnode function.
      y: Array-like, true labels.
      X_val: Array-like, inputs to qnode function used for early termination.
      y_val: Array-like, true labels used for early termination.
      n_epochs: Integer, number of epochs used for training.
      layers: (Optional) Parameter for potential layer configuration (currently unused).
      shots: (Optional) Parameter for simulation shots (currently unused).
      batch_size: (Optional) Parameter for batch processing (currently unused).
      lr: Integer representing the learning rate for the optimizer.
      beta1: Integer representing the beta1 parameter for the optimizer (currently unused).
      beta2: Integer representing the beta2 parameter for the optimizer (currently unused).
      epsilon: Float representing some minimum precision (currently unused).
      grad_func: Function used to apply the gradient (currently unused; could be used in the future for more decoupling/separation of concerns).
      trainable_mask: List of numpy arrays used to specify what parameters to NOT apply the gradient on.
      qnode: A function that takes (input_angles, params) and returns a vector of probabilities.
      patience: Integer representing the maximum number of epochs that is allowed for validation loss to not increase.

    Returns:
      params: a list of numpy arrays representing the trained params
      minibatch_losses: a numpy array containing the losses for each minibatch
      validation_losses: a numpy array containing the losses on the validation set
    """
    # Initialize containers to store metrics as well as the optimizer
    n_samples = X_angles.shape[0]
    minibatch_losses = []
    validation_losses = []
    opt = AdamOptimizer(stepsize=lr)
    best_val_loss = float('inf')
    patience_count = 0
    # print_cust(f"train_epochs_angle_param_adam, params[0].shape: {params[0].shape}")

    # Record the original nested structure as a template
    structure_template = params

    print_cust(f"train_epochs_angle_param_adam, n_samples: {n_samples}")
    # For each epoch, train on randomly sampled minibatches.
    for epoch in range(n_epochs):
        print_cust(f"[Adam SGD] Epoch {epoch+1}/{n_epochs}")
        indices = np.arange(n_samples)
        np.random.shuffle(indices)
        for start in range(0, n_samples, batch_size):
            end = min(start + batch_size, n_samples)
            batch_indices = indices[start:end]
            X_batch = X_angles[batch_indices]
            y_batch = y[batch_indices]
            old_params = copy.deepcopy(params)

            # Compute gradient on the mini-batch
            grads = qml.grad(compute_loss_angle_param_batch, argnum=0)(
                params, X_batch, y_batch, layers=layers, shots=shots, batch_size=len(X_batch), qnode=qnode
            )

            # --------- flatten *one* level ----------
            flat_p,  rebuild = tree_to_list(params)
            flat_g, _        = tree_to_list(grads)       # identical structure ⇒ same rebuild

            # ensure differentiable type
            flat_p = [np.asarray(p, requires_grad=True) for p in flat_p]
            # print_cust(f"train_epochs_angle_param_adam, grads after grad calculation: {grads}")
            # grads = np.array(grads).flatten()
            # If a trainable mask is provided, apply it elementwise
            # if trainable_mask is not None:
            #     def apply_mask(grad, mask):
            #         """Element‑wise multiply `grad` by `mask`, preserving the nested structure.

            #         Both grad and mask must have *exactly* the same tree layout.
            #         Mask entries can be 0/1, booleans, or any broadcast‑compatible scalars.
            #         """
            #         # Same container (list vs tuple) is preserved with `type(grad)(...)`
            #         if isinstance(grad, (list, tuple)):
            #             if len(grad) != len(mask):
            #                 raise ValueError("Mask and gradient have different lengths")
            #             return type(grad)(apply_mask(g, m) for g, m in zip(grad, mask))
            #         else:
            #             # Convert to array once so broadcasting works and the dtypes match
            #             return grad * np.asarray(mask)
            #     grads = apply_mask(grads, trainable_mask)
            #     # print_cust(f"train_epochs_angle_param_adam, masked grads: {grads}")
            # (optional) mask
            if trainable_mask is not None:
                flat_mask, _ = tree_to_list(trainable_mask)
                # print_cust(f"train_epochs_angle_param_adam, flat_mask: {flat_mask}")
                # print_cust(f"train_epochs_angle_param_adam, in trainable_mask conditional, flat_g: {flat_g}")
                flat_g = [g * m for g, m in zip(flat_g, flat_mask)]
                # print_cust(f"train_epochs_angle_param_adam, in trainable_mask conditional, after applying mask, flat_g: {flat_g}")


            # print_cust(f"train_epochs_angle_param_adam, flat_g: {flat_g}")

            # print_cust(f"train_epochs_angle_param_adam, old_params: {old_params}")
            # if trainable_mask is not None:
            #   print_cust(f"train_epochs_angle_param_adam, flat_p: {flat_p}")

            # --------- Adam step ----------
            flat_p = opt.apply_grad(flat_g, flat_p)
            # if trainable_mask is not None:
            #   print_cust(f"train_epochs_angle_param_adam, flat_p after grad application: {flat_p}")

            # --------- put tensors back ----------
            params  = rebuild(flat_p)

            # print_cust(f"train_epochs_angle_param_adam, params after grad application: {params}")
            print_cust(f"train_epochs_angle_param_adam, grad_l2_norm(grads): {grad_l2_norm(flat_g)}")

            # Flatten parameters and gradients recursively
            # flat_params, shapes = flatten_params_recursive(params)
            # flat_grad = flatten_grad_recursive(grads)

            # print_cust(f"train_epochs_angle_param_adam, np.linalg.norm(flat_grad): {np.linalg.norm(flat_grad)}")
            # print_cust(f"train_epochs_angle_param_adam, flat_params.shape: {flat_params.shape}")
            # print_cust(f"train_epochs_angle_param_adam, flat_grad.shape: {flat_grad.shape}")
            # print_cust(f"train_epochs_angle_param_adam, shapes: {shapes}")

            # Convert flat_params to a Pennylane array with requires_grad=True
            # flat_params = np.array(flat_params, requires_grad=True)

            # Update parameters using the Adam optimizer
            # updated_flat_params = opt.apply_grad(flat_grad, flat_params)
            # updated_flat_params = np.array(updated_flat_params, requires_grad=True)

            # Reconstruct the nested structure of parameters
            # params = unflatten_params_recursive(updated_flat_params, shapes, structure_template)

            if tuples_allclose(old_params, params):
              print_cust(f"train_epochs_angle_param_adam, old params and params are basically the same after gradient application")
            # print_cust(f"train_epochs_angle_param_adam, grads after grad application: {grads}")
            # print_cust(f"train_epochs_angle_param_adam, params after grad application: {params}")
            loss_batch = compute_loss_angle_param_batch(params, X_batch, y_batch, layers=layers, shots=shots, batch_size=len(X_batch), qnode=qnode)
            # print_cust(f"train_epochs_angle_param_adam, params after loss computation: {params}")
            minibatch_losses.append(loss_batch)
            print_cust(f"  Mini-batch loss: {loss_batch:.4f}")

        # At the end of the epoch, compute validation loss:
        val_loss = compute_loss_angle_param_batch(params, X_val, y_val, layers=layers, shots=shots, qnode=qnode)
        validation_losses.append(val_loss)
        print_cust(f"Epoch {epoch+1} validation loss: {val_loss:.4f}")

        # Early stopping: if the validation loss doesn't improve for 'patience' epochs, stop training.
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_count = 0
        else:
            patience_count += 1
            if patience_count >= patience:
                print_cust("Early stopping triggered.")
                break

    return params, minibatch_losses, validation_losses

"""## Helper Function for Computing Param Diffs"""

import torch
from typing import Iterable, Dict, Any, Optional

def l2_state_dict_difference(
    state_dict_a: Dict[str, torch.Tensor],
    state_dict_b: Dict[str, torch.Tensor],
    *,
    keys: Optional[Iterable[str]] = None,
    strict: bool = True,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    include_buffers: bool = True
) -> torch.Tensor:
    """
    Compute the L2 norm of the difference between two PyTorch state dicts.

    Parameters
    ----------
    state_dict_a, state_dict_b :
        The two state dicts (e.g. model.state_dict()) whose parameter/buffer
        tensors you want to compare.
    keys :
        Optional iterable of keys to restrict the comparison. If None, will
        use the intersection of keys (parameters +, optionally, buffers).
    strict :
        If True, raise an error when a key in `keys` is missing from either
        dict. If False, silently skip missing keys.
    device :
        Device to perform the accumulation on. If None, uses the device of
        each individual tensor (differences are accumulated on the first
        encountered device, moving others as needed).
    dtype :
        Optional dtype to cast tensors before differencing (e.g. torch.float32
        for numerical stability / consistency).
    include_buffers :
        If False, attempts to skip typical non-parameter buffers (by a simple
        heuristic: keys containing 'running_' or 'num_batches_tracked').

    Returns
    -------
    torch.Tensor
        A scalar tensor: sqrt( sum_i || param_i^A - param_i^B ||_2^2 ).
        (Equivalent to the global L2 norm of all flattened differences.)

    Notes
    -----
    * Implemented in a streaming fashion to avoid allocating one huge
      concatenated tensor—memory efficient for large models.
    * Uses manual accumulation + sqrt instead of the (deprecated) torch.norm
      over a concatenated tensor; could also use torch.linalg.norm on a final
      flattened difference vector, but that would require materializing it.
    """
    if keys is None:
        keys = set(state_dict_a.keys()) & set(state_dict_b.keys())
    else:
        keys = list(keys)

    # Optionally filter out common buffer keys
    if not include_buffers:
        def is_buffer(k: str) -> bool:
            # Heuristic; tailor as needed
            return ('running_' in k) or ('_tracked' in k)
        keys = [k for k in keys if not is_buffer(k)]

    total_ssd = None
    target_device = device
    target_dtype = dtype

    for k in keys:
        if k not in state_dict_a or k not in state_dict_b:
            if strict:
                raise KeyError(f"Key '{k}' missing from one of the state dicts.")
            else:
                continue

        ta = state_dict_a[k]
        tb = state_dict_b[k]
        if ta.shape != tb.shape:
            if strict:
                raise ValueError(f"Shape mismatch for key '{k}': {ta.shape} vs {tb.shape}")
            else:
                continue

        # Decide accumulation device/dtype lazily
        if target_device is None:
            target_device = ta.device
        if target_dtype is None:
            target_dtype = ta.dtype

        diff = (ta.to(device=target_device, dtype=target_dtype) -
                tb.to(device=target_device, dtype=target_dtype))
        ssd = diff.pow(2).sum()

        if total_ssd is None:
            total_ssd = ssd
        else:
            total_ssd += ssd

    if total_ssd is None:
        # No comparable keys
        return torch.tensor(0.0, device=device or 'cpu', dtype=dtype or torch.float32)
    return total_ssd.sqrt()

"""## Core Training Functions, QGAN

### Core Training Batch Function
"""

# from torchviz import make_dot

# from itertools import chain

# import uuid

# import time

def train_step(generator, discriminator, optD, optG, criterion, results, real_labels, fake_labels, counter, fixed_noise, real_data, fake_data, compressed_img_size, alpha=1.0, disc_img_size=None, pca_disc=True):

    print_cust(f"train_step, disc_img_size: {disc_img_size}")
    # Training the discriminator
    discriminator.zero_grad()

    # TODO: add assertions here about gradient magnitudes; should all be zero (disc at least, should)
    for name, p in generator.named_parameters():
        print_cust(f"train_step, generator param name: {name}, p.requires_grad: {p.requires_grad}")
        if p.grad is None:
            print_cust(f"train_step, generator param name: {name}, no grad")
        else:
            print_cust(f"train_step, generator param name: {name}, p.grad.norm(): {p.grad.norm()}")
            # print_cust(name, p.grad.norm())

    for name, p in discriminator.named_parameters():
        print_cust(f"train_step, discriminator param name: {name}, p.requires_grad: {p.requires_grad}")
        if p.grad is None:
            print_cust(f"train_step, discriminator param name: {name}, no grad")
        else:
            print_cust(f"train_step, discriminator param name: {name}, p.grad.norm(): {p.grad.norm()}")



    print_cust(f"train_step, real_data.shape: {real_data.shape}")
    print_cust(f"train_step, fake_data.shape: {fake_data.shape}")

    disc_beforeinf_statedict = copy.deepcopy(discriminator.state_dict())
    gen_beforeinf_statedict = copy.deepcopy(generator.state_dict())

    # if not pca_disc:
    #     real_data = pad_images_newdim(real_data, disc_img_size ** 2)
    #     fake_data = pad_images_newdim(fake_data, disc_img_size ** 2)

    # print_cust(f"train_step, after padding, real_data.shape: {real_data.shape}")
    # print_cust(f"train_step, after padding, fake_data.shape: {fake_data.shape}")

    # TODO: detach real_data too???
    outD_real = discriminator(real_data.detach(), alpha).view(-1)

    # TODO: see why in the world we detach????
    outD_fake = discriminator(fake_data.detach(), alpha).view(-1)

    disc_afterinf_statedict = copy.deepcopy(discriminator.state_dict())
    gen_afterinf_statedict = copy.deepcopy(generator.state_dict())

    print_cust(f"train_step, real_data.shape: {real_data.shape}")
    print_cust(f"train_step, fake_data.shape: {fake_data.shape}")
    print_cust(f"train_step, outD_real.shape: {outD_real.shape}")
    print_cust(f"train_step, outD_real: {outD_real}")
    print_cust(f"train_step, outD_fake: {outD_fake}")
    print_cust(f"train_step, real_labels.shape: {real_labels.shape}")
    print_cust(f"train_step, real_labels: {real_labels}")
    print_cust(f"train_step, fake_labels.shape: {fake_labels.shape}")
    print_cust(f"train_step, fake_labels: {fake_labels}")

    # could do some assertions here, but can do it later
    print_cust(f"train_step, change in disc params after inf: {l2_state_dict_difference(disc_beforeinf_statedict, disc_afterinf_statedict)}")
    print_cust(f"train_step, change in gen params after inf: {l2_state_dict_difference(gen_beforeinf_statedict, gen_afterinf_statedict)}")

    errD_real = criterion(outD_real, real_labels)

    errD_fake = criterion(outD_fake, fake_labels)

    # err_intermed = errD_real + errD_fake

    # param_dict_intermed = {
    #     **{f"gen.{k}": v for k, v in generator.named_parameters()},
    #     **{f"disc.{k}": v for k, v in discriminator.named_parameters()}
    # }

    # n_qubits_stamp = generator.n_qubits_gen

    # stamp_disc = time.strftime("%Y%m%d-%H%M%S")

    # dot_disc = make_dot(err_intermed,
    #             params=param_dict_intermed,
    #             show_attrs=True,
    #             show_saved=True)

    # dot_disc.render(f"discriminator_autograd_graph_batchsize_one_pcadisc_n_qubits_{n_qubits_stamp}_super_{stamp_disc}", format="png")

    # Propagate gradients
    errD_real.backward()
    errD_fake.backward()

    print_cust(f"train_step, after discriminator .backward()")

    # Didn't clear grads for generator here, so that's why they r non zero I think.

    for name, p in generator.named_parameters():
        print_cust(f"train_step, generator param name: {name}, p.requires_grad: {p.requires_grad}")
        if p.grad is None:
            print_cust(f"train_step, generator param name: {name}, no grad")
        else:
            print_cust(f"train_step, generator param name: {name}, p.grad.norm(): {p.grad.norm()}")
            # print_cust(name, p.grad.norm())


    disc_param_grad_norms = []

    for name, p in discriminator.named_parameters():
        print_cust(f"train_step, discriminator param name: {name}, p.requires_grad: {p.requires_grad}")
        if p.grad is None:
            print_cust(f"train_step, discriminator param name: {name}, no grad")
        else:
            print_cust(f"train_step, discriminator param name: {name}, p.grad.norm(): {p.grad.norm()}")
            disc_param_grad_norms.append(p.grad.norm())

    print_cust(f"train_step, disc_param_grad_norms: {disc_param_grad_norms}")

    errD = errD_real + errD_fake
    orig_generator_params = generator.q_params[0].clone().detach()
    optD.step()
    print_cust(f"train_step, change in generator parameters: {np.linalg.norm(generator.q_params[0].clone().detach() - orig_generator_params)}")
    disc_afterdiscstep_statedict = copy.deepcopy(discriminator.state_dict())
    gen_afterdiscstep_statedict = copy.deepcopy(generator.state_dict())
    print_cust(f"train_step, change in disc params after disc step: {l2_state_dict_difference(disc_afterinf_statedict, disc_afterdiscstep_statedict)}")
    print_cust(f"train_step, change in gen params after disc step: {l2_state_dict_difference(gen_afterinf_statedict, gen_afterdiscstep_statedict)}")

    # Training the generator
    generator.zero_grad()
    outD_fake = discriminator(fake_data, alpha).view(-1)
    print_cust(f"train_step, generator loss, outD_fake: {outD_fake}")
    errG = criterion(outD_fake, real_labels)

    # dot = make_dot(errG,
    #             params=dict(generator.named_parameters()),
    #             show_attrs=True,
    #             show_saved=True)

    # stamp_gen = time.strftime("%Y%m%d-%H%M%S")

    # dot.render(f"generator_autograd_graph_batchsize_one_pcadisc_n_qubits_{n_qubits_stamp}_super_{stamp_gen}", format="png")

    errG.backward()

    # print_cust(f"train_step, grad_log_glob: {grad_log_glob}")

    print_cust(f"train_step, generator gradient norms")

    generator_param_grad_norms = []

    for name, p in generator.named_parameters():
        print_cust(f"train_step, generator param name: {name}, p.requires_grad: {p.requires_grad}")
        if p.grad is None:
            print_cust(f"train_step, generator param name: {name}, no grad")
        else:
            print_cust(f"train_step, generator param name: {name}, p.grad.norm(): {p.grad.norm()}")
            generator_param_grad_norms.append(p.grad.norm())
            # print_cust(name, p.grad.norm())

    print_cust(f"train_step, generator_param_grad_norms: {generator_param_grad_norms}")
    print_cust(f"train_step, discriminator gradient norms")

    for name, p in discriminator.named_parameters():
        print_cust(f"train_step, generator param name: {name}, p.requires_grad: {p.requires_grad}")
        if p.grad is None:
            print_cust(f"train_step, discriminator param name: {name}, no grad")
        else:
            print_cust(f"train_step, discriminator param name: {name}, p.grad.norm(): {p.grad.norm()}")

    optG.step()
    disc_aftergenstep_statedict = copy.deepcopy(discriminator.state_dict())
    gen_aftergenstep_statedict = copy.deepcopy(generator.state_dict())

    print_cust(f"train_step, change in disc params after gen step: {l2_state_dict_difference(disc_afterdiscstep_statedict, disc_aftergenstep_statedict)}")
    print_cust(f"train_step, change in gen params after gen step: {l2_state_dict_difference(gen_afterdiscstep_statedict, gen_aftergenstep_statedict)}")
    # TODO: verify that the discriminator parameters do not change here.

    # Show loss values
    print_cust(f'train_step, Iteration: {counter}, Discriminator Loss: {errD:0.3f}, Generator Loss: {errG:0.3f}')
    if (counter + 1) % 10 == 0:
        # TODO: add some kind of adapter to indicate that I need to transform the size here.
        if not pca_disc:
            test_images = generator(fixed_noise, alpha).view(8,1,compressed_img_size,compressed_img_size).cpu().detach()
            # visualize_generator_imgs(generator, fixed_noise, compressed_img_size, alpha=alpha)
        else:
            test_images = generator(fixed_noise, alpha).cpu().detach()
        # Save images every 50 iterations
        if (counter + 1) % 50 == 0:
            results.append(test_images)

    # verify that testimgs is not anything wild?

    return errD.detach(), errG.detach(), disc_param_grad_norms, generator_param_grad_norms
    # verify that testimgs is not anything wild?

"""### Core Training Epochs Function"""

# from re import A
# TODO: add disc_img_size
def train_models(n_qubits, batch_size, generator, discriminator, optD, optG, noise_func, criterion, train_data, device, image_size, compressed_img_size, max_num_epochs, n_qubits_small=0, gen_pcas=True, disc_img_size=None, pca_disc=True):
    # NOTE: this function mutates the inputs.
    print_cust(f"train_models, train_data.shape: {train_data.shape}")
    n_train_data = train_data.shape[0]

    print_cust(f"train_models, gen_pcas: {gen_pcas}, pca_disc: {pca_disc}")
    print_cust(f"train_models, disc_img_size: {disc_img_size}")
    real_labels = torch.full((batch_size,), 1.0, dtype=torch.float, device=device)
    fake_labels = torch.full((batch_size,), 0.0, dtype=torch.float, device=device)

    # Fixed noise allows us to visually track the generated images throughout training
    # TODO: make the fixed noise dim a variable
    fixed_noise = noise_func(8, n_qubits, device)

    # Collect images for plotting later
    results = []

    log_metrics = {
      "disc_grad_norms": [],
      "gen_grad_norms": [],
      "disc_loss": [],
      "gen_loss": []
    }

    print_cust(f"train_models, log_metrics: {log_metrics}")

    counter = 0

    # print_cust(f"train_models, len(dataloader): {len(dataloader)}")

    print_cust(f"train_models, n_train_data: {n_train_data}")

    # max_num_epochs = math.ceil(num_iter / n_train_data)

    print_cust(f"train_models, max_num_epochs: {max_num_epochs}")

    for epoch_num in range(max_num_epochs):

        print_cust(f"train_models, [SGD] Epoch {epoch_num+1}/{max_num_epochs}")
        indices = np.arange(n_train_data)
        np.random.shuffle(indices)
        for start in range(0, n_train_data, batch_size):

            end = min(start + batch_size, n_train_data)
            batch_indices = indices[start:end]
            if len(batch_indices) < batch_size:
              print_cust(f"train_models, skipping batch because too small. len(batch_indices): {len(batch_indices)}")
              continue
            X_batch = train_data[batch_indices, :]
            print_cust(f"train_models, batch_indices: {batch_indices}")
            print_cust(f"train_models, X_batch.shape: {X_batch.shape}")

            # Data for training the discriminator
            # TODO: add an image processor adapter here.
            # alpha = counter / num_iter

            # FOR NOW
            if gen_pcas:
                alpha = 1.0

            print_cust(f"train_models, alpha: {alpha}")

            if not pca_disc:
                X_batch = X_batch.reshape(-1, image_size * image_size)
            # print_cust(f"data.shape: {data.shape}")
            # print_cust(f"np.linalg.norm(data[0]): {np.linalg.norm(data[0])}")
            # print_cust(f"data[0].min(): {data[0].min()}, data[0].max(): {data[0].max()}")
            # print_cust(f"data[0]: {data[0]}")
            real_data = X_batch.to(device)

            # Noise follwing a uniform distribution in range [0,pi/2)
            noise = noise_func(batch_size, n_qubits, device, n_qubits_small=n_qubits_small)
            # print_cust(f"train_models, noise.shape: {noise.shape}")
            # print_cust(f"train_models, noise: {noise}")
            # print_cust(f"train_models, generator: {generator}")

            fake_data = generator(noise, alpha=alpha)

            # fake_data = morton_permute_batched(fake_data)

            # print_cust(f"fake_data.shape: {fake_data.shape}")
            # print_cust(f"np.linalg.norm(fake_data[0]): {np.linalg.norm(fake_data[0].detach().numpy())}")
            # print_cust(f"fake_data[0]: {fake_data[0]}")

            # print_cust(f"fake_data.shape: {fake_data.shape}")
            # print_cust(f"real_data.shape: {real_data.shape}")

            # if not gen_pcas:
            #     real_data = progressive_resize_batch(real_data, fake_data.shape[1])
            #     real_data /= real_data.max()

            print_cust(f"X_batch.shape: {X_batch.shape}")
            print_cust(f"np.linalg.norm(real_data[0]): {np.linalg.norm(real_data[0])}")
            print_cust(f"real_data[0].min(): {real_data[0].min()}, real_data[0].max(): {real_data[0].max()}")
            print_cust(f"real_data[0]: {real_data[0]}")

            # if ((counter + 1) % 10 == 0) and not pca_disc:
            #     print_cust(f"real_data[0].shape: {real_data[0].shape}")
            #     print_cust(f"real_data[0].reshape(compressed_img_size, compressed_img_size): {real_data[0].reshape(compressed_img_size, compressed_img_size)}")
            #     real_data_reshaped = real_data[0].reshape(compressed_img_size, compressed_img_size)
            #     # 1. convert to NumPy (detach if the tensor tracks gradients)
            #     # --- 2. convert to NumPy for matplotlib --------------------------------------
            #     img4 = real_data_reshaped.numpy()                       # shape (4, 4)

            #     # --- 3. nearest-neighbour up-sampling to 8×8 ---------------------------------
            #     # duplicate every row, then every column
            #     img8 = np.repeat(np.repeat(img4, 2, axis=0), 2, axis=1)   # shape (8, 8)

            #     # --- 4. plot both images side by side ----------------------------------------
            #     fig, axes = plt.subplots(1, 2, figsize=(6, 3))

            #     print_cust(f"img4.shape: {img4.shape}")
            #     print_cust(f"img4: {img4}")
            #     axes[0].imshow(img4, cmap='gray', vmin=0, vmax=1, interpolation='nearest')
            #     axes[0].set_title('4×4 original')
            #     axes[0].axis('off')

            #     print_cust(f"img8.shape: {img8.shape}")
            #     print_cust(f"img8: {img8}")
            #     axes[1].imshow(img8, cmap='gray', vmin=0, vmax=1, interpolation='nearest')
            #     axes[1].set_title('8×8 nearest-neighbour')
            #     axes[1].axis('off')

            #     plt.tight_layout()
            #     plt.show()

            #     plt.imshow(img8, cmap='gray', vmin=0, vmax=1)

            # NOTE: technically not 'semantically' correct to increment counter here, but functionally, it is the same (FOR NOW)
            print_cust(f"train_models, before counter {counter}, generator: {generator}, discriminator: {discriminator}")
            errD, errG, disc_param_grad_norms, generator_param_grad_norms = train_step(generator, discriminator, optD, optG, criterion, results, real_labels, fake_labels, counter, fixed_noise, real_data, fake_data, compressed_img_size, alpha=alpha, disc_img_size=disc_img_size, pca_disc=pca_disc)
            print_cust(f"train_models, errD: {errD}, errG: {errG}, disc_param_grad_norms: {disc_param_grad_norms}, generator_param_grad_norms: {generator_param_grad_norms}")
            print_cust(f"train_models, after counter {counter}, generator: {generator}, discriminator: {discriminator}")
            counter += 1

            print_cust(f"train_models, counter: {counter}")

            log_metrics["disc_loss"].append(errD)
            log_metrics["gen_loss"].append(errG)
            log_metrics["disc_grad_norms"].append(disc_param_grad_norms)
            log_metrics["gen_grad_norms"].append(generator_param_grad_norms)
            # if counter == num_iter:
            #     break

    print_cust(f"train_models, about to return, log_metrics: {log_metrics}")
    return results, log_metrics

"""### Custom KLLoss Function Module"""

class KLLoss(nn.Module):
    """KL divergence loss for self distillation."""

    def __init__(self):
        super().__init__()
        self.temperature = 1

    def forward(self, pred, label):
        """KL loss forward."""
        # predict = F.log_softmax(pred / self.temperature, dim=1)
        # target_data = F.softmax(label / self.temperature, dim=1)
        predict = pred.log()
        target_data = label
        target_data = target_data + 10 ** (-7)
        with torch.no_grad():
            target = target_data.detach().clone()

        loss = (
            self.temperature
            * self.temperature
            * ((target * (target.log() - predict)).sum(1).sum() / target.size()[0])
        )
        print_cust(f"KLLoss, forward, loss: {loss}")
        return loss

"""### Core Training Batch Function, Variational Classifier"""

def train_step_classifier(classifier_model, opt_classifier, criterion, real_labels, counter, real_data, loss_type="depthfl"):
    # print_cust(f"train_step, disc_img_size: {disc_img_size}")
    # Training the discriminator

    print_cust(f"train_step_classifier, opt_classifier: {opt_classifier}")

    print_cust(f"train_step_classifier, loss_type: {loss_type}")

    classifier_model.zero_grad()

    # TODO: add assertions here about gradient magnitudes; should all be zero (disc at least, should)
    for name, p in classifier_model.named_parameters():
        print_cust(f"train_step_classifier, classifier_model param name: {name}, p.requires_grad: {p.requires_grad}")
        if p.grad is None:
            print_cust(f"train_step_classifier, classifier_model param name: {name}, no grad")
        else:
            print_cust(f"train_step_classifier, classifier_model param name: {name}, p.grad.norm(): {p.grad.norm()}")
            # print_cust(name, p.grad.norm())

    print_cust(f"train_step_classifier, real_data.shape: {real_data.shape}")

    classifier_beforeinf_statedict = copy.deepcopy(classifier_model.state_dict())

    # if not pca_disc:
    #     real_data = pad_images_newdim(real_data, disc_img_size ** 2)
    #     fake_data = pad_images_newdim(fake_data, disc_img_size ** 2)

    # print_cust(f"train_step, after padding, real_data.shape: {real_data.shape}")
    # print_cust(f"train_step, after padding, fake_data.shape: {fake_data.shape}")

    # TODO: detach real_data too???
    # NOTE, layers: taking that first column; the probability of observing the state to be 1.
    # DONE: TOMODIFY, DepthFL: this will be multiple probabilities; postprocess into the different losses that I want.

    outD_real = classifier_model(real_data.detach())
    if len(outD_real.shape) == 2:
      print_cust(f"train_step_classifier, outD_real.shape: {outD_real.shape}")
      outD_real = outD_real[:, 1].view(-1)
    print_cust(f"train_step_classifier, after running through classifier_model, outD_real.shape: {outD_real.shape}")
    # otherwise, outD_real is the same shape as I wanted (n_classifiers, input_dim, n_classes).


    classifier_afterinf_statedict = copy.deepcopy(classifier_model.state_dict())

    print_cust(f"train_step_classifier, real_data.shape: {real_data.shape}")
    print_cust(f"train_step_classifier, outD_real.shape: {outD_real.shape}")
    print_cust(f"train_step_classifier, outD_real: {outD_real}")
    print_cust(f"train_step_classifier, real_labels.shape: {real_labels.shape}")
    print_cust(f"train_step_classifier, real_labels: {real_labels}")

    # could do some assertions here, but can do it later
    print_cust(f"train_step_classifier, change in classifier params after inf: {l2_state_dict_difference(classifier_beforeinf_statedict, classifier_afterinf_statedict)}")

    # DONE: TOMODIFY, DepthFL: inject a new criterion. or, based on shape of the output, hardcode a loss fn for now.
    if len(outD_real.shape) == 1:
      errD_real = criterion(outD_real, real_labels)
    else:
      # TOMODIFY, DepthFL: make this on cuda() for GPU acceleration?
      criterion_kl = KLLoss()
      print_cust(f"train_step_classifier, len(outD_real.shape) != 1, outD_real.shape: {outD_real.shape}")
      errD_real = torch.zeros(1).to(classifier_model.device)
      if loss_type == "standalone":
        errD_real += criterion(outD_real[-1][:, 1], real_labels)
      elif loss_type == "depthfl":
        for one_output_idx in range(outD_real.shape[0]):
          print_cust(f"train_step_classifier, one_output_idx: {one_output_idx}")
          one_output_preds = outD_real[one_output_idx]
          print_cust(f"train_step_classifier, one_output_preds.shape: {one_output_preds.shape}")
          errD_real += criterion(one_output_preds[:, 1], real_labels)
          for second_output_idx in range(outD_real.shape[0]):
            print_cust(f"train_step_classifier, second_output_idx: {second_output_idx}")
            if second_output_idx == one_output_idx:
              continue
            second_output_preds = outD_real[second_output_idx]
            errD_real += (criterion_kl(one_output_preds, second_output_preds.detach()) / (outD_real.shape[0] - 1))

    print_cust(f"train_step_classifier, errD_real: {errD_real}")

    # NOTE, depthFL: can print out gradients at this step, to verify that they are still None.

    # Propagate gradients
    errD_real.backward()

    print_cust(f"train_step_classifier, after classifier .backward()")

    # Didn't clear grads for generator here, so that's why they r non zero I think.

    classifier_param_grad_norms = []

    for name, p in sorted(classifier_model.named_parameters(), key=lambda kv: kv[0]):
        print_cust(f"train_step_classifier, classifier_model param name: {name}, p.requires_grad: {p.requires_grad}")
        if p.grad is None:
            print_cust(f"train_step_classifier, classifier_model param name: {name}, no grad")
        else:
            print_cust(f"train_step_classifier, classifier_model param name: {name}, p.grad.norm(): {p.grad.norm()}")
            # BUG (possible), DepthFL: does p.grad.norm() require backprop, and thus consume too much memory when
            # sending back? if so, call p.grad.norm().detach().
            classifier_param_grad_norms.append(p.grad.norm())

    print_cust(f"train_step_classifier, classifier_param_grad_norms: {classifier_param_grad_norms}")

    errD = errD_real
    opt_classifier.step()
    classifier_afterdiscstep_statedict = copy.deepcopy(classifier_model.state_dict())
    print_cust(f"train_step_classifier, change in classifier params after classifier step: {l2_state_dict_difference(classifier_afterinf_statedict, classifier_afterdiscstep_statedict)}")

    # Show loss values
    print_cust(f'train_step_classifier, Iteration: {counter}, Classifier Loss: {errD.item():0.3f}')

    return errD.detach(), classifier_param_grad_norms

"""### Core Training Epochs Function, Variational Classifier"""

# TODO, depthFL, 8/19, 7:04 PM: continue here

def train_models_classifier(batch_size, classifier_model, opt_classifier, criterion, train_data, train_labels, device, max_num_epochs, loss_type="depthfl"):
    # NOTE: this function mutates the inputs.
    print_cust(f"train_models_classifier, loss_type: {loss_type}")
    print_cust(f"train_models_classifier, train_data.shape: {train_data.shape}")

    print_cust(f"train_models_classifier, opt_classifier: {opt_classifier}")
    n_train_data = train_data.shape[0]

    print_cust(f"train_models_classifier, train_labels.shape: {train_labels.shape}")

    if train_data.shape[0] != train_labels.shape[0]:
      print_cust(f"train_models_classifier, amount of training data does not equal amount of training labels, train_data.shape[0]: {train_data.shape[0]}, train_labels.shape[0]: {train_labels.shape[0]}")

    print_cust(f"train_models_classifier, train_data.min(): {train_data.min()}, train_data.max(): {train_data.max()}")
    print_cust(f"train_models_classifier, train_labels.min(): {train_labels.min()}, train_labels.max(): {train_labels.max()}")
    # Collect images for plotting later

    log_metrics = {
      "disc_grad_norms": [],
      "gen_grad_norms": [],
      "disc_loss": [],
      "gen_loss": []
    }

    print_cust(f"train_models_classifier, log_metrics: {log_metrics}")

    counter = 0

    # print_cust(f"train_models, len(dataloader): {len(dataloader)}")

    print_cust(f"train_models_classifier, n_train_data: {n_train_data}")

    # max_num_epochs = math.ceil(num_iter / n_train_data)

    print_cust(f"train_models_classifier, max_num_epochs: {max_num_epochs}")

    for epoch_num in range(max_num_epochs):

        print_cust(f"train_models_classifier, [SGD] Epoch {epoch_num+1}/{max_num_epochs}")
        indices = np.arange(n_train_data)
        np.random.shuffle(indices)
        for start in range(0, n_train_data, batch_size):

            end = min(start + batch_size, n_train_data)
            batch_indices = indices[start:end]
            if len(batch_indices) < batch_size:
              print_cust(f"train_models_classifier, skipping batch because too small. len(batch_indices): {len(batch_indices)}")
              continue
            X_batch = train_data[batch_indices, :]
            y_batch = train_labels[batch_indices]
            print_cust(f"train_models_classifier, batch_indices: {batch_indices}")
            print_cust(f"train_models_classifier, X_batch.shape: {X_batch.shape}")
            print_cust(f"train_models_classifier, y_batch.shape: {y_batch.shape}")
            print_cust(f"train_models_classifier, y_batch: {y_batch}")

            # Data for training the discriminator
            # TODO: add an image processor adapter here.
            # alpha = counter / num_iter

            # FOR NOW

            # print_cust(f"data.shape: {data.shape}")
            # print_cust(f"np.linalg.norm(data[0]): {np.linalg.norm(data[0])}")
            # print_cust(f"data[0].min(): {data[0].min()}, data[0].max(): {data[0].max()}")
            # print_cust(f"data[0]: {data[0]}")
            real_data = X_batch.to(device)

            real_labels = y_batch.to(device)

            print_cust(f"train_models_classifier, X_batch.shape: {X_batch.shape}")
            print_cust(f"train_models_classifier, np.linalg.norm(real_data[0]): {np.linalg.norm(real_data[0])}")
            print_cust(f"train_models_classifier, real_data[0].min(): {real_data[0].min()}, real_data[0].max(): {real_data[0].max()}")
            print_cust(f"train_models_classifier, real_data[0]: {real_data[0]}")

            # NOTE: technically not 'semantically' correct to increment counter here, but functionally, it is the same (FOR NOW)
            print_cust(f"train_models_classifier, before counter {counter}, classifier_model: {classifier_model}")
            # TODO, layers: change this train_step call; should call train_step_classifier
            errD, disc_param_grad_norms = train_step_classifier(classifier_model, opt_classifier, criterion, real_labels, counter, real_data, loss_type=loss_type)
            print_cust(f"train_models_classifier, errD: {errD}, disc_param_grad_norms: {disc_param_grad_norms}")
            print_cust(f"train_models_classifier, after counter {counter}, classifier_model: {classifier_model}")
            counter += 1

            print_cust(f"train_models_classifier, counter: {counter}")

            log_metrics["disc_loss"].append(errD)
            log_metrics["disc_grad_norms"].append(disc_param_grad_norms)
            # if counter == num_iter:
            #     break

    print_cust(f"train_models_classifier, len(log_metrics['disc_loss']): {len(log_metrics['disc_loss'])}")
    print_cust(f"train_models_classifier, len(log_metrics['disc_grad_norms']): {len(log_metrics['disc_grad_norms'])}")
    print_cust(f"train_models_classifier, about to return, log_metrics: {log_metrics}")
    return log_metrics

# DONE (both train_step_classifier and train_models_classifier): TOADD, layers: a completely new/different training loop for training a quantum classifier. Not sure what the arguments should be, but model it after the generative one,
# and note that we are using PyTorch here.
# think about arguments for train_models and train_step.
# NOTE, layers: REMEMBER to ONLY optimize ONE layer at a time. -- (this should be handled at optimizer creation.)

"""### Other Helpers"""

# TODO, DepthFL: continue reading code here, 8/26, 12:40 PM
# REPURPOSE: inject the optimizer in (for generator and discriminator), inject a noise function, inject the models for generator and discriminator.

def tuple_to_numpy(tensor_tuple):
    """
    Converts a tuple of PyTorch tensors (which may require grad) into a tuple of numpy arrays.
    Uses .detach() if available (PyTorch 0.4+); otherwise, it falls back to .data.

    Args:
        tensor_tuple (tuple): A tuple containing PyTorch tensors.

    Returns:
        tuple: A tuple containing the corresponding numpy arrays.
    """
    numpy_arrays = []
    for t in tensor_tuple:
        # print_cust(dir(t))
        # print_cust(type(t))
        numpy_arrays.append(np.array(t).copy())
    return tuple(numpy_arrays)

import random

"""## QCNN QNode Creation"""

from functools import partial

def qcnn_template(input_angles, params, n_qubits, expansion_data, n_classes, num_ancillas=0, layer_types_list=[], cheating=False, tunn_down=False):
  # TOMODIFY, depthFL: add num_ancillas to what this function calls.
  print_cust(f"qcnn_template, input_angles: {input_angles}, params: {params}, n_qubits: {n_qubits}, expansion_data: {expansion_data}, n_classes: {n_classes}, num_ancillas: {num_ancillas}, layer_types_list: {layer_types_list}, cheating: {cheating}, tunn_down: {tunn_down}")
  # print_cust(f"qcnn_template, qml.draw: {qml.draw(QCNN_circuit_dynamic)(input_angles, params[0], params[1], params[2], params[3], n_qubits, expansion_data=expansion_data, block_params_list=params[5], n_classes=n_classes, num_ancillas=num_ancillas, layer_types_list=layer_types_list, cheating=cheating, tunn_down=tunn_down)}")
  return QCNN_circuit_dynamic(input_angles, params[0], params[1], params[2], params[3],
                              n_qubits, expansion_data=expansion_data, block_params_list=params[5], n_classes=n_classes, num_ancillas=num_ancillas, layer_types_list=layer_types_list, cheating=cheating, tunn_down=tunn_down)

def create_qnode_qcnn(n_data, conv_layers, expansion_data, n_classes=2, pennylane_interface="autograd", num_ancillas=0, layer_types_list=[], cheating=False, tunn_down=False):
    """
    Creates a QNode of the specified size.

    Parameters:
      n_data: an integer representing the number of qubits for this QNode
      conv_layers: an integer representing the number of convolutional layers in this QNode
      expansion_data: a list of (encoded_angles, reversed_indices) that is used for identity initialization
      n_classes: an integer representing the number of output classes for this QCNN

    Returns:
      qnode: a QNode object satisfying the above constraints.
    """
    n_qubits = n_data
    dev = qml.device("default.qubit", wires=n_qubits)
    print_cust(f"create_qnode_qcnn, pennylane_interface: {pennylane_interface}, num_ancillas: {num_ancillas}, cheating: {cheating}, tunn_down: {tunn_down}")
    circuit = partial(qcnn_template, n_qubits=n_qubits, expansion_data=expansion_data, n_classes=n_classes, num_ancillas=num_ancillas, layer_types_list=layer_types_list, cheating=cheating, tunn_down=tunn_down)

    return qml.qnode(dev, interface=pennylane_interface)(circuit)

# DONE: TOMODIFY, layers: add argument for pennylane interface, and inject it here.
# DONE: TOMODIFY, layers: may have to do that partial functools call as done in the QGAN case below.

"""## Multi-Evaluation QCNN QNode Creation"""

def run_multiprob_qnode(input_angles, params, qnode):
  output_probs_list = []
  block_params_list = params[5]
  for block_param_idx in range(1, len(block_params_list) + 1):
    params_copy = list(copy.deepcopy(params))
    sub_blockparams_list = block_params_list[:block_param_idx]
    params_copy[5] = sub_blockparams_list
    params_copy = tuple(params_copy)
    print_cust(f"run_multiprob_qnode, block_param_idx: {block_param_idx}, params_copy: {params_copy}")
    output_probs = qnode(input_angles, params_copy)
    output_probs_list.append(output_probs)
  output_probs_tensors = torch.stack(output_probs_list)
  print_cust(f"run_multiprob_qnode, output_probs_tensors: {output_probs_tensors}, output_probs_tensors.shape: {output_probs_tensors.shape}")
  return output_probs_tensors

def create_qnode_qcnn_multieval(n_data, conv_layers, expansion_data, n_classes=2, pennylane_interface="autograd", layer_types_list=[]):
    """
    Creates a multi-evaluation QNode of the specified size.

    Parameters:
      n_data: an integer representing the number of qubits for this QNode
      conv_layers: an integer representing the number of convolutional layers in this QNode
      expansion_data: a list of (encoded_angles, reversed_indices) that is used for identity initialization
      n_classes: an integer representing the number of output classes for this QCNN

    Returns:
      qnode: a QNode object satisfying the above constraints.
    """
    qcnn_qnode = create_qnode_qcnn(n_data, conv_layers, expansion_data, n_classes=n_classes, pennylane_interface=pennylane_interface, layer_types_list=layer_types_list)
    # return a partial to run_multiprob_qnode
    multiprob_fn = partial(run_multiprob_qnode, qnode=qcnn_qnode)
    return multiprob_fn

# DONE: TOMODIFY, depthFL: add a builder for a wrapper class that contains a qnode that, instead of running the params through the qnode ONCE, instead runs it through the qnode
# multiple times based on the number of layers in params.
# TOMODIFY, depthFL: later, add an argument to specify which layers to run through the qnode for extensibility; for now, just run ALL of the layers
# again in the qnode, for the ensemble cost evaluation.

# TOMODIFY, depthFL: trying ancillas to get statistics of the circuit in parallel.

"""## Single-Evaluation, Ancilla QCNN QNode Creation"""

def run_singleeval_qnode(input_angles, params, qnode):
  print_cust(f"run_singleeval_qnode, params: {params}")
  output_probs_list = qnode(input_angles, params)
  # assumed to be a list of probabilities for each qubit output, because the QC should be using ancillas and return
  # a list of torch tensors.
  output_probs_tensors = torch.stack(output_probs_list)
  print_cust(f"run_singleeval_qnode, output_probs_tensors: {output_probs_tensors}, output_probs_tensors.shape: {output_probs_tensors.shape}")
  return output_probs_tensors

def create_qnode_qcnn_singleeval(n_data, conv_layers, expansion_data, n_classes=2, pennylane_interface="autograd", block_layers=5, layer_types_list=[]):
    """
    Creates a multi-evaluation QNode of the specified size.

    Parameters:
      n_data: an integer representing the number of qubits for this QNode
      conv_layers: an integer representing the number of convolutional layers in this QNode
      expansion_data: a list of (encoded_angles, reversed_indices) that is used for identity initialization
      n_classes: an integer representing the number of output classes for this QCNN

    Returns:
      qnode: a QNode object satisfying the above constraints.
    """
    # TOMODIFY, DepthFL, HACK: inject the number of additional ancillas that I want my qnode to have.
    # ^ currently, is hardcoded as a default arg to see if it works.
    print_cust(f"create_qnode_qcnn_singleeval, n_data: {n_data}, block_layers: {block_layers}")
    # TOMODIFY, DepthFL: for now, assuming (or maybe not? it's just named this way) that number of
    # block layers = number of ancillas in the circuit.
    qcnn_qnode = create_qnode_qcnn(n_data + block_layers, conv_layers, expansion_data, n_classes=n_classes, pennylane_interface=pennylane_interface, num_ancillas=block_layers, layer_types_list=layer_types_list)
    # return a partial to run_multiprob_qnode
    multiprob_fn = partial(run_singleeval_qnode, qnode=qcnn_qnode)
    return multiprob_fn

"""## Single-Evaluation, Cheating QCNN QNode Creation"""

def run_multiprob_qnode_cheating(input_angles, params, qnode):
  print_cust(f"run_multiprob_qnode_cheating, params: {params}")
  for block_param_idx, block_param in enumerate(params[5]):
    print_cust(f"run_multiprob_qnode_cheating, block_param_idx: {block_param_idx}, block_param: {block_param}")
  snapshot_results = qml.snapshots(qnode)(input_angles, params)
  list_results = []
  for block_param_idx in range(len(params[5])):
    list_results.append(snapshot_results[f"p0_bp{block_param_idx}"])
  # assumed to be a list of probabilities for each qubit output, because the QC should be using ancillas and return
  # a list of torch tensors.
  # TODO: postprocess differently, to get the cheating statistics.
  output_probs_tensors = torch.stack(list_results)
  print_cust(f"run_multiprob_qnode_cheating, output_probs_tensors: {output_probs_tensors}, output_probs_tensors.shape: {output_probs_tensors.shape}")
  return output_probs_tensors

def create_qnode_qcnn_multieval_cheating(n_data, conv_layers, expansion_data, n_classes=2, pennylane_interface="autograd", layer_types_list=[]):
    """
    Creates a multi-evaluation QNode of the specified size.

    Parameters:
      n_data: an integer representing the number of qubits for this QNode
      conv_layers: an integer representing the number of convolutional layers in this QNode
      expansion_data: a list of (encoded_angles, reversed_indices) that is used for identity initialization
      n_classes: an integer representing the number of output classes for this QCNN

    Returns:
      qnode: a QNode object satisfying the above constraints.
    """
    qcnn_qnode = create_qnode_qcnn(n_data, conv_layers, expansion_data, n_classes=n_classes, pennylane_interface=pennylane_interface, layer_types_list=layer_types_list, cheating=True)
    # return a partial to run_multiprob_qnode
    multiprob_fn = partial(run_multiprob_qnode_cheating, qnode=qcnn_qnode)
    return multiprob_fn

"""## Single-Evaluation, Tunneling-Down Circuit Design"""
def run_singleeval_tunneldown_qnode(input_angles, params, qnode):
  print_cust(f"run_singleeval_tunneldown_qnode, params: {params}")
  output_probs_list = qnode(input_angles, params)
  # assumed to be a list of probabilities for each qubit output, because the QC should be using ancillas and return
  # a list of torch tensors.
  output_probs_tensors = torch.stack(output_probs_list)
  print_cust(f"run_singleeval_tunneldown_qnode, output_probs_tensors: {output_probs_tensors}, output_probs_tensors.shape: {output_probs_tensors.shape}")
  return output_probs_tensors

def create_qnode_qcnn_singleeval_tunneldown(n_data, conv_layers, expansion_data, n_classes=2, pennylane_interface="autograd", layer_types_list=[]):
    """
    Creates a multi-evaluation QNode of the specified size.

    Parameters:
      n_data: an integer representing the number of qubits for this QNode
      conv_layers: an integer representing the number of convolutional layers in this QNode
      expansion_data: a list of (encoded_angles, reversed_indices) that is used for identity initialization
      n_classes: an integer representing the number of output classes for this QCNN

    Returns:
      qnode: a QNode object satisfying the above constraints.
    """
    # TOMODIFY, DepthFL, HACK: inject the number of additional ancillas that I want my qnode to have.
    # ^ currently, is hardcoded as a default arg to see if it works.
    print_cust(f"create_qnode_qcnn_singleeval_tunneldown, n_data: {n_data}, block_layers")
    # TOMODIFY, DepthFL: for now, assuming (or maybe not? it's just named this way) that number of
    # block layers = number of ancillas in the circuit.
    # create_qnode_qcnn(n_data, conv_layers, expansion_data, n_classes=2, pennylane_interface="autograd", num_ancillas=0, layer_types_list=[], cheating=False):
    qcnn_qnode = create_qnode_qcnn(n_data, conv_layers, expansion_data, n_classes=n_classes, pennylane_interface=pennylane_interface, layer_types_list=layer_types_list, tunn_down=True)
    # return a partial to run_multiprob_qnode
    multiprob_fn = partial(run_singleeval_tunneldown_qnode, qnode=qcnn_qnode)
    return multiprob_fn

"""## QGAN QNode Creation

### QNode QGAN, Global Function
"""

def qgan_template(input_noise, params, n_qubits):
    # params should just be a list of tensors representing the parameters to be used in the QGAN.
    # this qnode should not be exposed externally; should only be used in the context of PatchQuantumGenerator.
    print_cust(f"qgan_template, input_noise: {input_noise}, n_qubits: {n_qubits}, params: {params}")
    # print_cust(f"qgan_template, qml.draw: {qml.draw(QGAN_circuit)(input_noise, n_qubits, params)}")
    return QGAN_circuit(input_noise, n_qubits, params)

def create_qnode_qgan(n_data):
    """
    Creates a QNode of the specified size.

    Parameters:
      n_data: an integer representing the number of qubits for this QNode
      conv_layers: an integer representing the number of convolutional layers in this QNode
      expansion_data: a list of (encoded_angles, reversed_indices) that is used for identity initialization
      n_classes: an integer representing the number of output classes for this QCNN

    Returns:
      qnode: a QNode object satisfying the above constraints.
    """
    n_qubits = n_data
    dev = qml.device("default.qubit", wires=n_qubits)
    circuit = partial(qgan_template, n_qubits=n_qubits)
    # @qml.qnode(dev, interface="torch")
    # def qnode_qgan(input_noise, n_qubits, params):
    #     # params should just be a list of tensors representing the parameters to be used in the QGAN.
    #     # this qnode should not be exposed externally; should only be used in the context of PatchQuantumGenerator.
    #     print_cust(f"qnode_qgan, qml.draw: {qml.draw(QGAN_circuit)(input_noise, n_qubits, params)}")
    #     return QGAN_circuit(input_noise, n_qubits, params)
    # return qnode_qgan
    return qml.qnode(dev, interface="torch")(circuit)

# TOADD: a function with the same inputs, but really, I just need input_angles (noise), n_qubits, and block_params_list.
# TOADD: the interface should be TORCH, *not* autograd.

# DONE (same as above): TOADD, layers: a function for creating a qnode for the quantum classification case, as well as the template for this case.

"""## Feature List Expansion Data Function"""

def create_feat_list_expansion_data(n_qubits, conv_layers, expand=False, pool_in=False, min_qubits_noexpand=NUM_BASE_QUBITS, train_models_parallel=False, feat_sel_type="top"):
    """
    Create the feature list (ordering of features on qubits) and expansion data (a list of (encoded_angles, reversed_indices) for
    use in identity initialization) for the specified model size and convolutional layer count.

    Parameters:
      n_qubits: Integer representing the number of qubits in the model.
      conv_layers: Integer representing the number of convolutional layers in the model.
      expand: Boolean representing whether or not this model is an expansion of a smaller model.
      pool_in: Boolean indicating whether or not the qubits should be measured/compressed inwards.
      min_qubits_noexpand: Integer representing the minimum number of qubits for which no expansion shoudl take place, across all input models.
      train_models_parallel: Boolean indicating whether or not the models are being trained in parallel (in this case, no identity initialization takes place).
      feat_sel_type: String representing the order in which the features should be chosen.

    Returns:
      feature_list: List of integers representing the ordering of the features on the qubits
      expansion_data: List of (encoded_angle_indices, reversed_indices) for specifying what input data should be initialized to the identity.
    """
    # Store the current feature list, the total number of features, as well as the remaining features.
    cur_feature_list = []
    expansion_data = []
    total_num_feats = compute_reduced_qubits(n_qubits, 0)
    remaining_feats_list = list(range(total_num_feats))

    # num_feats = n_qubits

    # TODO: check/do this only for pool_in
    # For each layer in the convolutional layer, starting from the smallest layer and increasing,
    # compute the feature list and "expand" the layers.
    for layer in range(conv_layers, -1, -1):
        num_feats = compute_reduced_qubits(n_qubits, layer)
        # No expansion data if the number of qubits is less than the minimum number of qubits for expansion.
        if num_feats < min_qubits_noexpand:
          expansion_data.append([])
          continue
        # If this is the "base" number of qubits for expansion, then initialize the feature list and expansion data.
        if num_feats == min_qubits_noexpand:
          # Either take the top features, or the top and lowest features.
          if feat_sel_type == "top":
            cur_feature_list = list(range(num_feats))
          elif feat_sel_type == "toplow":
            half_feats = num_feats // 2
            # cur_feature_list.extend(all_feats[:half_feats])
            # cur_feature_list.extend(all_feats[-half_feats:])
            cur_feature_list = remaining_feats_list[:half_feats] + remaining_feats_list[-half_feats:]
            remaining_feats_list = sorted(set(remaining_feats_list) - set(cur_feature_list))
          expansion_data.append([])
        else:
          # If this is a convolutional layer that is an expansion of a smaller layer, then for the feature list,
          # insert features to minimize the distance of the remaining qubits that are not pooled.
          # Store the order of features and qubits that were added in the expansion data.
          print_cust(f"experiment_dynamic_QCNN, cur_feature_list: {cur_feature_list}")
          insertion_indices = []
          cur_half_feats = len(cur_feature_list) // 2
          for cur_feat_idx in range(len(cur_feature_list)):
            if cur_feat_idx < cur_half_feats:
              insertion_indices.append(cur_feat_idx * 2)
            else:
              insertion_indices.append(cur_feat_idx * 2 + 1)
          print_cust(f"experiment_dynamic_QCNN, insertion_indices: {insertion_indices}")
          # Either take the top, or the top and lowest features.
          if feat_sel_type == "top":
            new_feat_idxs = list(range(len(cur_feature_list), num_feats))
          elif feat_sel_type == "toplow":
            # all_feat_idxs = sorted(set(range(num_feats)) - set(cur_feature_list))
            num_feats_to_select = num_feats - len(cur_feature_list)
            half_feats_to_select = num_feats_to_select // 2
            new_feat_idxs = remaining_feats_list[:half_feats_to_select] + remaining_feats_list[-half_feats_to_select:]
            remaining_feats_list = sorted(set(remaining_feats_list) - set(new_feat_idxs))
          print_cust(f"experiment_dynamic_QCNN, new_feat_idxs: {new_feat_idxs}")
          for new_feat_idx_idx in range(len(new_feat_idxs)):
            new_feat = new_feat_idxs[new_feat_idx_idx]
            insertion_idx = insertion_indices[new_feat_idx_idx]
            print_cust(f"experiment_dynamic_QCNN, new_feat: {new_feat}")
            print_cust(f"experiment_dynamic_QCNN, insertion_idx: {insertion_idx}")
            cur_feature_list.insert(insertion_idx, new_feat)
            print_cust(f"experiment_dynamic_QCNN, cur_feature_list: {cur_feature_list}")

          # Only add expansion data if we are expanding and want to have an identity iniialization.
          if expand == True:
            expansion_data.append([list(cur_feature_list), insertion_indices])
          else:
            expansion_data.append([])

    feature_list = list(cur_feature_list)

    # If models are trained in parallel, get rid of the expansion data because we do not have/want an identity initialization.
    if train_models_parallel == True:
      new_expansion_data = []
      for layer in range(conv_layers, -1, -1):
        new_expansion_data.append([])
      expansion_data = new_expansion_data

    expansion_data.reverse()

    print_cust(f"create_feat_list_expansion_data, feature_list: {feature_list}")
    print_cust(f"create_feat_list_expansion_data, expansion_data: {expansion_data}")

    return feature_list, expansion_data

# NOTE: feature_list, expansion_data are trivial for the generative case.
# Don't even call it; just generate my own inputs

"""# QFL Functions"""

import math

import copy

"""## Load dataset function"""

from sklearn.datasets import make_classification

from sklearn.preprocessing import LabelEncoder

def load_dataset(dataset_type="mnist", classes=["4", "9"], n_samples=1000, num_feats=20, keep_orig_imgs=False, do_lda=False, custom_debug=False):
    """
    Load a dataset and return it as angles together with labels.

    Parameters:
      dataset_type (str): Can be:
           "mnist"      -- load MNIST digits (using classes list, e.g. ["4", "9"])
           "synthetic"  -- generate synthetic binary classification data
           "pima"       -- load the Pima Indians Diabetes dataset from OpenML.
           "higgs"      -- load the HIGGS dataset from OpenML.
           "covertype"  -- load the Covertype dataset from OpenML (binarized).
      classes (list): Used only when dataset_type is "mnist". (Ignored for other types.)
      n_samples (int): Number of samples to load or generate.
      num_feats (int): Number of features (or PCA components) to extract and encode as angles.
      keep_orig_imgs (Boolean): Boolean indicating whether or not the original images should be kept (and no dimensionality reduction should be done)
      do_lda (Boolean): Boolean indicating whether or not random sketching should be performed.

    Returns:
      X_angles (np.array): Array of shape (n_samples, num_feats) with values scaled to [0, π].
      y (np.array): Binary labels.
    """
    do_pca = False
    if not do_lda:
      do_pca = True

    if custom_debug:
      print_cust(f"load_dataset, do_pca: {do_pca}, do_lda: {do_lda}")

    if dataset_type == "mnist" or dataset_type == "Fashion-MNIST":
        print_cust(f"load_dataset, is {dataset_type}")
        if dataset_type == "mnist":
          dataset_type = "mnist_784"
        if custom_debug:
          print_cust(f"load_dataset, dataset_type: {dataset_type}")
        # Assumes a function load_mnist_digits exists.
        X, y = load_mnist_digits(classes, n_samples=n_samples, dataset_name=dataset_type)
        if custom_debug:
          # with open('load_mnist_digits_X.txt', 'w') as f:
          #   print_cust(X, file=f)
          # with open('load_minst_digits_y.txt', 'w') as f:
          #   print_cust(y, file=f)
          with open('load_mnist_digits_X.pkl', 'wb') as f:
            pickle.dump(X, f)
          with open('load_mnist_digits_y.pkl', 'wb') as f:
            pickle.dump(y, f)
        if keep_orig_imgs:
          X_angles = X
        else:
          X_angles = angle_encode_data(X, y=y, n_components=num_feats, do_pca=do_pca, do_lda=do_lda, custom_debug=custom_debug)
        return X_angles, y.astype(int)

    elif dataset_type == "cifar10":
        print_cust("load_dataset, is cifar10")
        X, y = load_cifar10(classes, n_samples=n_samples)
        # CIFAR-10 images are higher-dimensional; applying PCA is recommended.
        X_angles = angle_encode_data(X, y=y, n_components=num_feats, do_pca=do_pca, do_lda=do_lda)
        return X_angles, y.astype(int)

    elif dataset_type == "synthetic":
        print_cust("load_dataset, is synthetic")
        # Generate synthetic binary classification data.
        X, y = make_classification(n_samples=n_samples,
                                   n_features=num_feats,
                                   n_informative=num_feats,
                                   n_redundant=0,
                                   n_repeated=0,
                                   n_classes=2,
                                   random_state=42)
        y = y.astype(int)
        X_angles = angle_encode_data(X, y=y, n_components=num_feats, do_pca=do_pca, do_lda=do_lda)
        return X_angles, y

    elif dataset_type == "pima":
        print_cust("load_dataset, is pima")
        # Load the Pima Indians Diabetes dataset from OpenML.
        pima = fetch_openml('diabetes', version=1, as_frame=False)
        X = pima.data
        print_cust(f"load_dataset, X.shape: {X.shape}")
        mapping = {"tested_negative": 0, "tested_positive": 1}
        y = np.array([mapping.get(label, 0) for label in pima.target])
        if n_samples < len(y):
            X, _, y, _ = train_test_split(X, y, train_size=n_samples, stratify=y, random_state=42)
        X_angles = angle_encode_data(X, y=y, n_components=num_feats, do_pca=do_pca, do_lda=do_lda)
        return X_angles, y

    elif dataset_type == "higgs":
        print_cust("load_dataset, is higgs")
        # Load the HIGGS dataset from OpenML.
        higgs = fetch_openml('HIGGS', version=1, as_frame=False)
        X = higgs.data
        # The target is already numeric; convert to int if necessary.
        y = higgs.target.astype(int)
        if n_samples < len(y):
            X, _, y, _ = train_test_split(X, y, train_size=n_samples, stratify=y, random_state=42)
        X_angles = angle_encode_data(X, y=y, n_components=num_feats, do_pca=do_pca, do_lda=do_lda)
        return X_angles, y

    elif dataset_type == "covertype":
        print_cust("load_dataset, is covertype")
        # Load the Covertype dataset from OpenML.
        covertype = fetch_openml('covertype', version=1, as_frame=False)
        X = covertype.data
        # Convert the string labels to integers using LabelEncoder.
        le = LabelEncoder()
        y_encoded = le.fit_transform(covertype.target)
        print_cust("Covertype classes:", le.classes_)
        # Binarize: designate the label that is encoded as 1 as the positive class.
        y_binary = (y_encoded == 1).astype(int)
        if n_samples < len(y_binary):
            X, _, y_binary, _ = train_test_split(X, y_binary, train_size=n_samples, stratify=y_binary, random_state=42)
        X_angles = angle_encode_data(X, y=y_binary, n_components=num_feats, do_pca=do_pca, do_lda=do_lda)
        return X_angles, y_binary

    elif dataset_type == "breast_cancer":
        # if custom_debug:
        print_cust("load_dataset, is breast_cancer")

        # scikit-learn ships the 569×30 WDBC matrix locally:contentReference[oaicite:1]{index=1}
        from sklearn.datasets import load_breast_cancer
        data = load_breast_cancer()
        X, y = data.data, data.target
        print_cust(f"load_dataset, X.shape: {X.shape}")
        print_cust(f"load_dataset, np.unique(y): {np.unique(y)}")

        # Optional down-sampling if the caller asked for fewer rows
        if n_samples < len(y):
            print_cust(f"load_dataset, n_samples < len(y)")

            X, _, y, _ = train_test_split(X, y,
                                          train_size=n_samples,
                                          stratify=y,
                                          random_state=42)

            print_cust(f"load_dataset, X.shape: {X.shape}, y.shape: {y.shape}")

        # Angle-encoding: either raw 30-D or reduced to `num_feats`
        if keep_orig_imgs:
            # keep_orig_imgs means “no reduction”; we pass X straight through
            print_cust(f"load_dataset, keep_orig_imgs, so passing X_angles = X")
            X_angles = X
        else:
            X_angles = angle_encode_data(X, y=y,
                                         n_components=num_feats,
                                         do_pca=do_pca,
                                         do_lda=do_lda,
                                         custom_debug=custom_debug)
        print_cust(f"load_dataset, X_angles: {X_angles}, y: {y}, y.astype(int): {y.astype(int)}")
        return X_angles, y.astype(int)

    # ------------------------------------------------------------
    # 4) Unknown keyword guard
    # ------------------------------------------------------------
    else:
        raise ValueError(
            "Unknown dataset_type. "
            "Choose from 'mnist', 'cifar10', 'synthetic', 'pima', "
            "'higgs', 'covertype', or 'breast_cancer'."
        )

"""## Split data federated function"""

def split_data_federated(X, y, client_config, test_frac, val_frac=0.2,
                         feature_skew=0.0, label_skew=None, random_state=42, local_pca=False, do_lda=False, feat_sel_type="top", amp_embed=False, feat_ordering="same",
                         shared_pca=False, fed_pca_mocked=False):
    """
    Splits data for federated learning, inducing controllable label‐skew
    that goes from uniform (label_skew=1) to linear (s=0.5) to extreme (s→0).

    Parameters:
      X: The input data to split among clients.
      y: The input labels to split among clients.
      client_config: The client configuration dictionary used to determine how much data each client of each type gets.
      test_frac: Float representing the fraction of the input data used for testing.
      val_frac: Float representing the fraction of each clients' data used for validation.
      feature_skew: Float representing the strength of the skew that each client faces in terms of feature values.
      label_skew: None or float in [0,1].
        - 1.0 => each client’s labels are uniform.
        - 0.5 => exactly linear descending (green line).
        - 0.0 => extreme “red” skew (almost all mass on class 0).
        - Values in between smoothly interpolate via exponent θ = (1−s)/s.
      random_state: Integer representing the random state used for the entire program (TODO: change this; this function now changes global state)
      local_pca: Boolean indicating whether or not PCA should be performed locally.
      do_lda: Boolean indicating whether or not random sketching should be performed.
      feat_sel_type: String representing the choice of features that the client will take later.
      amp_embed: Boolean indicating whether or not the data will be amplitude encoded.
      feat_ordering: String representing the ordering of features to sample in the case where data is amplitude encoded (TODO: use the same logic for sampling amplitude encoded or
      angle encoded data, or refactor so that this is true)

    Returns:
      clients_data: a dictionary mapping integers representing client types to a list of data for each client, where the i-th element of the list is the data for
      client i, and each client has a list of data in the form [(X_train, y_train), (X_val, y_val), (pca_obj, pca_reduced_data)]
      with the third element only being present if local PCA is performed
      (X_test, y_test): a tuple of testing data and labels

    """
    print_cust(f"split_data_federated, fed_pca_mocked: {fed_pca_mocked}")
    print_cust(f"split_data_federated, client_config: {client_config}")
    # TODO: test this function. i keep getting weird results where the acc for each client differs a lot ...

    X = np.array(X)
    y = np.array(y)
    if random_state is not None:
        np.random.seed(random_state)

    # 1) global train/test split
    n = len(X)
    perm = np.random.permutation(n)
    tsize = int(test_frac * n)
    test_idx, train_idx = perm[:tsize], perm[tsize:]
    X_test,  y_test  = X[test_idx],  y[test_idx]
    X_train, y_train = X[train_idx], y[train_idx]
    n_train = len(X_train)

    rel_idx = np.arange(n_train)
    pointer = 0
    clients_data = {}

    max_cli_size = max(client_config.keys())

    if do_lda:
      sketch_mat = np.random.normal(loc=0.0, scale=1/np.sqrt(X.shape[1]), size=(X.shape[1], max_cli_size))

    # sanity
    total_pct = sum(cfg["percentage_data"] for cfg in client_config.values())
    if total_pct > 1.0:
        raise ValueError("Sum of percentage_data > 1")

    # 2) per-client-type allocation
    for ctype, cfg in client_config.items():
        pct, n_clients = cfg["percentage_data"], cfg["num_clients"]

        # carve off this type’s pool
        alloc_n = int(pct * n_train)
        alloc_n = min(alloc_n, n_train - pointer)
        alloc_idx = rel_idx[pointer : pointer + alloc_n]
        pointer += alloc_n

        # --- feature skew (unchanged) ---
        feats = X_train[alloc_idx, 0].astype(float)
        if feats.size:
            lo, hi = feats.min(), feats.max()
            norm_feat = (feats - lo)/(hi - lo) if hi > lo else np.zeros_like(feats)
        else:
            norm_feat = feats
        rand_comp = np.random.rand(len(alloc_idx))
        scores = feature_skew*norm_feat + (1-feature_skew)*rand_comp
        alloc_idx = alloc_idx[np.argsort(scores)]

        # --- label skew ---
        # --- label skew (robust, disjoint sampling) ---
        if label_skew is None:
            client_chunks = np.array_split(alloc_idx, n_clients)
        else:
            labels_sorted = np.array(sorted(np.unique(y_train[alloc_idx])))
            C             = len(labels_sorted)

            # build the rank‐power weights w_i = (C−i)^θ
            s     = float(label_skew)
            theta = (1.0 - s)/s if s>0 else np.inf
            ranks = np.arange(C, 0, -1, dtype=float)    # [C, C-1, ...,1]
            w     = (ranks**theta) if np.isfinite(theta) else np.zeros_like(ranks)
            if not np.isfinite(theta):
                w[0] = 1.0   # all mass on first class
            p_local = w / w.sum()

            # we’ll sample from this array, removing as we go
            remaining = np.array(alloc_idx, dtype=int)
            # precompute a map from label → slot in p_local
            lbl2idx = {lbl:i for i,lbl in enumerate(labels_sorted)}

            client_chunks = []
            for client_id in range(n_clients):
                n_rem        = len(remaining)
                if n_rem == 0:
                    client_chunks.append(np.array([],dtype=int))
                    continue

                clients_left = n_clients - client_id
                take_n       = int(math.ceil(n_rem / clients_left))

                # build per-sample probs
                classes_rem = y_train[remaining]
                q           = np.array([p_local[lbl2idx[l]] for l in classes_rem])
                q          /= q.sum()

                # sample without replacement
                sel_idx = np.random.choice(
                    n_rem,
                    size=min(take_n, n_rem),
                    replace=False,
                    p=q
                )
                sel      = remaining[sel_idx]
                client_chunks.append(sel)

                # remove them
                remaining = np.delete(remaining, sel_idx)

        # 3) split each client chunk into train/val
        clients_data[ctype] = []
        for chunk in client_chunks:
            m = len(chunk)
            v = int(val_frac * m)
            client_data_chunk = X_train[chunk]
            val_idx   = chunk[:v]
            train_idx = chunk[v:]
            n_tot_comps = ctype
            # Have ALL components so that we have the option to sample others (as opposed to just the top components).
            if feat_sel_type != "top" or shared_pca:
              n_tot_comps = max_cli_size
            # Perform local PCA (and local random sketching) if specified.
            if local_pca and not fed_pca_mocked:
              if do_lda:
                # TODO: validation data prob should not be LDA'd together w/ training, but it's alright for now.. doesn't affect training for single client epoch case
                client_data_chunk, pca, client_data_chunk_pca = angle_encode_data(client_data_chunk, n_tot_comps, y=y_train[chunk], do_lda=True, ret_lda=True, sketch_mat=sketch_mat[:, :ctype])
              else:
                client_data_chunk, pca, client_data_chunk_pca = angle_encode_data(client_data_chunk, n_tot_comps, do_pca=True, ret_pca=True)
            else:
              client_data_chunk, pca, client_data_chunk_pca = client_data_chunk, None, None

            # If the data will be amplitude encoded, and the features are sampled in order of highest variance,
            # add a small constant to the smallest valued pixel so that the data can be subsequently normalized.
            if amp_embed and feat_ordering == "highest_var":
              variances = X.var(axis=0, ddof=0)
              order = np.argsort(variances)[::-1]
              client_data_chunk = client_data_chunk[:, order] + 1e-3

            # Store the original data for this particular client.
            if shared_pca:
              client_data_chunk = X_train[chunk]

            X_train_data = client_data_chunk[v:]
            X_val_data = client_data_chunk[:v]
            client_data_lst = [[X_train_data, y_train[train_idx]],
                 [X_val_data,   y_train[val_idx]]]
            if local_pca:
              client_data_lst.append([pca, client_data_chunk_pca])
            clients_data[ctype].append(
                client_data_lst
            )

    return clients_data, (X_test, y_test)

"""## Initialize client params function"""

def initialize_client_params(clients_config_arg, model_size=NUM_BASE_QUBITS, cur_client_params_dict=None, debug=False, qubits_and_layers_to_add_block_params=[],
                             train_models_parallel=False, n_output_qubits=1, generative=False, use_torch=False, qnode_func=None, device=None, pca_info=None, is_qcnn=True, alt_zeros_init=""):
    """
    Initializes the parameter dictionary storing the parameters for each client.

    Parameters:
      clients_config_arg: The client configuration dictionary used to determine the size and types of parameters that each client will have.
      model_size: Integer representing the shared model size (if that exists) for the subnet initialization case.
      cur_client_params_dict: A dictionary representing a client parameters dictionary, used for previous parameter initialization.
      debug: A Boolean indicating whether or not the parameters should be initialized in debug mode.
      qubits_and_layers_to_add_block_params: A dictionary mapping client sizes to a list of (n_qubits, n_layers) specifying the block parameters
      that this client contains.
      train_models_parallel: A Boolean indicating whether or not the models are trained in parallel (and thus, whether or not we should employ subnet initialization).
      n_output_qubits: An integer specifying the number of output qubits for these models.

    Returns:
      client_params_dict, a dictionary mapping client types (integers, representing the numbers of qubits that each client contains) to a list of parameters for each client
      (the i-th element of the list are the parameters for the i-th client)
    """
    # NOTE: for train_models_parallel, this function only needs to be called ONCE. There is no notion of 'expansion'; so the
    # subsequent checks for 'expansion' in this function should NOT be strictly necessary.
    client_params_dict = {}
    max_client_size = max(clients_config_arg.keys())
    # Iterate over each client type
    for client_type, cfg in clients_config_arg.items():
        # If the client type is smaller than the model size in expansion, or the number of clients is 0, then don't create
        # parameters for models of this type
        if (client_type < model_size and not train_models_parallel) or cfg["num_clients"] == 0:
            continue
        num_clients = cfg["num_clients"]
        client_params_dict[client_type] = []
        # For each client that is to be created of this type, initialize the parameters for this client.
        for client_idx in range(num_clients):
            # NOCHANGE: TOMODIFY, depthFL: inject some argument here to indicate that client_model_size should be dynamically found based on the largest qubits needed for its block params,
            # if not qcnn.
            # ^ client_model_size is never manifested for the non-QCNN case, so is OK to use client_model_size
            # as an alias to diff't client types in terms of number of layers.
            if train_models_parallel:
              client_model_size = client_type
            else:
              client_model_size = model_size
            is_expansion = (cur_client_params_dict is not None)
            # Filter block config: for each (num_qubits, num_layers) tuple, if num_qubits <= client_type, include it.
            qubits_and_layers_to_add_block_params_client = []
            if client_model_size in qubits_and_layers_to_add_block_params:
              qubits_and_layers_to_add_block_params_client = qubits_and_layers_to_add_block_params[client_model_size]
              if generative:
                print_cust(f"initialize_client_params, generative, filtering qubits_and_layers_to_add_block_params_client")
                qubits_and_layers_to_add_block_params_client_filt = []
                for qubits_and_layers_tuple in sorted(qubits_and_layers_to_add_block_params_client, key=lambda x:x[0]):
                  if qubits_and_layers_tuple[0] <= client_model_size:
                    qubits_and_layers_to_add_block_params_client_filt.append(qubits_and_layers_tuple)
                print_cust(f"initialize_client_params, generative, qubits_and_layers_to_add_block_params_client_filt: {qubits_and_layers_to_add_block_params_client_filt}")
                qubits_and_layers_to_add_block_params_client = qubits_and_layers_to_add_block_params_client_filt


            print_cust(f"initialize_client_params, client_idx: {client_idx}, client_model_size: {client_model_size}, n_output_qubits: {n_output_qubits}")
            print_cust(f"initialize_client_params, generative: {generative}")
            conv_layers = compute_conv_layers(client_model_size, n_output_qubits, generative=generative, is_qcnn=is_qcnn)

            conv_params_tuple, pool_params_tuple, final_pool_param, final_params, bias_param, block_params_list = \
                init_dynamic_qcnn_params(client_model_size, conv_layers, debug=debug, zeros_init=is_expansion,
                                         qubits_and_layers_to_add_block_params=qubits_and_layers_to_add_block_params_client, generative=generative, use_torch=use_torch, alt_zeros_init=alt_zeros_init)

            if generative:
              client_discriminator = PCADiscriminator(max_client_size, 4)
              client_qgan_qnode = qnode_func(client_model_size)
              client_qubits_depth_dict = {}
              for qubit_depth_tuple in qubits_and_layers_to_add_block_params_client:
                client_qubits_depth_dict[qubit_depth_tuple[0]] = qubit_depth_tuple[1]
              pca_info_used = copy.deepcopy(pca_info)
              client_pcarescaler = PCARescaler(*pca_info_used)
              client_generator = PatchQuantumGenerator(1, 1.0, client_model_size, 0, client_qgan_qnode, client_qubits_depth_dict, device, 0, True, client_pcarescaler)
              client_gan_models = [client_generator, client_discriminator]

            # TODO: continue here. initialize previous disc/generator.
            # If we want to pre-initialize the parameters for this client, do so.
            if cur_client_params_dict is not None and not train_models_parallel:
                init_trained_params = copy.deepcopy(cur_client_params_dict[client_type][client_idx])
                new_conv_params = list(conv_params_tuple)
                new_conv_params[1:] = list(init_trained_params[0])
                conv_params_tuple = tuple(new_conv_params)
                new_pool_params_tuple = list(pool_params_tuple)
                new_pool_params_tuple[1:] = list(init_trained_params[1])
                pool_params_tuple = tuple(new_pool_params_tuple)
                final_pool_param = init_trained_params[2]
                final_params = init_trained_params[3]
                bias_param = init_trained_params[4]
                # block_params_list = init_trained_params[5]
                init_block_params_list = init_trained_params[5]
                if not generative:
                  for init_block_param_idx, block_params in enumerate(init_block_params_list):
                    block_params_shape = block_params.shape
                    print_cust(f"initialize_client_params, block_params_shape: {block_params_shape}")
                    num_layers_blockparams = block_params_shape[0]
                    num_qubits_blockparams = block_params_shape[1]
                    existing_block_params = block_params_list[init_block_param_idx]
                    if existing_block_params.shape == block_params_shape:
                      block_params_list[init_block_param_idx] = block_params
                  print_cust(f"initialize_client_params, block_params_list: {block_params_list}")
                else:
                  # first time the function is called, we assume we get a block params list. but that is effectively ignored and replaced with our gen and disc.
                  # next time this function is called, I can assume that I'll get passed in a generator and discriminator model. so assume that's the format,
                  # if cur_client_params_dict is not None.
                  print_cust(f"initialize_client_params, generative, init_block_params_list: {init_block_params_list}")
                  existing_generator, existing_discriminator = init_block_params_list[0], init_block_params_list[1]
                  client_discriminator = existing_discriminator
                  client_generator = PatchQuantumGenerator(1, 0.0, client_model_size, 0, client_qgan_qnode, client_qubits_depth_dict, device, 0, True, client_pcarescaler)
                  client_generator.initialize_existing_parameters(existing_generator.q_params)
                  client_gan_models = [client_generator, client_discriminator]

            if generative:
              block_params_list = client_gan_models
              print_cust(f"initialize_client_params, generative, block_params_list: {block_params_list}")

            client_params = (conv_params_tuple, pool_params_tuple, final_pool_param, final_params, bias_param, block_params_list)
            client_params_dict[client_type].append(client_params)
    return client_params_dict

# REPURPOSE: for client_params_dict, I need nothing else except the block

# NOT NEEDED (?): TOMODIFY, layers: do something similar above, but for quantumclassifier.
# DONE: TOMODIFY, layers: specifically, when initializing the block params for the larger client, modify it in-place: modify the previous params ONLY to have the block params. -- can modify this for the qcnn case, too.
# NOTE, 8/1, 5:19 PM: continue working from here.
# NO CHANGE: TOMODIFY, layers: (kind of ?): do ID init for the new block params. should already do this for the expansion anyways!
# DONE: TOMODIFY, layers: have an argument that says whether or not I need to have conv layers, and inject that in compute_conv_layers.

# DONE (slightly diff, see below): TOMODIFY, layers: change the initialize block params logic to do an exact match, AND start in the SAME order as provided block params as well as current created block params.
# ^ (slightly different, though: instead of finding the first matches, I find the EXACT index matches. this likely makes more sense.)

# DONE: TOMODIFY, LAYERS: did the above block params initialization from existing params in a torch.no_grad() context.

"""## Initialize Client Optimizers"""

def initialize_client_optimizers(client_params_dict, lr_gen=0.004, lr_disc=0.001, gen_betas=(0.5, 0.9), disc_betas=(0.5, 0.9), optim_type="sgd", existing_optims_dict=None, generative=False, is_qcnn=False, opt_layers=[-1]):
  print_cust(f"initialize_client_optimizers, generative: {generative}, is_qcnn: {is_qcnn}, opt_layers: {opt_layers}")

  client_optimizers_dict = {}
  for client_type, client_params_list in client_params_dict.items():
    client_optimizers_dict[client_type] = []
    for client_idx, client_params in enumerate(client_params_list):
      if generative:
        curr_cli_gen = client_params[5][0]
        curr_cli_disc = client_params[5][1]
        if optim_type == "sgd":
          curr_cli_gen_opt = torch.optim.SGD(nn.ParameterList([curr_cli_gen.q_params[-1]]), lr=lr_gen)
          curr_cli_disc_opt = torch.optim.SGD(curr_cli_disc.parameters(), lr=lr_disc)
        elif optim_type == "adam":
          curr_cli_gen_opt = torch.optim.Adam(nn.ParameterList([curr_cli_gen.q_params[-1]]), lr=lr_gen, betas=gen_betas)
          curr_cli_disc_opt = torch.optim.Adam(curr_cli_disc.parameters(), lr=lr_disc, betas=disc_betas)
        if existing_optims_dict is not None:
          existing_cli_disc_opt = existing_optims_dict[client_type][client_idx][5][1]
          print_cust(f"initialize_client_optimizers, existing_cli_disc_opt: {existing_cli_disc_opt}")
          curr_cli_disc_opt = existing_cli_disc_opt
        # client_optimizers_dict[client_type].append([curr_cli_gen_opt, curr_cli_disc_opt])
        cli_optims = [[] for _ in range(len(client_params))]
        print_cust(f"initialize_client_optimizers, cli_optims: {cli_optims}")
        cli_optims[5].append(curr_cli_gen_opt)
        cli_optims[5].append(curr_cli_disc_opt)
        print_cust(f"initialize_client_optimizers, cli_optims: {cli_optims}")
      elif not is_qcnn:
        curr_cli_blockparams = client_params[5]
        list_params = []
        if opt_layers is None:
          list_params = [nn.Parameter(block_param, requires_grad=True) for block_param in curr_cli_blockparams]
        else:
          for layer_idx in opt_layers:
            layer_opt_param = nn.Parameter(curr_cli_blockparams[layer_idx], requires_grad=True)
            list_params.append(layer_opt_param)
        if optim_type == "sgd":
          # NOTE: these are technically pointing to the WRONG model; BUT, because during train_client, I recreate
          # the optimizers, the actual optimizers themselves during optimization should point to the right actual
          # parameter objects. These are mainly just to store the 'architecture' of the optimizer based on the model
          # params.
          # TODO, layers, 10:49 PM 8/1: continue here
          curr_cli_blockparams_opt = torch.optim.SGD(nn.ParameterList(list_params), lr=lr_disc)
        elif optim_type == "adam":
          curr_cli_blockparams_opt = torch.optim.Adam(nn.ParameterList(list_params), lr=lr_disc, betas=disc_betas)
        # TOMODIFY, layers: need to have better expansion logic here, IF you want to add a param group; need to compute
        # differences and stuff.
        # if existing_optims_dict is not None:
        #   existing_cli_blockparams_opt = existing_optims_dict[client_type][client_idx][5]
        #   print_cust(f"initialize_client_optimizers, existing_cli_blockparams_opt: {existing_cli_blockparams_opt}")
        #   curr_cli_blockparams_opt = existing_cli_blockparams_opt
        cli_optims = [None for _ in range(len(client_params))]
        print_cust(f"initialize_client_optimizers, cli_optims: {cli_optims}")
        cli_optims[5] = curr_cli_blockparams_opt
        print_cust(f"initialize_client_optimizers, cli_optims: {cli_optims}")

      else:
        # not implemented
        cli_optims = [None for _ in range(len(client_params))]

      client_optimizers_dict[client_type].append(cli_optims)
  print_cust(f"initialize_client_optimizers, client_optimizers_dict: {client_optimizers_dict}")
  return client_optimizers_dict

# DONE: TOMODIFY, layers: do something similar, but for the quantumclassifier. (just based on the type of block params.)
# DONE: TOMODIFY, layers: for each client, optimizers are just in a list, not in a nested list.
# DONE: TOMODIFY, layers: inject an argument here, for QCNN, mask_grads, that specifies what layer(s) we should optimize over.

# DONE (kind of; can select layers with point 3 above): TOMODIFY, layers: change the optimizer to point only to the placeholder last layer tensor (wrapped in an nn.Parameter).

"""## Federated averaging function

### Arithmetic mean
"""

def federated_averaging(client_params_dict, clients_data_dict):
    # Get a reference parameter structure.
    first_client_params = None
    for client_type in client_params_dict:
        if len(client_params_dict[client_type]) > 0:
            first_client_params = client_params_dict[client_type][0]
            break
    if first_client_params is None:
        raise ValueError("No client parameters provided.")
    conv_params_first, pool_params_first, final_pool_param_first, final_params_first, bias_param_first, block_params_first = first_client_params

    conv_layers = len(conv_params_first)
    aggregated_conv_params = []
    for layer in range(conv_layers):
        even_shape = conv_params_first[layer][0].shape
        odd_shape = conv_params_first[layer][1].shape
        aggregated_even = np.zeros(even_shape)
        aggregated_odd = np.zeros(odd_shape)
        aggregated_conv_params.append((aggregated_even, aggregated_odd))
    pool_layers = len(pool_params_first)
    aggregated_pool_params = [np.zeros(pool_params_first[layer].shape) for layer in range(pool_layers)]
    aggregated_final_pool_param = np.zeros(final_pool_param_first.shape)
    aggregated_final_params = np.zeros(final_params_first.shape)
    aggregated_bias_param = np.zeros(bias_param_first.shape)
    # Initialize accumulator for block parameters.
    aggregated_block_params = []
    for bp in block_params_first:
        aggregated_block_params.append(np.zeros(bp.shape))

    total_weight = 0.0
    for client_type in client_params_dict:
        clients_params = client_params_dict[client_type]
        clients_data = clients_data_dict[client_type]
        for params, data in zip(clients_params, clients_data):
            train_data, _ = data
            X_train, _ = train_data
            weight = len(X_train)
            total_weight += weight
            conv_params, pool_params, final_pool_param, final_params, bias_param, block_params_list = params
            for layer in range(conv_layers):
                aggregated_conv_params[layer] = (
                    aggregated_conv_params[layer][0] + weight * conv_params[layer][0],
                    aggregated_conv_params[layer][1] + weight * conv_params[layer][1]
                )
            for layer in range(len(pool_params)):
                aggregated_pool_params[layer] += weight * pool_params[layer]
            aggregated_final_pool_param += weight * final_pool_param
            aggregated_final_params += weight * final_params
            aggregated_bias_param += weight * bias_param
            # Aggregate block parameters.
            for i, bp in enumerate(block_params_list):
                aggregated_block_params[i] += weight * bp

    for layer in range(conv_layers):
        aggregated_conv_params[layer] = (
            aggregated_conv_params[layer][0] / total_weight,
            aggregated_conv_params[layer][1] / total_weight
        )
    aggregated_pool_params = [pp / total_weight for pp in aggregated_pool_params]
    aggregated_final_pool_param /= total_weight
    aggregated_final_params /= total_weight
    aggregated_bias_param /= total_weight
    for i in range(len(aggregated_block_params)):
        aggregated_block_params[i] /= total_weight

    aggregated_params = (
        tuple(aggregated_conv_params),
        tuple(aggregated_pool_params),
        aggregated_final_pool_param,
        aggregated_final_params,
        aggregated_bias_param,
        aggregated_block_params
    )
    return aggregated_params

"""### Quaternion mean"""

from scipy.spatial.transform import Rotation as R

def euler_to_quat(phi, theta, omega):
    """
    Converts Euler angles (phi, theta, omega) corresponding to
    qml.Rot(phi, theta, omega) = R_z(omega) R_y(theta) R_z(phi)
    to a quaternion.

    We use the 'zyz' convention in scipy so that:
       quat = R.from_euler('zyz', [omega, theta, phi]).as_quat()
    Note: as_quat() returns the quaternion in [x, y, z, w] order.
    """
    return R.from_euler('zyz', [omega, theta, phi]).as_quat()

def quat_to_euler(q):
    """
    Converts a quaternion back to Euler angles using the 'zyz' convention.
    The scipy routine returns angles in the order [omega, theta, phi];
    we then rearrange them to (phi, theta, omega) for use with qml.Rot.
    """
    euler = R.from_quat(q).as_euler('zyz')
    return euler[2], euler[1], euler[0]

def federated_averaging_quat(client_params_dict, clients_data_dict):
    """
    Federated averaging using weighted quaternion averaging for rotations,
    including aggregation of block_params_list.
    For each rotation (specified as Euler angles) in the convolutional layers
    and in each block parameter array (shape (L_bp, n_bp, 3)), we convert to a quaternion,
    accumulate a weighted sum, then convert the averaged quaternion back to Euler angles.
    """
    # Get a reference parameter structure.
    first_client_params = None
    for client_type in client_params_dict:
        if len(client_params_dict[client_type]) > 0:
            first_client_params = client_params_dict[client_type][0]
            break
    if first_client_params is None:
        raise ValueError("No client parameters provided.")
    conv_params_tuple_ref, pool_params_tuple_ref, final_pool_param_ref, final_params_ref, bias_param_ref, block_params_ref = first_client_params
    num_layers = len(conv_params_tuple_ref)

    # --- Accumulators for convolution parameters (quaternion averaging) ---
    accum_conv_even = []
    accum_conv_odd = []
    for layer in range(num_layers):
        even_shape = conv_params_tuple_ref[layer][0].shape  # (N_even, 12)
        odd_shape  = conv_params_tuple_ref[layer][1].shape   # (N_odd, 12)
        # There are 4 rotations per row (each 3 parameters).
        accum_conv_even.append(np.zeros((even_shape[0], 4, 4)))
        accum_conv_odd.append(np.zeros((odd_shape[0], 4, 4)))

    # --- Accumulators for pooling, final rotation and bias ---
    accum_pool = []
    for layer in range(num_layers):
        pool_shape = pool_params_tuple_ref[layer].shape  # (N_pool, 1)
        accum_pool.append(np.zeros(pool_shape, dtype=np.complex128))
    accum_final = np.zeros(4, dtype=np.float64)
    accum_final_pool = 0 + 0j
    accum_bias = np.zeros(bias_param_ref.shape, dtype=np.float64)

    # --- New: Accumulators for block parameters ---
    # For each block parameter array (shape (L_bp, n_bp, 3)), create an accumulator (shape (L_bp, n_bp, 4)).
    accum_block = []
    for bp in block_params_ref:
        shape = bp.shape
        accum_block.append(np.zeros((shape[0], shape[1], 4)))

    total_weight = 0.0

    # --- Loop over clients and accumulate weighted sums ---
    for client_type in client_params_dict:
        clients_params = client_params_dict[client_type]
        clients_data = clients_data_dict[client_type]
        for params, data in zip(clients_params, clients_data):
            train_data, _ = data
            X_train, _ = train_data
            weight = len(X_train)
            total_weight += weight
            conv_params, pool_params, final_pool_param, final_params, bias_param, block_params_list = params

            # Aggregate convolution parameters.
            for layer in range(num_layers):
                even_params = conv_params[layer][0]  # shape (N_even, 12)
                odd_params  = conv_params[layer][1]   # shape (N_odd, 12)
                for i in range(even_params.shape[0]):
                    for r in range(4):
                        euler_angles = even_params[i, r*3:(r+1)*3]
                        q = euler_to_quat(euler_angles[0], euler_angles[1], euler_angles[2])
                        if np.linalg.norm(accum_conv_even[layer][i, r]) > 0:
                            if np.dot(accum_conv_even[layer][i, r], q) < 0:
                                q = -q
                        accum_conv_even[layer][i, r] += weight * q
                for i in range(odd_params.shape[0]):
                    for r in range(4):
                        euler_angles = odd_params[i, r*3:(r+1)*3]
                        q = euler_to_quat(euler_angles[0], euler_angles[1], euler_angles[2])
                        if np.linalg.norm(accum_conv_odd[layer][i, r]) > 0:
                            if np.dot(accum_conv_odd[layer][i, r], q) < 0:
                                q = -q
                        accum_conv_odd[layer][i, r] += weight * q

            # Aggregate pooling parameters.
            for layer in range(num_layers):
                pool_layer_params = pool_params[layer]  # shape (N_pool, 1)
                for i in range(pool_layer_params.shape[0]):
                    angle = pool_layer_params[i, 0]
                    accum_pool[layer][i, 0] += weight * np.exp(1j * angle)

            # Aggregate final rotation parameters.
            euler_angles = final_params[0]  # (phi, theta, omega)
            q = euler_to_quat(euler_angles[0], euler_angles[1], euler_angles[2])
            if np.linalg.norm(accum_final) > 0:
                if np.dot(accum_final, q) < 0:
                    q = -q
            accum_final += weight * q

            # Aggregate final pooling parameter.
            angle = final_pool_param[0]
            accum_final_pool += weight * np.exp(1j * angle)

            # Aggregate bias (arithmetic).
            accum_bias += weight * bias_param

            # --- Aggregate block parameters ---
            # For each block parameter array in the list.
            for idx, bp in enumerate(block_params_list):
                shape = bp.shape  # (L_bp, n_bp, 3)
                for i in range(shape[0]):
                    for j in range(shape[1]):
                        euler_angles = bp[i, j, :]
                        q = euler_to_quat(euler_angles[0], euler_angles[1], euler_angles[2])
                        if np.linalg.norm(accum_block[idx][i, j]) > 0:
                            if np.dot(accum_block[idx][i, j], q) < 0:
                                q = -q
                        accum_block[idx][i, j] += weight * q

    # --- Finalize averages for convolution parameters ---
    avg_conv_layers = []
    for layer in range(num_layers):
        even_shape = conv_params_tuple_ref[layer][0].shape
        odd_shape = conv_params_tuple_ref[layer][1].shape
        avg_even = np.zeros(even_shape)
        avg_odd = np.zeros(odd_shape)
        for i in range(even_shape[0]):
            for r in range(4):
                q_avg = accum_conv_even[layer][i, r] / np.linalg.norm(accum_conv_even[layer][i, r])
                euler_avg = quat_to_euler(q_avg)
                avg_even[i, r*3:(r+1)*3] = euler_avg
        for i in range(odd_shape[0]):
            for r in range(4):
                q_avg = accum_conv_odd[layer][i, r] / np.linalg.norm(accum_conv_odd[layer][i, r])
                euler_avg = quat_to_euler(q_avg)
                avg_odd[i, r*3:(r+1)*3] = euler_avg
        avg_conv_layers.append((avg_even, avg_odd))

    # --- Finalize pooling parameters ---
    avg_pool_layers = []
    for layer in range(num_layers):
        pool_shape = pool_params_tuple_ref[layer].shape
        avg_pool = np.zeros(pool_shape)
        for i in range(pool_shape[0]):
            avg_pool[i, 0] = np.angle(accum_pool[layer][i, 0])
        avg_pool_layers.append(avg_pool)

    # --- Finalize final rotation ---
    q_final_avg = accum_final / np.linalg.norm(accum_final)
    avg_final_params = np.array([quat_to_euler(q_final_avg)])

    # --- Finalize final pooling ---
    avg_final_pool_param = np.array([np.angle(accum_final_pool)])

    # --- Finalize bias ---
    avg_bias = accum_bias / total_weight

    # --- Finalize block parameters ---
    avg_block_params = []
    for idx, bp in enumerate(block_params_ref):
        shape = bp.shape
        avg_bp = np.zeros(shape)
        for i in range(shape[0]):
            for j in range(shape[1]):
                q_avg = accum_block[idx][i, j] / np.linalg.norm(accum_block[idx][i, j])
                euler_avg = quat_to_euler(q_avg)
                avg_bp[i, j, :] = euler_avg
        avg_block_params.append(avg_bp)

    aggregated_params = (
        tuple(avg_conv_layers),
        tuple(avg_pool_layers),
        avg_final_pool_param,
        avg_final_params,
        avg_bias,
        avg_block_params
    )
    return aggregated_params

"""### Circular mean"""

def federated_averaging_circular(client_params_dict, clients_data_dict):
    """
    Federated averaging using circular averaging for angle parameters,
    including the block_params_list.
    For each parameter (including each rotation gate’s angles in block parameters),
    we compute the weighted circular mean:
         avg_angle = angle(sum(weight * exp(1j*angle))).
    """
    # Get a reference parameter structure.
    first_client_params = None
    for client_type in client_params_dict:
        if len(client_params_dict[client_type]) > 0:
            first_client_params = client_params_dict[client_type][0]
            break
    if first_client_params is None:
        raise ValueError("No client parameters provided.")
    conv_params_tuple_ref, pool_params_tuple_ref, final_pool_param_ref, final_params_ref, bias_param_ref, block_params_ref = first_client_params
    n_layers = len(conv_params_tuple_ref)

    # --- Accumulators for convolution parameters (circular) ---
    accum_conv_even = []
    accum_conv_odd = []
    for layer in range(n_layers):
        even_shape = conv_params_tuple_ref[layer][0].shape
        odd_shape  = conv_params_tuple_ref[layer][1].shape
        accum_conv_even.append(np.zeros(even_shape, dtype=np.complex128))
        accum_conv_odd.append(np.zeros(odd_shape, dtype=np.complex128))

    # --- Accumulators for pooling, final rotation, and bias ---
    accum_pool = []
    for layer in range(n_layers):
        pool_shape = pool_params_tuple_ref[layer].shape
        accum_pool.append(np.zeros(pool_shape, dtype=np.complex128))
    accum_final_pool = 0 + 0j
    accum_final_params = np.zeros(final_params_ref.shape, dtype=np.complex128)
    accum_bias = np.zeros(bias_param_ref.shape, dtype=np.float64)

    # --- New: Accumulators for block parameters ---
    # For each block parameter array (shape (L_bp, n_bp, 3)), accumulate using exp(1j * angle) elementwise.
    accum_block = []
    for bp in block_params_ref:
        accum_block.append(np.zeros(bp.shape, dtype=np.complex128))

    total_weight = 0.0

    # --- Loop over clients ---
    for client_type in client_params_dict:
        clients_params = client_params_dict[client_type]
        clients_data = clients_data_dict[client_type]
        for params, data in zip(clients_params, clients_data):
            train_data, _ = data
            X_train, _ = train_data
            weight = len(X_train)
            total_weight += weight
            conv_params, pool_params, final_pool_param, final_params, bias_param, block_params_list = params

            # Aggregate convolution parameters.
            for layer in range(n_layers):
                even_params = conv_params[layer][0]
                odd_params  = conv_params[layer][1]
                accum_conv_even[layer] += weight * np.exp(1j * even_params)
                accum_conv_odd[layer]  += weight * np.exp(1j * odd_params)

            # Aggregate pooling parameters.
            for layer in range(n_layers):
                accum_pool[layer] += weight * np.exp(1j * pool_params[layer])

            # Aggregate final pooling parameter.
            accum_final_pool += weight * np.exp(1j * final_pool_param[0])

            # Aggregate final rotation parameters.
            accum_final_params += weight * np.exp(1j * final_params)

            # Aggregate bias (arithmetic).
            accum_bias += weight * bias_param

            # --- Aggregate block parameters ---
            for idx, bp in enumerate(block_params_list):
                accum_block[idx] += weight * np.exp(1j * bp)

    # --- Finalize convolution parameters ---
    avg_conv_layers = []
    for layer in range(n_layers):
        avg_even = np.angle(accum_conv_even[layer])
        avg_odd  = np.angle(accum_conv_odd[layer])
        avg_conv_layers.append((avg_even, avg_odd))

    # --- Finalize pooling parameters ---
    avg_pool_layers = []
    for layer in range(n_layers):
        avg_pool = np.angle(accum_pool[layer])
        avg_pool_layers.append(avg_pool)

    # --- Finalize final pooling and rotation ---
    avg_final_pool_param = np.array([np.angle(accum_final_pool)])
    avg_final_params = np.angle(accum_final_params)

    # --- Finalize bias ---
    avg_bias = accum_bias / total_weight

    # --- Finalize block parameters ---
    avg_block_params = []
    for idx, bp in enumerate(block_params_ref):
        avg_bp = np.angle(accum_block[idx])
        avg_block_params.append(avg_bp)

    aggregated_params = (
        tuple(avg_conv_layers),
        tuple(avg_pool_layers),
        avg_final_pool_param,
        avg_final_params,
        avg_bias,
        avg_block_params
    )
    return aggregated_params

"""### Circular mean parallel"""

from collections import defaultdict

"""#### Block parameters aggregation function"""

def aggregate_block_params_general(client_params_dict, clients_data_dict, math_int=np):
    """
    Aggregates block (HEA ansatz) parameters for heterogeneous clients in a general way.

    Each client's block parameters (params[5]) is a list of numpy arrays,
    each with shape (num_layers, num_qubits, 3). We ignore the order in the list and
    group arrays by their block type (defined by bp.shape[1], i.e. the number of qubits).

    Args:
      client_params_dict: dict mapping client types (e.g., model sizes) to lists of parameter tuples.
          Each tuple is (conv_params, pool_params, final_pool_param, final_params, bias_param, block_params_list).
      clients_data_dict: dict mapping client types to lists of client data tuples.
          Each client data tuple is ((X_train, y_train), (X_val, y_val)).

    For each block type:
      - Let depths be the set of depths among all clients.
      - Define common_depth = min(depths). This is the number of bottom layers shared by all clients.
      - Define extra_count = max(depths) - common_depth.
      - For each client, its contribution to the common part is its bottom common_depth layers.
      - Its extra part is the top (depth - common_depth) layers.
        We align these extra layers in a meta extra region of length extra_count by computing an offset:
            offset = extra_count - (client_depth - common_depth)
        so that clients with fewer extra layers “line up” at the bottom.
      - Both the common and extra parts are aggregated using weighted circular averaging,
        with weight given by the client’s training set size.

    Returns:
      A dict mapping block type (number of qubits) to a tuple:
         (aggregated_common, aggregated_extra)
      where aggregated_common has shape (common_depth, num_qubits, 3) and
            aggregated_extra has shape (extra_count, num_qubits, 3) (or is None if no extra layers).
    """
    # Group block arrays by block type; record weight, depth, and array.
    print_cust(f"aggregate_block_params_general, math_int: {math_int}")
    block_groups = defaultdict(list)
    for ct, clients in client_params_dict.items():
        print_cust(f"aggregate_block_params_general, ct: {ct}")
        for params, data in zip(clients, clients_data_dict[ct]):
            weight = len(data[0][0])  # number of training samples for this client
            print_cust(f"aggregate_block_params_general, weight: {weight}")
            block_list = params[5]
            if block_list is None or len(block_list) == 0:
                continue
            if isinstance(block_list[0], PatchQuantumGenerator):
              block_list = block_list[0].q_params
            # NOTE: I MAY be breaking the block param functionality because I am saying that the ORDER of the block params matters.
            # Inducing additional structure; but helps for the case where you have SAME number of qubits, but
            # different LAYERS you want to optimize over for the BP's.
            for bp_idx, bp in enumerate(block_list):
                key = bp.shape[1]  # block type is determined by the number of qubits
                block_groups[(key, bp_idx)].append((weight, bp.shape[0], bp))

    aggregated_block = {}
    for key_tuple, group in block_groups.items():
        # For this block type (e.g. 4 qubits), determine the depths.
        key = key_tuple[0]
        bp_idx_group = key_tuple[1]
        depths = [item[1] for item in group]
        common_depth = min(depths)
        max_depth = max(depths)
        extra_count = max_depth - common_depth

        # Initialize accumulators for the common part:
        # Shape: (common_depth, key, 3)
        if math_int == np:
          common_accum = math_int.zeros((common_depth, key, 3), dtype=np.complex128)
        else:
          common_accum = math_int.zeros((common_depth, key, 3), dtype=torch.complex64)
        # We use a vector weight for each common layer;
        # adding a scalar to a vector broadcasts the scalar across all positions.

        if math_int == np:
          common_weight = math_int.zeros(common_depth, dtype=np.float64)
        else:
          common_weight = math_int.zeros(common_depth, dtype=torch.float32)

        # Initialize accumulators for the extra part (if any).
        if extra_count > 0:
            if math_int == np:
              extra_accum = math_int.zeros((extra_count, key, 3), dtype=np.complex128)
              extra_weight = math_int.zeros(extra_count, dtype=np.float64)
            else:
              extra_accum = math_int.zeros((extra_count, key, 3), dtype=torch.complex64)
              extra_weight = math_int.zeros(extra_count, dtype=torch.float32)
        else:
            extra_accum = None
            extra_weight = None

        # Loop over contributions.
        for weight, depth, bp in group:
            # Common part: use the bottom common_depth layers from this client's array.
            print_cust(f"aggregate_block_params_general, weight (in group loop): {weight}")
            common_part = bp[depth - common_depth : depth, :, :]  # shape (common_depth, key, 3)
            common_accum += weight * math_int.exp(1j * common_part)
            common_weight += weight  # scalar addition broadcasts to each common layer

            # Extra part: if this client has extra layers.
            client_extra = depth - common_depth
            if client_extra > 0 and extra_count > 0:
                # Align extra part: offset = extra_count - (client_extra)
                offset = extra_count - client_extra
                extra_part = bp[0 : client_extra, :, :]  # shape (client_extra, key, 3)
                # NOTE, layers: this is weird. Offset gets larger as an idx, BUT I'm (prev line) getting FIRST client_extra layers of clients' block. can change this later I guess.
                extra_accum[offset : offset + client_extra] += weight * math_int.exp(1j * extra_part)
                extra_weight[offset : offset + client_extra] += weight

        print_cust(f"aggregate_block_params_general, common_weight: {common_weight}")
        aggregated_common = math_int.angle(common_accum / common_weight[:, None, None])
        if extra_count > 0:
            aggregated_extra = math_int.angle(extra_accum / extra_weight[:, None, None])
        else:
            aggregated_extra = None
        aggregated_block[key_tuple] = (aggregated_common, aggregated_extra)

    print_cust(f"aggregate_block_params_general, aggregated_block: {aggregated_block}")

    for test_key_tuple, (test_common, test_extra) in aggregated_block.items():
      print_cust(f"aggregate_block_params_general, test_key_tuple: {test_key_tuple}")
      test_key = test_key_tuple[0]
      test_key_bpidx = test_key_tuple[1]
      print_cust(f"aggregate_block_params_general, test_key: {test_key}")
      if test_common is not None:
        print_cust(f"aggregate_block_params_general, test_common: {test_common}")
        print_cust(f"aggregate_block_params_general, test_common.requires_grad: {test_common.requires_grad}")
        print_cust(f"aggregate_block_params_general, test_common.grad_fn: {test_common.grad_fn}")
        if test_common.grad_fn is not None:
          print_cust(f"aggregate_block_params_general, test_common.grad_fn.next_functions: {test_common.grad_fn.next_functions}")
      if test_extra is not None:
        print_cust(f"aggregate_block_params_general, test_extra: {test_extra}")
        print_cust(f"aggregate_block_params_general, test_extra.requires_grad: {test_extra.requires_grad}")
        print_cust(f"aggregate_block_params_general, test_extra.grad_fn: {test_extra.grad_fn}")
        if test_extra.grad_fn is not None:
          print_cust(f"aggregate_block_params_general, test_extra.grad_fn.next_functions: {test_extra.grad_fn.next_functions}")

    return aggregated_block

# TOADD: an identifier to specify if this should be torch or numpy.

"""#### Discriminator Aggregation Functions

##### Discriminator Aggregation Helper
"""

# import copy, torch
from collections import OrderedDict

def fedavg_disc(models, weights=None):
    """
    models  : List[nn.Module]  – client models after local training
    weights : List[int|float] – sample counts or None for plain mean
    returns : nn.Module       – new global model
    """
    # this assumes that models[0] exists
    global_model = copy.deepcopy(models[0])      # keep architecture
    global_dict  = OrderedDict()
    print_cust(f"fedavg_disc, global_model: {global_model}")
    print_cust(f"fedavg_disc, models: {models}")
    print_cust(f"fedavg_disc, weights: {weights}")

    # 1. stack same-shaped tensors from every client
    for key in global_model.state_dict().keys():
        stacked = torch.stack([m.state_dict()[key].float().cpu()
                               for m in models], dim=0)
        print_cust(f"fedavg_disc, stacked: {stacked}")
        if weights is None:                          # unweighted
            global_dict[key] = stacked.mean(dim=0)
        else:                                        # weighted
            w = torch.tensor(weights, dtype=stacked.dtype)
            print_cust(f"fedavg_disc, w: {w}")
            w = w / w.sum()
            print_cust(f"fedavg_disc, after summing, w: {w}")                          # normalise
            # TODO: test this!!!!!
            global_dict[key] = (stacked * w.view(-1, *([1]*
                                      (stacked.dim()-1)))).sum(dim=0)

        print_cust(f"fedavg_disc, key: {key}, global_dict[key]: {global_dict[key]}")

    # 2. load averaged parameters and return
    global_model.load_state_dict(global_dict, strict=True)
    print_cust(f"fedavg_disc, global_model after loading state dict, global_model: {global_model}")
    return global_model

"""##### Main Disc Aggregation Function"""

def aggregate_discriminator_params(client_params_dict, clients_data_dict):
  discriminator_models = []
  disc_weights = []

  for client_type, clients_params in client_params_dict.items():
    for params, data in zip(clients_params, clients_data_dict[client_type]):
      # assumed that params[5][1] is indeed the discriminator.
      # no instanceof checking
      discriminator_models.append(params[5][1])
      disc_weights.append(len(data[0][0]))

  disc_model_agg = fedavg_disc(discriminator_models, disc_weights)

  print_cust(f"aggregate_discriminator_params, disc_model_agg: {disc_model_agg}")

  return disc_model_agg

# TODO: impl this; arithmetic avg.

"""#### Parallel aggregation function"""

def federated_averaging_circular_parallel(client_params_dict, clients_data_dict, n_output_qubits=1, generative=False, is_qcnn=True):
    """
    Perform federated averaging in a chunkwise (parallel) manner.

    This function aggregates parameters from heterogeneous clients.
    Each client's QCNN has the structure:
      - Angle encoding on n qubits
      - An n-qubit HEA ansatz (block ansatz) with a specified number of layers
      - An n-qubit convolutional layer
      - An n-qubit pooling layer (reducing to n/2 qubits)
      - n/2-qubit HEA ansatz with specified number of layers
      - n/2-qubit convolutional layer, etc.

    For the block ansatz parameters (stored in params[5]), different client types can have
    different numbers of layers. The common (shared) block parameters are defined as the last L_common
    layers, where L_common is the minimum block depth across all client types (that have block parameters).
    Extra block layers (i.e. those not in the common group) are aggregated separately for each client type.

    Aggregation is performed via weighted circular averaging, where weights are the number of training samples.

    Args:
      client_params_dict: dict mapping client types (e.g., model sizes) to lists of parameter tuples.
          Each tuple is (conv_params, pool_params, final_pool_param, final_params, bias_param, block_params_list).
      clients_data_dict: dict mapping client types to lists of client data tuples.
          Each client data tuple is ((X_train, y_train), (X_val, y_val)).
      n_output_qubits: an integer representing the number of output qubits for the QCNN model.

    Returns:
      aggregated_params: a meta aggregated parameter tuple with structure:
         (
           aggregated_conv_params,  # tuple of meta conv layers (even, odd parts)
           aggregated_pool_params,  # tuple of meta pool layers
           aggregated_final_pool_param,
           aggregated_final_params,
           aggregated_bias,
           (aggregated_common_block, aggregated_extra_block)
         )
         where aggregated_common_block is a tuple of L_common arrays (the shared block layers)
         and aggregated_extra_block is a dict mapping client type to a tuple of aggregated extra block layers.
    """

    # --- Aggregate convolution, pooling, final pooling, final rotation, and bias ---
    # NOTE, layers: if not is_qcnn, then set math_int to be torch.
    # ^ might be buggy if I decide to use this function. only numpy is supported for non-block params, but
    # for block params, torch is supported.
    if generative or not is_qcnn:
      math_int = torch
    else:
      math_int = np

    print_cust(f"federated_averaging_circular_parallel, math_int: {math_int}")

    # Determine meta number of convolution layers based on the largest client (assumes model size key represents qubit count)
    max_model_size = max(client_params_dict.keys())
    # meta_conv_layers = int(np.log2(max_model_size))
    meta_conv_layers = compute_conv_layers(max_model_size, n_output_qubits, generative=generative, is_qcnn=is_qcnn)

    # Use a reference from a client of the max model size
    ref_params = client_params_dict[max_model_size][0]
    ref_conv = ref_params[0]

    # Initialize accumulators for conv parameters (even and odd parts)
    accum_conv_even = []
    accum_conv_odd = []
    total_weight_conv = [0.0] * meta_conv_layers
    for l in range(meta_conv_layers):
        even_shape = ref_conv[l][0].shape
        odd_shape = ref_conv[l][1].shape
        accum_conv_even.append(np.zeros(even_shape, dtype=np.complex128))
        accum_conv_odd.append(np.zeros(odd_shape, dtype=np.complex128))

    # Initialize accumulators for pooling parameters
    accum_pool = []
    total_weight_pool = [0.0] * meta_conv_layers
    ref_pool = ref_params[1]
    for l in range(meta_conv_layers):
        pool_shape = ref_pool[l].shape
        accum_pool.append(np.zeros(pool_shape, dtype=np.complex128))

    # Accumulators for final pooling, final rotation, and bias.
    accum_final_pool = 0 + 0j
    if math_int == np:
      accum_final_params = math_int.zeros(ref_params[3].shape, dtype=np.complex128)
      accum_bias = math_int.zeros(ref_params[4].shape, dtype=np.float64)
    else:
      accum_final_params = math_int.zeros(ref_params[3].shape, dtype=torch.complex64)
      accum_bias = math_int.zeros(ref_params[4].shape, dtype=torch.float32)
    total_weight_final = 0.0

    # Loop over all clients to accumulate parameters
    for client_type in client_params_dict:
        for params, data in zip(client_params_dict[client_type], clients_data_dict[client_type]):
            train_data, _ = data
            weight = len(train_data[0])

            # Convolution parameters: map client's conv layers into meta layers
            conv_params = params[0]
            client_conv_layers = len(conv_params)
            offset = meta_conv_layers - client_conv_layers
            for l in range(client_conv_layers):
                meta_layer = l + offset
                even_angles = conv_params[l][0]
                odd_angles  = conv_params[l][1]
                accum_conv_even[meta_layer] += weight * np.exp(1j * even_angles)
                accum_conv_odd[meta_layer]  += weight * np.exp(1j * odd_angles)
                total_weight_conv[meta_layer] += weight

            # Pooling parameters: similar mapping
            client_pool = params[1]
            client_pool_layers = len(client_pool)
            offset_pool = meta_conv_layers - client_pool_layers
            for l in range(client_pool_layers):
                meta_layer = l + offset_pool
                pool_angles = client_pool[l]
                accum_pool[meta_layer] += weight * np.exp(1j * pool_angles)
                total_weight_pool[meta_layer] += weight

            # Final pooling, rotation, and bias
            final_pool_param = params[2]
            accum_final_pool += weight * math_int.exp(1j * final_pool_param[0])
            final_params = params[3]
            accum_final_params += weight * math_int.exp(1j * final_params)
            bias_param = params[4]
            accum_bias += weight * bias_param
            total_weight_final += weight

    meta_avg_conv = []
    for l in range(meta_conv_layers):
        avg_even = np.angle(accum_conv_even[l] / total_weight_conv[l])
        avg_odd  = np.angle(accum_conv_odd[l] / total_weight_conv[l])
        meta_avg_conv.append((avg_even, avg_odd))
    meta_avg_conv = tuple(meta_avg_conv)

    meta_avg_pool = []
    for l in range(meta_conv_layers):
        avg_pool = np.angle(accum_pool[l] / total_weight_pool[l])
        meta_avg_pool.append(avg_pool)
    meta_avg_pool = tuple(meta_avg_pool)

    if math_int == np:
      avg_final_pool_param = math_int.array([math_int.angle(accum_final_pool / total_weight_final)])
      avg_final_params = math_int.angle(accum_final_params / total_weight_final)
    else:
      avg_final_pool_param = math_int.tensor([math_int.angle(accum_final_pool / total_weight_final)])
      avg_final_params = math_int.angle(accum_final_params / total_weight_final)
    avg_bias = accum_bias / total_weight_final

    # NOTE: in the generative case (I do instanceof checks to check), this aggregates the generator params only.
    # forcing use of torch here. can change later
    with torch.no_grad():
      aggregated_block_params = aggregate_block_params_general(client_params_dict, clients_data_dict, math_int=math_int)
      if generative:
        # assumed, if generative, that (1) block_params contains discriminator, and (2) there exists at least one discriminator.
        # ALSO, this assumes that if a client is in the training, then its disc params need to be aggregated.
        agg_disc = aggregate_discriminator_params(client_params_dict, clients_data_dict)
        # this changes how the block params need to be used downstream dep on generative... not sure how else to adapt it otherwise for now
        aggregated_block_params = [aggregated_block_params, agg_disc]

    # --- Build final aggregated parameters tuple ---
    aggregated_params = (
        meta_avg_conv,
        meta_avg_pool,
        avg_final_pool_param,
        avg_final_params,
        avg_bias,
        aggregated_block_params
    )
    return aggregated_params

# TOADD: for agg block params specifically, pass in torch, and also, do it in a no_grad() context.
# NOTE: for gen/disc agg specifically, I don't implement either of these functions. I only implement the case where all clients have the exact same model architecture.

"""#### Parallel aggregation function, shared only"""

# TODO: edit fed_avg_circ_par_shared, just like I did for fed_avg_circ_parallel.

def federated_averaging_circular_parallel_shared(client_params_dict, clients_data_dict, generative=False):
    """
    Perform federated averaging only over those parameters that are common to all clients.

    In our dynamic QCNN setup the client parameters tuple is assumed to have the structure:
      (conv_params, pool_params, final_pool_param, final_params, bias_param, block_params_list)

    where:
      - conv_params is a tuple of convolution layer parameters (each element is a tuple (even, odd)
        whose arrays hold rotation angles used in the QCNN).
      - pool_params is a tuple of pooling layer parameters.
      - final_pool_param, final_params, and bias_param are the parameters used after all convolutions.
      - block_params_list is a list of arrays corresponding to the block (HEA ansatz) parameters.

    For heterogeneous clients (e.g. some with 4 qubits and others with 8 qubits) only the QCNN and block
    layers that every client shares are aggregated. For example, if the 4-qubit clients have 2 QCNN layers
    and the 8-qubit clients have 3, then only the final two layers are common.

    The aggregation is done using weighted circular averaging: for any parameter (which is an angle),
    we compute the weighted sum of exp(1j*angle) and then take np.angle(sum/total_weight).

    For block parameters, the function calls aggregate_block_params_general (which groups by block type
    and aggregates both “common” and “extra” parts) and then it keeps only the common part.

    Args:
      client_params_dict: dict mapping client types (e.g. model sizes) to lists of parameter tuples.
      clients_data_dict: dict mapping client types to the corresponding list of data tuples.
         (Each data tuple is ((X_train, y_train), (X_val, y_val)) and the length of X_train is used as weight.)

    Returns:
      A tuple of aggregated parameters for the shared structure:
         (shared_conv_params, shared_pool_params, shared_final_pool_param,
          shared_final_params, shared_bias, shared_block_params)
      where:
         - shared_conv_params is a tuple (length = min_conv_layers) of (even, odd) parameters.
         - shared_pool_params is a tuple (length = min_conv_layers).
         - shared_final_pool_param, shared_final_params, and shared_bias are the aggregated final parameters.
         - shared_block_params is a dict (by block type) where each value is the aggregated common block parameters.
    """

    print_cust(f"federated_averaging_circular_parallel_shared, generative: {generative}")

    # 1. Determine the minimum number of convolution (QCNN) layers across all clients.
    min_conv_layers = None
    for client_type in client_params_dict:
        for params in client_params_dict[client_type]:
            client_conv_layers = len(params[0])
            if min_conv_layers is None or client_conv_layers < min_conv_layers:
                min_conv_layers = client_conv_layers
    if min_conv_layers is None and not generative:
      # NOTE, layers: am changing the logic to allow for NO convolutional and ONLY block params.
      min_conv_layers = 0
      print_cust(f"federated_averaging_circular_parallel_shared, min_conv_layers is None and not generative, min_conv_layers: {min_conv_layers}")
    elif min_conv_layers is None and generative:
      min_conv_layers = 0
      print_cust(f"federated_averaging_circular_parallel_shared, min_conv_layers is None and generative, min_conv_layers: {min_conv_layers}")

    # NOTE, layers (and should prob change): min_conv_layers being 0 is a "hack" for saying we are in the
    # non-QCNN regime, and thus doing model training in PyTorch (although these things should not need)
    # to be coupled with one another.
    # NOTE, layers (and should prob change): all these things regarding torch interfaces, dtypes, etc. should
    # be injected.
    math_int = np
    complex_dtype = np.complex128
    float_dtype = np.float64
    container_func = np.array

    print_cust(f"federated_averaging_circular_parallel_shared, math_int: {math_int}, complex_dtype: {complex_dtype}, float_dtype: {float_dtype}")

    print_cust(f"federated_averaging_circular_parallel_shared, min_conv_layers: {min_conv_layers}")

    # 2. Initialize accumulators for the shared QCNN layers (even and odd parts) and pooling layers.
    shared_conv_even_accum = [None] * min_conv_layers
    shared_conv_odd_accum = [None] * min_conv_layers
    total_weight_conv = [0.0] * min_conv_layers

    shared_pool_accum = [None] * min_conv_layers
    total_weight_pool = [0.0] * min_conv_layers

    # Use a reference client that has exactly min_conv_layers layers to initialize our accumulators.
    ref_found = False
    for client_type in client_params_dict:
        for params in client_params_dict[client_type]:
            client_conv = params[0]
            if len(client_conv) == min_conv_layers:
                for i in range(min_conv_layers):
                    shared_conv_even_accum[i] = math_int.zeros(client_conv[i][0].shape, dtype=complex_dtype)
                    shared_conv_odd_accum[i] = math_int.zeros(client_conv[i][1].shape, dtype=complex_dtype)
                    shared_pool_accum[i] = math_int.zeros(params[1][i].shape, dtype=complex_dtype)
                ref_found = True
                break
        if ref_found:
            break

    # 3. Initialize accumulators for final pooling, final rotation, and bias.
    any_client = next(iter(next(iter(client_params_dict.values()))))
    accum_final_pool = 0 + 0j
    if isinstance(any_client[3], (list, tuple)):
      accum_final_params = [math_int.zeros(any_client[3][0].shape, dtype=complex_dtype), math_int.zeros(any_client[3][1].shape, dtype=complex_dtype)]
    else:
      accum_final_params = math_int.zeros(any_client[3].shape, dtype=complex_dtype)
    accum_bias = math_int.zeros(any_client[4].shape, dtype=float_dtype)
    total_weight_final = 0.0
    print_cust(f"federated_averaging_circular_parallel_shared, accum_final_params: {accum_final_params}")
    if isinstance(accum_final_params, (list, tuple)):
      print_cust(f"federated_averaging_circular_parallel_shared, accum_final_params[0]: {accum_final_params[0]}, accum_final_params[1]: {accum_final_params[1]}, type(accum_final_params[0]): {type(accum_final_params[0])}, type(accum_final_params[1]): {type(accum_final_params[1])}")

    # 4. Loop over all clients and accumulate only the shared parts.
    # For each client we extract the last min_conv_layers layers.
    for client_type in client_params_dict:
        print_cust(f"federated_averaging_circular_parallel_shared, client_type: {client_type}")
        for params, data in zip(client_params_dict[client_type], clients_data_dict[client_type]):
            # Use the number of training samples as the weight.
            train_data = data[0]
            weight = len(train_data[0])
            print_cust(f"federated_averaging_circular_parallel_shared, weight: {weight}")
            client_conv = params[0]
            client_pool = params[1]
            client_layers = len(client_conv)
            offset_conv = client_layers - min_conv_layers
            for i in range(min_conv_layers):
                conv_layer = client_conv[offset_conv + i]
                even_angles = conv_layer[0]
                odd_angles = conv_layer[1]
                shared_conv_even_accum[i] += weight * math_int.exp(1j * even_angles)
                shared_conv_odd_accum[i] += weight * math_int.exp(1j * odd_angles)
                total_weight_conv[i] += weight
            client_pool_layers = len(client_pool)
            offset_pool = client_pool_layers - min_conv_layers
            for i in range(min_conv_layers):
                pool_layer = client_pool[offset_pool + i]
                shared_pool_accum[i] += weight * math_int.exp(1j * pool_layer)
                total_weight_pool[i] += weight

            # Accumulate final pooling, rotation parameters, and bias.
            final_pool_param = params[2]
            accum_final_pool += weight * math_int.exp(1j * final_pool_param[0])
            final_params = params[3]
            print_cust(f"federated_averaging_circular_parallel_shared, any_client[3]: {any_client[3]}")
            if isinstance(any_client[3], (list, tuple)):
              print_cust(f"federated_averaging_circular_parallel_shared, final_params: {final_params}")
              print_cust(f"federated_averaging_circular_parallel_shared, type(final_params[0]): {type(final_params[0])}, type(final_params[1]): {type(final_params[1])}")
              accum_final_params[0] += weight * math_int.exp(1j * final_params[0])
              accum_final_params[1] += weight * math_int.exp(1j * final_params[1])
              print_cust(f"federated_averaging_circular_parallel_shared, type(accum_final_params[0]): {type(accum_final_params[0])}, type(accum_final_params[1]): {type(accum_final_params[1])}")
            else:
              accum_final_params += weight * math_int.exp(1j * final_params)
            bias_param = params[4]
            accum_bias += weight * bias_param
            total_weight_final += weight

    # 5. Compute the weighted circular averages.
    print_cust(f"federated_averaging_circular_parallel_shared, total_weight_conv: {total_weight_conv}, total_weight_final: {total_weight_final}")
    shared_conv_avg = []
    for i in range(min_conv_layers):
        avg_even = math_int.angle(shared_conv_even_accum[i] / total_weight_conv[i])
        avg_odd = math_int.angle(shared_conv_odd_accum[i] / total_weight_conv[i])
        shared_conv_avg.append((avg_even, avg_odd))
    shared_conv_avg = tuple(shared_conv_avg)

    shared_pool_avg = []
    for i in range(min_conv_layers):
        avg_pool = math_int.angle(shared_pool_accum[i] / total_weight_pool[i])
        shared_pool_avg.append(avg_pool)
    shared_pool_avg = tuple(shared_pool_avg)

    shared_final_pool_param = container_func([math_int.angle(accum_final_pool / total_weight_final)])
    if isinstance(any_client[3], (list, tuple)):
      print_cust(f"federated_averaging_circular_parallel_shared, type(accum_final_params[0]): {type(accum_final_params[0])}, type(accum_final_params[1]): {type(accum_final_params[1])}, type(total_weight_final): {type(total_weight_final)}")
      shared_final_params = (math_int.angle(accum_final_params[0] / total_weight_final), math_int.angle(accum_final_params[1] / total_weight_final))
    else:
      shared_final_params = math_int.angle(accum_final_params / total_weight_final)
    shared_bias = accum_bias / total_weight_final

    # NOTE: in the generative case (I do instanceof checks to check), this aggregates the generator params only.
    # forcing use of torch here. can change later

    if generative or min_conv_layers == 0:
      math_int = torch
      complex_dtype = torch.complex64
      float_dtype = torch.float32
      container_func = torch.tensor

    print_cust(f"federated_averaging_circular_parallel_shared, right before block params agg, math_int: {math_int}")

    with torch.no_grad():
      aggregated_block_params = aggregate_block_params_general(client_params_dict, clients_data_dict, math_int=math_int)
      shared_block_params = {}
      for block_type, (aggregated_common, aggregated_extra) in aggregated_block_params.items():
          shared_block_params[block_type] = aggregated_common
      if generative:
        # assumed, if generative, that (1) block_params contains discriminator, and (2) there exists at least one discriminator.
        # ALSO, this assumes that if a client is in the training, then its disc params need to be aggregated.
        agg_disc = aggregate_discriminator_params(client_params_dict, clients_data_dict)
        # this changes how the block params need to be used downstream dep on generative... not sure how else to adapt it otherwise for now
        shared_block_params = [shared_block_params, agg_disc]

    print_cust(f"federated_averaging_circular_parallel_shared, shared_block_params: {shared_block_params}")

    # 7. Return the aggregated shared parameters in the same 6-tuple structure.
    aggregated_shared_params = (
        shared_conv_avg,
        shared_pool_avg,
        shared_final_pool_param,
        shared_final_params,
        shared_bias,
        shared_block_params
    )
    return aggregated_shared_params

# REPURPOSE: add a line to the above function -- if generative, it is both expected and OK to have zero convolutional layers (min_conv_layers = 0).

# DONE: TOMODIFY, layers: it is perfectly OK to have 0 conv layers.
# DONE: TOMODIFY, layers: add a line, if zero conv layers, set math_int to be torch.
# DONE: TOMODIFY, layers: log math_int when I set/change it.

# map the above to model objects

"""## Broadcast parameters function

#### Helper Function, Convert Block Params to Generator
"""

def convert_agg_block_params_generator(aggregated_block_params):
    # This function assumes that aggregated_block_params is meant to be converted to a PatchQuantumGenerator (list of two elements; second element is PCADiscriminator)
    # This function mutates aggregated_block_params.


    # assumed to be a sequence-like obj containing first a dict, and then a PCADiscriminator.
    # focusing on the dict.
    # effectively below code, converts the dict for generator only to a model class, and then
    # sets agg_block_params to be [gen_model, disc_model] (instead of [agg_block_params_dict, disc_model])
    generator_block_params_dict = aggregated_block_params[0]
    block_list = []
    if isinstance(generator_block_params_dict, dict):
      for block_type in sorted(generator_block_params_dict.keys()):
        agg_blocks = generator_block_params_dict[block_type]
        if isinstance(agg_blocks, tuple):
          agg_blocks = agg_blocks[0]
        # assumed to be a list of tensors
        block_list.append(agg_blocks)
    # debug
    for block in block_list:
      print_cust(f"convert_agg_block_params_generator, type(block): {type(block)}")
      print_cust(f"convert_agg_block_params_generator, block.shape: {block.shape}")

    n_qubits_gen_agg = max([block.shape[1] for block in block_list])
    # NOTE: I am using global functions here which is NOT GOOD. can replace later; I am just doing this b/c I don't want to deal with
    # weight copying mutations issues.
    qgan_qnode_agg = create_qnode_qgan(n_qubits_gen_agg)
    qubits_depth_dict_agg = {}
    for block in block_list:
      # assumed that each num_qubits of each block is unique
      qubits_depth_dict_agg[block.shape[1]] = block.shape[0]
    # device_agg = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device_agg = "cpu"
    aggregated_generator = PatchQuantumGenerator(1, 1.0, n_qubits_gen_agg, 0, qgan_qnode_agg, qubits_depth_dict_agg, device_agg, 0, True, None)

    aggregated_generator.initialize_existing_parameters(block_list)

    aggregated_block_params[0] = aggregated_generator
    assert isinstance(aggregated_block_params[1], PCADiscriminator), f"broadcast_param_updates, when agg_block_params is NOT a dict, type(aggregated_block_params[1]): {type(aggregated_block_params[1])}"

    return aggregated_block_params

"""### Broadcast parameters function, all aggregate"""

def broadcast_param_updates(client_params_dict, aggregated_params, math_int=np):
    """

    Args:
      client_params_dict: dict mapping client types (e.g. model sizes) to lists of parameter tuples.
    aggregated_params is assumed to have the structure:
      (meta_avg_conv, meta_avg_pool, avg_final_pool_param, avg_final_params, avg_bias, aggregated_block_params)
      and contains the aggregated parameters to broadcast to the clients.

    Update the clients' parameters in client_params_dict using the aggregated_params.



      - meta_avg_conv: tuple of length meta_conv_layers; each element is a tuple (even, odd)
      - meta_avg_pool: tuple of length meta_conv_layers; each element is an array
      - avg_final_pool_param: numpy array (1,)
      - avg_final_params: numpy array of shape (1,3) or similar
      - avg_bias: numpy array (bias scalar or vector)
      - aggregated_block_params: a dict mapping block type (num_qubits) to a tuple:
            (aggregated_common, aggregated_extra)
          where aggregated_common is a numpy array of shape (common_depth, num_qubits, 3) and
          aggregated_extra is either None or a numpy array of shape (meta_extra, num_qubits, 3).

    For the convolution and pooling parameters, each client’s parameters are updated based on its
    “offset” into the meta (global) parameter set.

    For the block parameters, for each block array bp (shape (client_depth, num_qubits, 3)) in the client's
    block_params_list:
      - The bottom common_depth layers are replaced by the aggregated common block for that block type.
      - If bp has extra layers (client_depth > common_depth) and an aggregated extra block exists,
        then bp’s top extra layers are updated. They are aligned using:
              offset = meta_extra - (client_depth - common_depth)
        so that the bottom of the extra region is shared among all clients.

    The function updates client_params_dict in place and returns it.
    """

    meta_avg_conv, meta_avg_pool, avg_final_pool_param, avg_final_params, avg_bias, aggregated_block_params = aggregated_params
    meta_conv_layers = len(meta_avg_conv)
    if not isinstance(aggregated_block_params, dict):
      aggregated_block_params = convert_agg_block_params_generator(aggregated_block_params)
      print_cust(f"broadcast_param_updates, aggregated_block_params: {aggregated_block_params}")

    # TODO, 7/11: load in the generator, disc params into the model for gen, disc models.

    # Iterate over each client type and update its clients.
    for client_type in client_params_dict:
        for i, params in enumerate(client_params_dict[client_type]):
            conv_params, pool_params, final_pool_param, final_params, bias_param, block_params_list = params

            # --- Update convolution parameters ---
            client_conv_layers = len(conv_params)
            offset_conv = meta_conv_layers - client_conv_layers
            new_conv_params = []
            for l in range(client_conv_layers):
                # Replace client's conv layer l with meta conv layer at l+offset_conv.
                new_conv_params.append(meta_avg_conv[l + offset_conv])

            # --- Update pooling parameters ---
            client_pool_layers = len(pool_params)
            offset_pool = meta_conv_layers - client_pool_layers
            new_pool_params = []
            for l in range(client_pool_layers):
                new_pool_params.append(meta_avg_pool[l + offset_pool])

            # --- Update final pooling, final rotation, and bias ---
            new_final_pool_param = np.array(avg_final_pool_param)  # shallow copy
            if isinstance(avg_final_params, tuple):
              new_final_params = tuple(avg_final_params)
            else:
              new_final_params = np.array(avg_final_params)
            new_bias_param = np.array(avg_bias)

            # --- Update block (HEA ansatz) parameters ---
            new_block_params_list = []
            # For each block parameter array bp in the list:
            if not isinstance(block_params_list[0], PatchQuantumGenerator):
              with torch.no_grad():
                for bp_idx, bp in enumerate(block_params_list):
                    # bp has shape (client_depth, num_qubits, 3)
                    client_depth, num_qubits, _ = bp.shape
                    # If this block type (num_qubits) has aggregated updates:
                    if (num_qubits, bp_idx) in aggregated_block_params:
                        aggregated_common, aggregated_extra = aggregated_block_params[(num_qubits, bp_idx)]
                        common_depth = aggregated_common.shape[0]  # minimum depth shared among all clients for this block type
                        if math_int == np:
                          new_bp = bp.copy()
                        elif math_int == torch:
                          print_cust(f"broadcast_param_updates, applying .detach().clone() to bp")
                          new_bp = bp.detach().clone()
                        # Update common part: last common_depth layers of bp.
                        # NOTE: in general, I do NOT want to update the last common layers; I want to update the FIRST
                        # commmon layers. but, b/c ACROSS clients of the same type, the layers should always be the same
                        # structure, it does not matter.
                        new_bp[client_depth - common_depth : client_depth, :, :] = aggregated_common
                        # Update extra part if bp has extra layers and aggregated_extra is available.
                        client_extra = client_depth - common_depth
                        if client_extra > 0 and (aggregated_extra is not None):
                            meta_extra = aggregated_extra.shape[0]
                            # Align: offset = meta_extra - client_extra
                            offset_extra = meta_extra - client_extra
                            new_bp[0:client_extra, :, :] = aggregated_extra[offset_extra : offset_extra + client_extra]
                        new_block_params_list.append(new_bp)
                    else:
                        # No aggregated update for this block type; leave unchanged.
                        print_cust(f"broadcast_param_updates, no update for bp bp_idx: {bp_idx}, client_depth: {client_depth}, num_qubits: {num_qubits}")
                        new_block_params_list.append(bp)
            else:
              # This is the generator case
              existing_generator = block_params_list[0]
              existing_discriminator = block_params_list[1]
              with torch.no_grad():
                existing_generator.load_state_dict(aggregated_block_params[0].state_dict())
                print_cust(f"broadcast_param_updates, existing_generator: {existing_generator}")
                # print_cust(f"broadcast_param_updates, existing_generator.state_dict(): {existing_generator.state_dict()}")
                existing_discriminator.load_state_dict(aggregated_block_params[1].state_dict())
              print_cust(f"broadcast_param_updates, existing_discriminator: {existing_discriminator}")
              # print_cust(f"broadcast_param_updates, existing_discriminator.state_dict(): {existing_discriminator.state_dict()}")
              new_block_params_list = block_params_list

            # Update the client parameter tuple.
            client_params_dict[client_type][i] = (tuple(new_conv_params),
                                                    tuple(new_pool_params),
                                                    new_final_pool_param,
                                                    new_final_params,
                                                    new_bias_param,
                                                    new_block_params_list)
    return client_params_dict

# REPURPOSE: use model objects in the above code.

# DONE: TOMODIFY, layers: when setting the block parameters in the above code, do it in a torch.no_grad() context.
# DONE: TOMODIFY, layers: copy() doesn't fly in the pytorch tensor regime; I need to do .detach().clone(). but make this work with numpy (configurable)

"""### Broadcast parameters function, aggregate shared only"""

def broadcast_param_updates_shared(client_params_dict, aggregated_params, math_int=np):
    """
    Update each client's parameters in client_params_dict using only the shared (common) aggregated parameters.

    Args:
      client_params_dict: dict mapping client types to lists of parameter tuples. Each parameter tuple has the form:
          (conv_params, pool_params, final_pool_param, final_params, bias_param, block_params_list)
          where:
            conv_params: tuple of convolution parameters (each element is a tuple (even, odd))
            pool_params: tuple of pooling parameters (arrays)
            final_pool_param: array for final pooling
            final_params: array for final rotation
            bias_param: array (or scalar) for bias
            block_params_list: list of block (HEA ansatz) parameter arrays.
      aggregated_params: aggregated parameters returned by
          federated_averaging_circular_parallel_shared, having the structure:
          (shared_conv_avg, shared_pool_avg, shared_final_pool_param,
           shared_final_params, shared_bias, shared_block_params)

    Returns:
      Updated client_params_dict with only the shared (common) portions replaced by the aggregated ones.
    """

    # Unpack aggregated shared parameters.
    shared_conv_avg, shared_pool_avg, shared_final_pool_param, shared_final_params, shared_bias, shared_block_params = aggregated_params

    if not isinstance(shared_block_params, dict):
      shared_block_params = convert_agg_block_params_generator(shared_block_params)
      print_cust(f"broadcast_param_updates_shared, shared_block_params: {shared_block_params}")
      # print_cust(f"broadcast_param_updates, shared_block_params[0].state_dict(): {shared_block_params[0].state_dict()}")
      # print_cust(f"broadcast_param_updates, shared_block_params[1].state_dict(): {shared_block_params[1].state_dict()}")

    # Number of shared convolution and pooling layers (these are the "common" layers to update).
    num_shared_conv = len(shared_conv_avg)
    num_shared_pool = len(shared_pool_avg)

    print_cust(f"broadcast_param_updates_shared, num_shared_conv: {num_shared_conv}, num_shared_pool: {num_shared_pool}")

    # Loop over each client type and each client.
    for client_type in client_params_dict:
        print_cust(f"broadcast_param_updates_shared, client_type: {client_type}")
        for i, params in enumerate(client_params_dict[client_type]):
            print_cust(f"broadcast_param_updates_shared, i: {i}")
            conv_params, pool_params, final_pool_param, final_params, bias_param, block_params_list = params

            # --- Update QCNN convolution parameters ---
            # conv_params is a tuple; update only the last 'num_shared_conv' layers.
            client_conv_layers = len(conv_params)
            offset_conv = client_conv_layers - num_shared_conv
            print_cust(f"broadcast_param_updates_shared, offset_conv: {offset_conv}")
            new_conv_params = list(conv_params)  # convert to list for modification
            for j in range(num_shared_conv):
                # Each conv layer is a tuple (even, odd) and we simply replace it.
                new_conv_params[offset_conv + j] = shared_conv_avg[j]

            # --- Update pooling parameters ---
            client_pool_layers = len(pool_params)
            offset_pool = client_pool_layers - num_shared_pool
            print_cust(f"broadcast_param_updates_shared, offset_pool: {offset_pool}")
            new_pool_params = list(pool_params)
            for j in range(num_shared_pool):
                new_pool_params[offset_pool + j] = shared_pool_avg[j]

            # --- Update final pooling, rotation, and bias parameters ---
            # These parameters are assumed to be shared among all clients.
            new_final_pool_param = np.array(shared_final_pool_param)
            if isinstance(shared_final_params, tuple):
              new_final_params = tuple(shared_final_params)
            else:
              new_final_params = np.array(shared_final_params)
            new_bias_param = np.array(shared_bias) if isinstance(shared_bias, np.ndarray) else shared_bias

            # --- Update block (HEA ansatz) parameters ---
            # For each block parameter array in the client's list, update only the last common layers.
            new_block_params_list = []
            if not isinstance(block_params_list[0], PatchQuantumGenerator):
              with torch.no_grad():
                for bp_idx, bp in enumerate(block_params_list):
                    # Determine the block type by number of qubits (assumed bp.shape = (depth, n_qubits, 3)).
                    block_type = bp.shape[1]
                    if (block_type, bp_idx) in shared_block_params:
                        aggregated_common = shared_block_params[(block_type, bp_idx)]  # shape: (common_depth, block_type, 3)
                        common_depth = aggregated_common.shape[0]
                        print_cust(f"broadcast_param_updates_shared, common_depth: {common_depth}")
                        client_depth = bp.shape[0]
                        if math_int == np:
                          bp_updated = bp.copy()
                        elif math_int == torch:
                          print_cust(f"broadcast_param_updates_shared, applying .detach().clone() to bp")
                          bp_updated = bp.detach().clone()
                        # Update the last 'common_depth' layers.
                        bp_updated[client_depth - common_depth : client_depth, :, :] = aggregated_common
                        new_block_params_list.append(bp_updated)
                    else:
                        # If no update for this block type is provided, leave it unchanged.
                        print_cust(f"broadcast_param_updates_shared, no update for bp_idx: {bp_idx}, client_type: {client_type}, i: {i}")
                        new_block_params_list.append(bp)
            else:
              # This is the generator case
              existing_generator = block_params_list[0]
              existing_discriminator = block_params_list[1]
              with torch.no_grad():
                existing_generator.load_state_dict(shared_block_params[0].state_dict())
                print_cust(f"broadcast_param_updates_shared, existing_generator: {existing_generator}")
                # print_cust(f"broadcast_param_updates, existing_generator.state_dict(): {existing_generator.state_dict()}")
                existing_discriminator.load_state_dict(shared_block_params[1].state_dict())
              print_cust(f"broadcast_param_updates_shared, existing_discriminator: {existing_discriminator}")
              # print_cust(f"broadcast_param_updates, existing_discriminator.state_dict(): {existing_discriminator.state_dict()}")
              new_block_params_list = block_params_list

            # Assemble the new parameter tuple for the client.
            updated_params = (
                tuple(new_conv_params),
                tuple(new_pool_params),
                new_final_pool_param,
                new_final_params,
                new_bias_param,
                new_block_params_list
            )
            # Update the client parameter in place.
            client_params_dict[client_type][i] = updated_params

    return client_params_dict

# DONE: TOMODIFY, layers: same thing. do the block parameters aggregation in a torch.no_grad() context.
# DONE: TOMODIFY, layers: copy() doesn't fly in the pytorch tensor regime; I need to do .detach().clone(). but make this work with numpy (configurable)

"""### Broadcast parameters function, aggregated shared only, no convolutional layer aggregation"""

def broadcast_param_updates_shared_noconv(client_params_dict, aggregated_params):
    """
    Update each client's parameters in client_params_dict using only the shared (common) aggregated parameters.

    Args:
      client_params_dict: dict mapping client types to lists of parameter tuples. Each parameter tuple has the form:
          (conv_params, pool_params, final_pool_param, final_params, bias_param, block_params_list)
          where:
            conv_params: tuple of convolution parameters (each element is a tuple (even, odd))
            pool_params: tuple of pooling parameters (arrays)
            final_pool_param: array for final pooling
            final_params: array for final rotation
            bias_param: array (or scalar) for bias
            block_params_list: list of block (HEA ansatz) parameter arrays.
      aggregated_params: aggregated parameters returned by
          federated_averaging_circular_parallel_shared, having the structure:
          (shared_conv_avg, shared_pool_avg, shared_final_pool_param,
           shared_final_params, shared_bias, shared_block_params)

    Returns:
      Updated client_params_dict with only the shared (common) portions replaced by the aggregated ones.
    """
    # NOTE, layers: I shouldn't need to use this function for the layer-heterogenous case, so I'm not going to
    # modify it for now.

    # Unpack aggregated shared parameters.
    shared_conv_avg, shared_pool_avg, shared_final_pool_param, shared_final_params, shared_bias, shared_block_params = aggregated_params

    # Number of shared convolution and pooling layers (these are the "common" layers to update).
    num_shared_conv = len(shared_conv_avg)
    num_shared_pool = len(shared_pool_avg)

    # Loop over each client type and each client.
    for client_type in client_params_dict:
        for i, params in enumerate(client_params_dict[client_type]):
            conv_params, pool_params, final_pool_param, final_params, bias_param, block_params_list = params

            # --- Update QCNN convolution parameters ---
            # conv_params is a tuple; update only the last 'num_shared_conv' layers.
            client_conv_layers = len(conv_params)
            offset_conv = client_conv_layers - num_shared_conv
            # new_conv_params = list(conv_params)  # convert to list for modification
            # for j in range(num_shared_conv):
            #     # Each conv layer is a tuple (even, odd) and we simply replace it.
            #     new_conv_params[offset_conv + j] = shared_conv_avg[j]
            new_conv_params = conv_params

            # --- Update pooling parameters ---
            client_pool_layers = len(pool_params)
            offset_pool = client_pool_layers - num_shared_pool
            new_pool_params = list(pool_params)
            for j in range(num_shared_pool):
                new_pool_params[offset_pool + j] = shared_pool_avg[j]

            # --- Update final pooling, rotation, and bias parameters ---
            # These parameters are assumed to be shared among all clients.
            new_final_pool_param = np.array(shared_final_pool_param)
            if isinstance(shared_final_params, tuple):
              new_final_params = tuple(shared_final_params)
            else:
              new_final_params = np.array(shared_final_params)
            new_bias_param = np.array(shared_bias) if isinstance(shared_bias, np.ndarray) else shared_bias

            # --- Update block (HEA ansatz) parameters ---
            # For each block parameter array in the client's list, update only the last common layers.
            new_block_params_list = []
            for bp in block_params_list:
                # Determine the block type by number of qubits (assumed bp.shape = (depth, n_qubits, 3)).
                block_type = bp.shape[1]
                if block_type in shared_block_params:
                    aggregated_common = shared_block_params[block_type]  # shape: (common_depth, block_type, 3)
                    common_depth = aggregated_common.shape[0]
                    client_depth = bp.shape[0]
                    bp_updated = bp.copy()
                    # Update the last 'common_depth' layers.
                    bp_updated[client_depth - common_depth : client_depth, :, :] = aggregated_common
                    new_block_params_list.append(bp_updated)
                else:
                    # If no update for this block type is provided, leave it unchanged.
                    new_block_params_list.append(bp)

            # Assemble the new parameter tuple for the client.
            updated_params = (
                tuple(new_conv_params),
                tuple(new_pool_params),
                new_final_pool_param,
                new_final_params,
                new_bias_param,
                new_block_params_list
            )
            # Update the client parameter in place.
            client_params_dict[client_type][i] = updated_params

    return client_params_dict

"""### Broadcast parameters function, aggregated shared only, no convolutional layer aggregation, no final params aggregation"""

def broadcast_param_updates_shared_noconv_nofinal(client_params_dict, aggregated_params):
    """
    Update each client's parameters in client_params_dict using only the shared (common) aggregated parameters.

    Args:
      client_params_dict: dict mapping client types to lists of parameter tuples. Each parameter tuple has the form:
          (conv_params, pool_params, final_pool_param, final_params, bias_param, block_params_list)
          where:
            conv_params: tuple of convolution parameters (each element is a tuple (even, odd))
            pool_params: tuple of pooling parameters (arrays)
            final_pool_param: array for final pooling
            final_params: array for final rotation
            bias_param: array (or scalar) for bias
            block_params_list: list of block (HEA ansatz) parameter arrays.
      aggregated_params: aggregated parameters returned by
          federated_averaging_circular_parallel_shared, having the structure:
          (shared_conv_avg, shared_pool_avg, shared_final_pool_param,
           shared_final_params, shared_bias, shared_block_params)

    Returns:
      Updated client_params_dict with only the shared (common) portions replaced by the aggregated ones.
    """
    # NOTE, layers: I shouldn't need to use this function for the layer-heterogenous case, so I'm not going to
    # modify it for now.

    # Unpack aggregated shared parameters.
    shared_conv_avg, shared_pool_avg, shared_final_pool_param, shared_final_params, shared_bias, shared_block_params = aggregated_params

    # Number of shared convolution and pooling layers (these are the "common" layers to update).
    num_shared_conv = len(shared_conv_avg)
    num_shared_pool = len(shared_pool_avg)

    # Loop over each client type and each client.
    for client_type in client_params_dict:
        for i, params in enumerate(client_params_dict[client_type]):
            conv_params, pool_params, final_pool_param, final_params, bias_param, block_params_list = params

            # --- Update QCNN convolution parameters ---
            # conv_params is a tuple; update only the last 'num_shared_conv' layers.
            client_conv_layers = len(conv_params)
            offset_conv = client_conv_layers - num_shared_conv
            # new_conv_params = list(conv_params)  # convert to list for modification
            # for j in range(num_shared_conv):
            #     # Each conv layer is a tuple (even, odd) and we simply replace it.
            #     new_conv_params[offset_conv + j] = shared_conv_avg[j]
            new_conv_params = conv_params

            # --- Update pooling parameters ---
            client_pool_layers = len(pool_params)
            offset_pool = client_pool_layers - num_shared_pool
            new_pool_params = list(pool_params)
            for j in range(num_shared_pool):
                new_pool_params[offset_pool + j] = shared_pool_avg[j]

            # --- Update final pooling, rotation, and bias parameters ---
            # These parameters are assumed to be shared among all clients.
            new_final_pool_param = np.array(shared_final_pool_param)
            # if isinstance(shared_final_params, tuple):
            #   new_final_params = tuple(shared_final_params)
            # else:
            #   new_final_params = np.array(shared_final_params)
            new_final_params = final_params
            new_bias_param = np.array(shared_bias) if isinstance(shared_bias, np.ndarray) else shared_bias

            # --- Update block (HEA ansatz) parameters ---
            # For each block parameter array in the client's list, update only the last common layers.
            new_block_params_list = []
            for bp in block_params_list:
                # Determine the block type by number of qubits (assumed bp.shape = (depth, n_qubits, 3)).
                block_type = bp.shape[1]
                if block_type in shared_block_params:
                    aggregated_common = shared_block_params[block_type]  # shape: (common_depth, block_type, 3)
                    common_depth = aggregated_common.shape[0]
                    client_depth = bp.shape[0]
                    bp_updated = bp.copy()
                    # Update the last 'common_depth' layers.
                    bp_updated[client_depth - common_depth : client_depth, :, :] = aggregated_common
                    new_block_params_list.append(bp_updated)
                else:
                    # If no update for this block type is provided, leave it unchanged.
                    new_block_params_list.append(bp)

            # Assemble the new parameter tuple for the client.
            updated_params = (
                tuple(new_conv_params),
                tuple(new_pool_params),
                new_final_pool_param,
                new_final_params,
                new_bias_param,
                new_block_params_list
            )
            # Update the client parameter in place.
            client_params_dict[client_type][i] = updated_params

    return client_params_dict

"""## Gradient masking function"""

def mask_gradients(params):
    """
    For a given set of parameters, creates a gradient mask for the set of convolutional layers, pooling layers, and block parameters acting on the initial set of qubits
    (i.e., creates a Boolean mask that is 1 for the first convolution, pooling, and block layers, and 0 everywhere else).

    Parameters:
      params: a list of arrays of parameters for this client

    Returns:
      A list of arrays representing a boolean mask for the gradient for this client such that only parameters on all of its qubits are updated.
    """
    conv_params_first, pool_params_first, final_pool_param_first, final_params_first, bias_param_first, block_params_first = params
    conv_layers = len(conv_params_first)
    aggregated_conv_mask = []
    for layer in range(conv_layers):
        even_shape = conv_params_first[layer][0].shape
        odd_shape = conv_params_first[layer][1].shape
        if layer == 0:
            aggregated_even = np.ones(even_shape)
            aggregated_odd = np.ones(odd_shape)
        else:
            aggregated_even = np.zeros(even_shape)
            aggregated_odd = np.zeros(odd_shape)
        aggregated_conv_mask.append((aggregated_even, aggregated_odd))
    pool_layers = len(pool_params_first)
    aggregated_pool_mask = []
    for layer in range(pool_layers):
        aggregated_pool_mask.append(np.ones(pool_params_first[layer].shape) if layer == 0 else np.zeros(pool_params_first[layer].shape))
    aggregated_final_pool_mask = np.zeros(final_pool_param_first.shape)
    aggregated_final_params_mask = np.zeros(final_params_first.shape)
    aggregated_bias_mask = np.zeros(bias_param_first.shape)

    # Determine the number of qubits in the first convolutional layer.
    # We approximate n_first = 2 * (number of even pairs) from conv_params_first[0].
    n_first = conv_params_first[0][0].shape[0] * 2
    aggregated_block_mask = []
    for bp in block_params_first:
        # bp.shape = (num_layers_bp, num_qubits_bp, 3)
        if bp.shape[1] == n_first:
            mask_bp = np.ones(bp.shape)
        else:
            mask_bp = np.zeros(bp.shape)
        aggregated_block_mask.append(mask_bp)

    aggregated_params = (
        tuple(aggregated_conv_mask),
        tuple(aggregated_pool_mask),
        aggregated_final_pool_mask,
        aggregated_final_params_mask,
        aggregated_bias_mask,
        aggregated_block_mask
    )
    return aggregated_params

"""## Meta-params initialization"""

def generate_meta_params(client_params_dict, clients_data_dict, math_int=np):
    """
    Generate meta parameters that cover all shared parameters across clients.

    For the convolutional, pooling, final pooling, final rotation, and bias parameters, we use
    the parameters from the client with the largest model type (which is a superset for these layers).
    For the block parameters, we aggregate them from all clients using the aggregate_block_params_general function.

    Parameters:
      client_params_dict: dict mapping client types (e.g., model sizes) to lists of parameter tuples.
          Each tuple is (conv_params, pool_params, final_pool_param, final_params, bias_param, block_params_list).
      clients_data_dict: dict mapping client types to lists of client data tuples.
          Each client data tuple is ((X_train, y_train), (X_val, y_val)).
      n_output_qubits: an integer representing the number of output qubits for the QCNN model.

    Returns:
        A tuple (meta_conv, meta_pool, meta_final_pool, meta_final, meta_bias, meta_block)
        that represents the meta parameter set.
    """
    # Identify the largest model type (assumes keys are numerical sizes).
    print_cust(f"generate_meta_params, math_int: {math_int}")
    max_model_type = max(client_params_dict.keys())
    # Use the parameters from the first client of the largest model type.
    base_meta = client_params_dict[max_model_type][0]
    meta_conv, meta_pool, meta_final_pool, meta_final, meta_bias, meta_block_orig = base_meta

    meta_block = aggregate_block_params_general(client_params_dict, clients_data_dict, math_int=math_int)
    print_cust(f"generate_meta_params, meta_block: {meta_block}")
    for block_type, (aggregated_common, aggregated_extra) in meta_block.items():
      # NOTE: now, block_type is a tuple.
      print_cust(f"generate_meta_params, type(block_type): {type(block_type)}, type(aggregated_common): {type(aggregated_common)}, type(aggregated_extra): {type(aggregated_extra)}")

    # Aggregate block parameters from all clients.
    if isinstance(meta_block_orig[0], PatchQuantumGenerator):
      print_cust(f"generate_meta_params, meta_block_orig[0] is a PatchQuantumGenerator")
      # Placeholder for now; assumes homogenous architecture among all clients.
      # meta_gen_params = aggregate_block_params_general(client_params_dict, clients_data_dict, math_int=math_int)
      meta_disc = meta_block_orig[1]
      meta_block = [meta_block, meta_disc]
      print_cust(f"generate_meta_params, generator case, meta_block: {meta_block}")

    # Construct and return the meta parameter tuple.
    return (meta_conv, meta_pool, meta_final_pool, meta_final, meta_bias, meta_block)

"""## Random meta params initialization"""

# TODO: continue here; generate_meta_params_random. See how it is used in the overall training code.
def generate_meta_params_random(meta_params, math_int=np):
    """
    Given meta_params with structure:
      (meta_conv, meta_pool, meta_final_pool, meta_final, meta_bias, meta_block)
    where:
      - meta_conv is a tuple of layers, each a tuple (even_params, odd_params)
      - meta_pool is a tuple of pooling parameter arrays
      - meta_final_pool is a numpy array (e.g., shape (1,))
      - meta_final is a numpy array (e.g., shape (1,3))
      - meta_bias is a numpy array
      - meta_block is a dict mapping block type (int) to a tuple
          (aggregated_common, aggregated_extra), where aggregated_common is an array
          of shape (common_depth, num_qubits, 3) and aggregated_extra is either None
          or an array of shape (extra_layers, num_qubits, 3)

    Returns:
      meta_params_random: a new tuple with the same structure as meta_params,
                          but with each numpy array re-initialized with random values.
    """
    meta_conv, meta_pool, meta_final_pool, meta_final, meta_bias, meta_block = meta_params

    rand_int = np.random

    # Generate random meta_conv with the same structure.
    random_meta_conv = tuple(
        (rand_int.randn(*even.shape), rand_int.randn(*odd.shape))
        for even, odd in meta_conv
    )

    # Generate random meta_pool with the same structure.
    random_meta_pool = tuple(
        rand_int.randn(*pool.shape) for pool in meta_pool
    )

    # Generate random arrays for final pooling, final rotation, and bias.
    random_meta_final_pool = rand_int.randn(*meta_final_pool.shape)
    if isinstance(meta_final, (list, tuple)):
      random_meta_final = (rand_int.randn(*meta_final[0].shape), rand_int.randn(*meta_final[1].shape))
    else:
      random_meta_final = rand_int.randn(*meta_final.shape)
    random_meta_bias = rand_int.randn(*meta_bias.shape)

    # For block parameters, generate a random array for each block type.
    # meta_block is assumed to be a dict mapping block type to (aggregated_common, aggregated_extra)

    meta_block_dict = meta_block

    if math_int == np:
      rand_int = np.random
    else:
      rand_int = torch

    if not isinstance(meta_block, dict):
      meta_block_dict = meta_block[0]
      assert isinstance(meta_block_dict, dict), f"generate_meta_params_random, meta_block_dict is not a dict, type(meta_block_dict): {type(meta_block_dict)}"

    random_meta_block = {}
    for block_type, (common, extra) in meta_block_dict.items():
        random_common = rand_int.randn(*common.shape)
        random_extra = rand_int.randn(*extra.shape) if extra is not None else None
        random_meta_block[block_type] = (random_common, random_extra)

    if not isinstance(meta_block, dict):
      print_cust(f"generate_meta_params_random, meta_block is not a dict")
      random_meta_block = [random_meta_block, meta_block[1]]

    print_cust(f"generate_meta_params_random, random_meta_block: {random_meta_block}")


    return (random_meta_conv, random_meta_pool, random_meta_final_pool,
            random_meta_final, random_meta_bias, random_meta_block)

"""## Amplitude Embedding Data Reordering Function"""

def reorder_amplitude_data(data: np.ndarray, qubit_order: list[int]) -> np.ndarray:
    """
    Reorder (and, if needed, zero-pad or truncate) amplitude-embedded data
    for a new qubit ordering.

    Args:
        data (np.ndarray): Array of shape (n_samples, N_old), where N_old = 2**k_old.
        qubit_order (list[int]): New ordering of qubit indices of length k_new;
                                 qubit_order[i] is the original qubit index
                                 to place at position i.

    Returns:
        np.ndarray: Array of shape (n_samples, 2**k_new), with amplitudes:
            - permuted so that the i-th qubit in the new embedding
              corresponds to the original qubit_order[i],
            - zero-padded if N_old < 2**k_new,
            - truncated if N_old > 2**k_new.
    """
    # NOTE, layers, ampembed: be careful about how this function is called. make sure that the inputs to the
    # calling function here fit; ideally, I don't want to call this function.
    print_cust(f"reorder_amplitude_data, data.shape: {data.shape}")
    print_cust(f"reorder_amplitude_data, qubit_order: {qubit_order}")
    n_samples, old_size = data.shape

    # --- sanity-check old size is a power of 2 ---
    # old_nq = int(np.log2(old_size))
    # if 2**old_nq != old_size:
    #     raise ValueError("`data` must have 2^k columns")

    # --- compute new dimensions ---
    new_nq   = len(qubit_order)
    new_size = 1 << new_nq

    # --- truncate or pad as needed ---
    if old_size > new_size:
        # too many features → truncate
        data = data[:, :new_size]
    elif old_size < new_size:
        # too few features → pad with zeros
        padded = np.zeros((n_samples, new_size), dtype=data.dtype)
        padded[:, :old_size] = data
        data = padded
    # now data.shape[1] == new_size

    print_cust(f"reorder_amplitude_data, data.shape: {data.shape}")

    # --- build the bit-permutation map ---
    perm = np.empty(new_size, dtype=int)
    for j in range(new_size):
        bits     = (j >> np.arange(new_nq)) & 1          # binary of j
        bits_new = bits[qubit_order]                     # permute bit positions
        perm[j]  = int(np.dot(bits_new, 1 << np.arange(new_nq)))  # reassemble

    # --- apply permutation across amplitudes ---
    return data[:, perm]

"""## Convert Data to Tensors"""

def convert_data_to_lib(clients_data_dict, math_int=np):
  for client_type, clients_data in clients_data_dict.items():
    for data in clients_data:
      train_data = data[0][0]
      train_labels = data[0][1]
      # NOTE: if I'm getting datatype issues, I can change this to float32 (b/c converting from np arrs, these are float64.)
      if math_int == torch:
        data[0][0] = torch.tensor(train_data, dtype=torch.float32)
        data[0][1] = torch.tensor(train_labels, dtype=torch.float32)
      val_data = data[1][0]
      val_labels = data[1][1]
      if math_int == torch:
        data[1][0] = torch.tensor(val_data, dtype=torch.float32)
        data[1][1] = torch.tensor(val_labels, dtype=torch.float32)
  return clients_data_dict

# TODO, layers: continue here, 08/04, 12:38 PM.
# DONE: TOMODIFY, layers: convert the labels to torch tensors, too.

"""## Federated PCA Function"""

# TODO, DepthFL: continue here, 8/19 12:12 PM

def perform_federated_pca_mocked(clients_data_dict, max_size_clients, random_seed=42, math_int=np, device=None, generative=False):
  # local_covariances = []
  # local_weights = []
  # pca_list = []
  # for client_type, client_data in clients_data_dict.items():
  #   for client_data_indiv in client_data:
  #     cli_weight = client_data_indiv[0][0].shape[0] + client_data_indiv[1][0].shape[0]
  #     cli_pca = client_data_indiv[2][0]
  #     cli_cov = cli_pca.get_covariance()
  #     local_covariances.append(cli_cov)
  #     local_weights.append(cli_weight)
  #     pca_list.append(cli_pca)
  # total_weights = sum(local_weights)
  # local_weights = [loc_weight / total_weights for loc_weight in local_weights]
  # # assumes that all cov matrices are the same shape
  # assert len(local_covariances) == len(local_weights) == len(pca_list), "Local covariances and local weights length are not the same"
  # global_cov = sum(local_weights[i] * local_covariances[i] for i in range(len(local_covariances)))
  # e_vals, e_vecs = np.linalg.eigh(global_cov)
  # # TODO: again, improve this for the homogenous setting
  # idx = np.argsort(e_vals)[::-1][:max_size_clients]
  # U_global = e_vecs[:, idx]

  # pca_global = PCA(n_components=max_size_clients)
  # pca_global.components_ = U_global.T
  # pca_global.explained_variance_ = e_vals[idx]
  # pca_global.explained_variance_ratio_ = e_vals[idx] / e_vals.sum()
  # pca_global.mean_ = sum(local_weights[i] * pca_list[i].mean_ for i in range(len(pca_list)))
  # pca_global.n_samples_ = sum(pca_list[i].n_samples_ for i in range(len(pca_list)))
  # pca_global.n_features_in_ = U_global.shape[0]

  print_cust(f"perform_federated_pca_mocked, generative: {generative}")

  any_client_data = next(iter(next(iter(clients_data_dict.values()))))

  print_cust(f"perform_federated_pca_mocked, len(any_client_data): {len(any_client_data)}")

  data_dim = any_client_data[0][0].shape[1]

  print_cust(f"perform_federated_pca_mocked, data_dim: {data_dim}")

  agg_client_data = math_int.empty((0, data_dim))

  print_cust(f"perform_federated_pca_mocked, agg_client_data.shape: {agg_client_data.shape}")

  # TODO: this is slow; can change later.
  for client_type, client_data in clients_data_dict.items():
    for client_data_indiv in client_data:
      all_client_data = math_int.concatenate((client_data_indiv[0][0], client_data_indiv[1][0]), axis=0)
      print_cust(f"perform_federated_pca_mocked, all_client_data.shape: {all_client_data.shape}")
      agg_client_data = math_int.concatenate((agg_client_data, all_client_data), axis=0)

  print_cust(f"perform_federated_pca_mocked, max_size_clients: {max_size_clients}")
  # NOTE: PCA is dependent on the seed given.
  pca_global = PCA(n_components=max_size_clients, random_state=random_seed)
  agg_client_data_numpy = agg_client_data.detach().cpu().numpy()
  glob_data_numpy = pca_global.fit_transform(agg_client_data_numpy)
  glob_data = torch.from_numpy(glob_data_numpy).to(device)
  print_cust(f"perform_federated_pca_mocked, glob_data.shape: {glob_data.shape}")
  shared_max_comps_test = math_int.max(glob_data, axis=0)
  shared_min_comps_test = math_int.min(glob_data, axis=0)
  print_cust(f"perform_federated_pca_mocked, shared_max_comps_test: {shared_max_comps_test}, shared_min_comps_test: {shared_min_comps_test}")

  glob_data_invtransform_numpy = pca_global.inverse_transform(glob_data_numpy)
  glob_data_invtransform = torch.from_numpy(glob_data_invtransform_numpy).to(device)

  inv_pca_max = glob_data_invtransform.max()
  inv_pca_min = glob_data_invtransform.min()

  print_cust(f"perform_federated_pca_mocked, inv_pca_max: {inv_pca_max}, inv_pca_min: {inv_pca_min}")

  # transform each clients' data
  shared_max_comps = [-float('inf') for _ in range(max_size_clients)]
  shared_min_comps = [float('inf') for _ in range(max_size_clients)]

  for client_type, client_data in clients_data_dict.items():
    for client_data_indiv in client_data:
      cli_train_data = client_data_indiv[0][0]
      cli_val_data = client_data_indiv[1][0]
      n_cli_train = client_data_indiv[0][0].shape[0]
      cli_train_val_data = math_int.concatenate((cli_train_data, cli_val_data), axis=0)
      cli_train_val_data_numpy = cli_train_val_data.detach().cpu().numpy()
      cli_train_val_data_reduced_numpy = pca_global.transform(cli_train_val_data_numpy)
      cli_train_val_data_reduced = torch.from_numpy(cli_train_val_data_reduced_numpy).to(device)
      assert len(shared_max_comps) <= cli_train_val_data_reduced.shape[1], "Maximum components list doesn't cover all components for client data"
      for i in range(cli_train_val_data_reduced.shape[1]):
          comp = cli_train_val_data_reduced[:, i]
          lo, hi = comp.min(), comp.max()
          if lo < shared_min_comps[i]:
            shared_min_comps[i] = lo
          if hi > shared_max_comps[i]:
            shared_max_comps[i] = hi

      client_data_indiv[2][1] = cli_train_val_data_reduced
      client_data_indiv[0][0] = cli_train_val_data_reduced[:n_cli_train]
      client_data_indiv[1][0] = cli_train_val_data_reduced[n_cli_train:]
      client_data_indiv[2][0] = pca_global

  print_cust(f"perform_federated_pca_mocked, shared_max_comps: {shared_max_comps}, shared_min_comps: {shared_min_comps}")

  if not generative:
    print_cust(f"perform_federated_pca_mocked, rescaling training data to be from 0 to pi")
    for client_type, client_data in clients_data_dict.items():
      for client_data_indiv in client_data:
        cli_train_val_data_reduced = client_data_indiv[2][1]

        # 2) scale each component to [0, π]
        # NOTE, layers: I might get dtype issues here, in which case I'd need to change to torch.float32 explicitly.
        X_angle_cli = math_int.zeros_like(cli_train_val_data_reduced)
        for i in range(cli_train_val_data_reduced.shape[1]):
            comp = cli_train_val_data_reduced[:, i]
            lo, hi = shared_min_comps[i], shared_max_comps[i]
            X_angle_cli[:, i] = ( (comp - lo) / (hi - lo + 1e-8) ) * math_int.pi

        # ADDED, layers: n_cli_train, as it should be specified for this particular client_data_indiv.
        n_cli_train = client_data_indiv[0][0].shape[0]

        cli_train_data_reduced = X_angle_cli[:n_cli_train]
        cli_val_data_reduced = X_angle_cli[n_cli_train:]

        client_data_indiv[0][0] = cli_train_data_reduced
        client_data_indiv[1][0] = cli_val_data_reduced

        # client_data_indiv[2][0] = pca_global
        # client_data_indiv[2][1] = cli_train_val_data_reduced
        # NOTE, layers: shouldn't need to change [2][1]; i.e., the PCA reduced but non-angle encoded data right?

  return pca_global, shared_min_comps_test, shared_max_comps_test, inv_pca_min, inv_pca_max

# DONE: TOMODIFY, layers: be sure to change the labels to be pytorch tensors too.

"""## QFL Experiment Function"""

import pickle

def run_qfl_experiments(clients_config_arg, classes=["4", "9"], n_samples=1000, dataset_type="mnist", agg_strategy="fedavg", test_frac=0.2, val_frac=0.1, random_state=42, pool_in=True,
                        local_batch_size=32, local_lr=0.01, shots=1024, debug=False, init_client_data_dict=None, save_pkl=False, mask_grads=False, qubits_and_layers_to_add_block_params=[],
                        train_models_parallel=False):
  max_size_clients = max(clients_config_arg.keys())
  min_size_clients = min(clients_config_arg.keys())
  X_angles, y = load_dataset(dataset_type=dataset_type, classes=classes, n_samples=n_samples, num_feats=max_size_clients)

  print_cust(f"run_qfl_experiments, X_angles.shape: {X_angles.shape}")

  if init_client_data_dict is not None:
    clients_data_dict = init_client_data_dict["clients_data_dict"]
    (X_test, y_test) = init_client_data_dict["testing_data"]
    print_cust(f"run_qfl_experiments, loaded in existing data")
  else:
    clients_data_dict, (X_test, y_test) = split_data_federated(X_angles, y, clients_config_arg, test_frac, val_frac=val_frac, random_state=random_state)
    print_cust(f"run_qfl_experiments, generated new data")

  if debug:
    for client_type in clients_data_dict:
      clients_data_list = clients_data_dict[client_type]
      print_cust(f"client_type: {client_type}, len(clients_data_list): {len(clients_data_list)}")
      for client_idx in range(len(clients_data_list)):
        client_data = clients_data_list[client_idx]
        print_cust(f"client_type: {client_type}, client_idx: {client_idx}, len(client_data): {len(client_data)}")
        (training_data, validation_data) = client_data
        print_cust(f"client_type: {client_type}, client_idx: {client_idx}, len(training_data): {len(training_data)}, len(validation_data): {len(validation_data)}")
        (X_train, y_train) = training_data
        (X_val, y_val) = validation_data
        print_cust(f"client_type: {client_type}, client_idx: {client_idx}, X_train.shape: {X_train.shape}, y_train.shape: {y_train.shape}, X_val.shape: {X_val.shape}, y_val.shape: {y_val.shape}")

  client_params_dict = initialize_client_params(clients_config_arg, model_size=min_size_clients, cur_client_params_dict=None, qubits_and_layers_to_add_block_params=qubits_and_layers_to_add_block_params, train_models_parallel=train_models_parallel)

  cur_shared_params = None

  client_types = sorted(list(clients_config_arg.keys()))

  data_logs = {}

  data_logs["clients_data_dict"] = clients_data_dict
  data_logs["testing_data"] = (X_test, y_test)

  if save_pkl:
    with open(f"data_logs_n_samples_{n_samples}_dataset_type_{dataset_type}_classes_{'_'.join(classes)}.pkl", "wb") as file:
      pickle.dump(data_logs, file)

  for cur_model_size in client_types:
    print_cust(f"run_qfl_experiments, cur_model_size: {cur_model_size}")
    data_logs[cur_model_size] = {}
    for client_size, cfg in clients_config_arg.items():
      data_logs[cur_model_size][client_size] = []
      # TODO: add a field for testing loss in data_logs
      data_logs[cur_model_size]["aggregated_params_rounds"] = []
      data_logs[cur_model_size]["testing_loss_rounds"] = []
      data_logs[cur_model_size]["testing_accuracy_rounds"] = []
      for client_idx in range(cfg["num_clients"]):
        data_logs[cur_model_size][client_size].append({"trained_params_rounds": [],
                                                       "minibatch_losses_rounds": [],
                                                       "validation_losses_rounds": [],
                                                       "training_accs_rounds": []})


    num_rounds = clients_config_arg[cur_model_size]["communication_rounds"]
    num_local_epochs = clients_config_arg[cur_model_size]["local_epochs"]
    # TODO: probably have some kind of aggregation of metrics here. what do you want to aggregate?

    # TODO: create qnode
    expand = False
    if cur_model_size > min_size_clients:
      expand = True
    feature_list, expansion_data = create_feat_list_expansion_data(cur_model_size, int(math.log2(cur_model_size)), expand=expand, pool_in=pool_in, min_qubits_noexpand=min_size_clients, train_models_parallel=train_models_parallel)

    for round_num in range(num_rounds):
      print_cust(f"run_qfl_experiments, round_num: {round_num}")
      for client_type, client_params in client_params_dict.items():
        print_cust(f"run_qfl_experiments, client_type: {client_type}")
        client_data = clients_data_dict[client_type]
        for client_idx in range(len(client_params)):
          print_cust(f"run_qfl_experiments, client_idx: {client_idx}")
          client_params_indiv = client_params[client_idx]
          client_data_indiv = client_data[client_idx]

          client_train_data = client_data_indiv[0]
          client_val_data = client_data_indiv[1]

          if debug:
            print_cust(f"run_qfl_experiments, len(client_train_data): {len(client_train_data)}")
            print_cust(f"run_qfl_experiments, client_train_data: {client_train_data}")
            print_cust(f"run_qfl_experiments, len(client_val_data): {len(client_val_data)}")
            print_cust(f"run_qfl_experiments, client_val_data: {client_val_data}")

          # train for num_local_epochs

          grad_mask = None

          # TODO: move this conditional up?
          if mask_grads:
            if cur_model_size > min_size_clients:
              grad_mask = mask_gradients(client_params_indiv)

          print_cust(f"run_qfl_experiments, grad_mask: {grad_mask}")

          trained_params, minibatch_losses, validation_losses = train_epochs_angle_param_adam(
        client_params_indiv, client_train_data[0][:, feature_list], client_train_data[1], client_val_data[0][:, feature_list], client_val_data[1],
        n_epochs=num_local_epochs, shots=shots, batch_size=local_batch_size, lr=local_lr, qnode=qnode, trainable_mask=grad_mask)

          client_params[client_idx] = trained_params

          qnode = create_qnode_qcnn(cur_model_size, int(math.log2(cur_model_size)), expansion_data, n_classes=len(classes))

          train_acc = compute_avg_acc_angle_param_batch(trained_params, client_train_data[0][:, feature_list], client_train_data[1], layers=int(np.log2(cur_model_size)), shots=1024, batch_size=local_batch_size, qnode=qnode)

          data_logs[cur_model_size][client_type][client_idx]["training_accs_rounds"].append(train_acc)

          data_logs[cur_model_size][client_type][client_idx]["trained_params_rounds"].append(trained_params)
          data_logs[cur_model_size][client_type][client_idx]["minibatch_losses_rounds"].append(minibatch_losses)
          data_logs[cur_model_size][client_type][client_idx]["validation_losses_rounds"].append(validation_losses)

      # perform aggregation
      # NOTE: assume, during aggregation, that the parameters for all the clients have the same shape.
      print_cust(f"run_qfl_experiments, parameter aggregation, agg_strategy: {agg_strategy}")
      if agg_strategy == "fedavg":
        aggregated_params = federated_averaging(client_params_dict, clients_data_dict)
      elif agg_strategy == "fedavg_quat":
        aggregated_params = federated_averaging_quat(client_params_dict, clients_data_dict)
      elif agg_strategy == "fedavg_circ":
        aggregated_params = federated_averaging_circular(client_params_dict, clients_data_dict)
        # print_cust(f"run_qfl_experiments, aggregated_params: {aggregated_params}")

      data_logs[cur_model_size]["aggregated_params_rounds"].append(aggregated_params)

      # broadcast parameter updates
      for client_type, client_params in client_params_dict.items():
        for client_idx in range(len(client_params)):
          client_params[client_idx] = aggregated_params

      test_loss = compute_loss_angle_param_batch(aggregated_params, X_test[:, feature_list], y_test,
                                            shots=shots, qnode=qnode)

      test_acc = compute_avg_acc_angle_param_batch(aggregated_params, X_test[:, feature_list], y_test, layers=int(np.log2(cur_model_size)), shots=1024, batch_size=local_batch_size, qnode=qnode)

      print_cust(f"run_qfl_experiments, test_loss: {test_loss}")

      print_cust(f"run_qfl_experiments, test_acc: {test_acc}")

      cur_shared_params = aggregated_params

      data_logs[cur_model_size]["testing_loss_rounds"].append(test_loss)

      data_logs[cur_model_size]["testing_accuracy_rounds"].append(test_acc)

    # initialize parameters for the next size of model training
    # TODO: handle different sized models
    print_cust(f"run_qfl_experiments, initializing next-sized models")
    client_params_dict = initialize_client_params(clients_config_arg, model_size=cur_model_size * 2, cur_client_params_dict=client_params_dict, qubits_and_layers_to_add_block_params=qubits_and_layers_to_add_block_params)

  return data_logs

"""## QFL Experiment Parallel Function"""

def run_qfl_experiments_parallel(clients_config_arg, classes=["4", "9"], n_samples=1000, dataset_type="mnist", agg_strategy="fedavg", test_frac=0.2, val_frac=0.1, random_state=42, pool_in=True,
                        local_batch_size=32, local_lr=0.01, shots=1024, debug=False, init_client_data_dict=None, save_pkl=False, mask_grads=False, qubits_and_layers_to_add_block_params=[],
                        train_models_parallel=False, same_init=False, feature_skew=0.0, label_skew=None, local_pca=False, do_lda=False, feat_sel_type="top", amp_embed=False, feat_ordering="same",
                                 morepers=False, custom_debug=False, shared_pca=False, heirarchical_train=False, generative=False, use_torch=False, fed_pca_mocked=True, lr_gen=0.004, lr_disc=0.001,
                                 noise_func=generate_latent_noise, criterion_func=nn.BCELoss, targ_data_folder_prefix="testing_gen_imgs", gen_data_folder_prefix="qgan_gen_imgs", device=None, fid_batch_size=None):

  """
  Function that runs the QFL workflow for a given configuration and returns a log of the training.

  Note: in general, the client parameters dictionary is of the form:

      {
        <int>: [[conv_params, pool_params, final_pool_params,
            final_params, bias_param, block_params],
            ... for each client],
        ... for each client type
      }

  Parameters:
    clients_config_arg, a dictionary mapping client types to configuration information for clients of that type. For this code, I expect clients_config_arg to be of the form:
      {
        <int>: {
            "percentage_data": <float>,
            "num_clients": <int>,
            "local_epochs": <int>,
            "communication_rounds": <int>
        },
        ... for each client type
      }
    classes, a list of strings (preferably string representations of integers) representing the classes on which to perform classification
      ex: ["4", "9"]
    n_samples, an integer representing the number of samples to use for training
    dataset_type, a string representing the dataset type for training
      should either be "mnist", "Fashion-MNIST", "cifar10", "synthetic", "pima", "higgs", or "covertype" for those respective datasets
    agg_strategy, a string representing the aggregation scheme to perform
      should be either "fedavg", "fedavg_quat", or "fedavg_circ"
    test_frac, a float representing the fraction of the total data to be used for testing
      should be between 0 and 1
    val_frac, a float representing the fraction of each client's data to be used for validation
      should be between 0 and 1
    random_state, an integer representing the random state used in the entire program (assuming that this function generates client data -- TODO: make the use of the random
    state argument more clear)
    pool_in, a Boolean representing whether or not the model should have reduce the number of qubits to the center
    local_batch_size, an integer representing the batch size each local model should use for training
    local_lr, a float representing the learning rate each model should use for training, upon each communication round
    shots, an integer representing the number of shots each model should use (currently unused)
    debug, a Boolean indicating whether or not this function should be run in debug mode (prints the data, the amount of data per client)
    init_client_data_dict, a dictionary that maps client types to lists of data for each client that is used for each client
      should be of the form
      {
        <int>: [[(X_train, y_train), (X_val, y_val), (pca_obj, pca_reduced_data)],
                ... for each client
              ],
        ... for each client type
      }
    save_pkl, a Boolean indicating whether or not the data should be saved to a pickle file
    mask_grads, a Boolean indicating whether or not the gradients should have a Boolean mask (and parameters in smaller qubit sets should not be updated)
    qubits_and_layers_to_add_block_params, a dictionary that maps client types to a list of (n_qubits, n_layers) representing the block parameters that each client has
    train_models_parallel, a Boolean indicating whether or not the models should be trained in parallel (for personalized models)
    same_init, a Boolean indicating whether or not all the clients should start with the same initial parameters
    feature_skew, a float between 0 and 1 specifying the magnitude of the feature skew (i.e., the strength of sorting the features by the first feature and how prominently that appears
    in the data for clients)
    label_skew, a float between 0 and 1 specifying the magnitude of the label skew that each client has
    do_lda, a Boolean indicating whether or not to perform random sketching
    feat_sel_type, a string representing the choice of features to make in angle encoding for each client
      should be either "top" for selecting the features with the highest variance, or "toplow" for selecting half the features with the highest variance and half the features with the
      lowest variance at each expansion
    amp_embed, a Boolean indicating whether the data should be amplitude encoded
    feat_ordering, a string representing whether the features in amplitude encoding should be taken as-is, or if it should be taken in a different order
      should be either "same" for as-is feature ordering, or "highest_var" for sorting the features in descending order in terms of highest variance
    morepers, a Boolean indicating whether or not only the the convolutional parameters should or should not be aggregated (if not, the convolutional parameters are unique to each client,
    so each client has its own personalized convolutional parameters -- see PDF drawing sent in Slack)

    Returns:
      data_logs, a dictionary of the following form:
        {
          'clients_data_dict': <client_data_dict> of the above form,
          'testing_data': (X_test, y_test),
          0:
            {
              "aggregated_params": <params>,
              <int>: {
                "local_epochs": <int>,
                "client_metrics": [
                    {
                      "trained_params": <params>,
                      "minibatch_losses": list<float>,
                      "validation_losses": list<float>,
                      "training_acc": <float>,
                      "testing_acc": <float>,
                      "testing_loss": <float>,
                      "training_acc_stdev": <float>,
                      "testing_acc_stdev": <float>,
                      "training_acc_topk": <float>,
                      "testing_acc_topk": <float>
                    }, ... (for each client)
                ]
              },
              ... for each client type
            },
          ... for each communication round
          clients_config_arg["communication_rounds"] - 1: (same format as above data logs dictionary)
        }
  """

  # TOADD, generative:
  # 1. An option to this function to indicate that we are having a generative model

  print_cust(f"run_qfl_experiments_parallel, heirarchical_train: {heirarchical_train}")
  # Find the maximum, minimum size clients, as well as the total communication rounds
  max_size_clients = max(clients_config_arg.keys())
  min_size_clients = min(clients_config_arg.keys())

  if custom_debug:
    print_cust(f"run_qfl_experiments_parallel, max_size_clients: {max_size_clients}, min_size_clients, {min_size_clients}")
    assert max_size_clients >= min_size_clients, "Maximum sized client is NOT at least as large as minimum sized client"

  num_total_rounds = max([clients_config_arg[key]["communication_rounds"] for key in clients_config_arg])

  # Note: this does impose some more constraints on the format in which the round info is passed (larger MUST be > smaller; training 4 implies training 8.) going to stick with this for now
  total_rounds_accum = 0
  client_types_to_rounds = {}
  for client_type in sorted(clients_config_arg.keys()):
    client_rounds = clients_config_arg[client_type]["communication_rounds"]
    client_types_to_rounds[client_type] = client_rounds - total_rounds_accum
    total_rounds_accum += client_rounds

  # this should be client type to rounds TO RUN for basically clients that size AND ABOVE.
  print_cust(f"run_qfl_experiments_parallel, client_types_to_rounds: {client_types_to_rounds}")

  if custom_debug:
    print_cust(f"run_qfl_experiments_parallel, num_total_rounds: {num_total_rounds}")
    assert num_total_rounds >= 0, "Number of total rounds is not at least 0"

  # The number of output qubits is enough to accommodate the number of classes.
  n_output_qubits = int(np.ceil(np.log2(len(classes))))

  if custom_debug:
    print_cust(f"run_qfl_experiments_parallel, n_output_qubits: {n_output_qubits}")
    assert n_output_qubits >= 0, "Number of output qubits is not at least 0"

  print_cust(f"run_qfl_experiments_parallel, num_total_rounds: {num_total_rounds}")

  print_cust(f"run_qfl_experiments_parallel, n_output_qubits: {n_output_qubits}")

  # Load the input dataset. If we are doing amplitude embedding or local PCA, then we do not want to dimensionality reduce the input images.
  # Note that the pixels are normalized to be between 0 and 1.
  keep_orig_imgs = (local_pca or amp_embed)

  if custom_debug:
    print_cust(f"run_qfl_experiments_parallel, keep_orig_imgs: {keep_orig_imgs}")

  X_angles, y = load_dataset(dataset_type=dataset_type, classes=classes, n_samples=n_samples, num_feats=max_size_clients, keep_orig_imgs=keep_orig_imgs, custom_debug=custom_debug)

  print_cust(f"run_qfl_experiments_parallel, X_angles.shape: {X_angles.shape}")

  print_cust(f"run_qfl_experiments_parallel, local_pca: {local_pca}, do_lda: {do_lda}")

  # Load in the previous data dictionary, if supplied.
  if init_client_data_dict is not None:
    clients_data_dict = init_client_data_dict["clients_data_dict"]
    (X_test, y_test) = init_client_data_dict["testing_data"]
    # if shared_pca:
    #   shared_max_comps = init_client_data_dict['shared_max_comps']
    #   shared_min_comps = init_client_data_dict['shared_min_comps']
    print_cust(f"run_qfl_experiments_parallel, loaded in existing data")
  else:
    clients_data_dict, (X_test, y_test) = split_data_federated(X_angles, y, clients_config_arg, test_frac, val_frac=val_frac, random_state=random_state, feature_skew=feature_skew, label_skew=label_skew, local_pca=local_pca,
                                                              do_lda=do_lda, feat_sel_type=feat_sel_type, amp_embed=amp_embed, feat_ordering=feat_ordering, shared_pca=shared_pca, fed_pca_mocked=fed_pca_mocked)
    print_cust(f"run_qfl_experiments_parallel, generated new data")

  # Create a log of the information that we have throughout training.
  data_logs = {}

  data_logs["clients_data_dict"] = clients_data_dict
  data_logs["testing_data"] = (X_test, y_test)
  # if shared_pca:
  #   data_logs['shared_max_comps'] = shared_max_comps
  #   data_logs['shared_min_comps'] = shared_min_comps

  # Store the initial set of training parameters and data in a pickle file.
  if save_pkl:
    with open(f"data_logs_n_samples_{n_samples}_dataset_type_{dataset_type}_classes_{'_'.join(classes)}_train_models_parallel_{train_models_parallel}_feature_skew_{feature_skew}_label_skew_{label_skew}_local_pca_{local_pca}_shared_pca_{shared_pca}_gen_{generative}.pkl", "wb") as file:
      pickle.dump(data_logs, file)


  if use_torch:
    math_int = torch
  else:
    math_int = np

  # Maybe right here, have a torch data conversion function ??? (if not done in perform_federated_pca_mocked)
  clients_data_dict = convert_data_to_lib(clients_data_dict, math_int=math_int)
  if math_int == torch:
    X_test = torch.tensor(X_test, dtype=torch.float32)

  if generative:
    print_cust(f"run_qfl_experiments_parallel, X_test.shape: {X_test.shape}")
    if len(X_test.shape) == 2:
      img_dim = math.isqrt(X_test.shape[1])
      assert (img_dim ** 2) == X_test.shape[1], f"run_qfl_experiments_parallel, X_test is not a perfect square, X_test.shape: {X_test.shape}"
      X_test = X_test.view(X_test.shape[0], img_dim, img_dim)
    print_cust(f"run_qfl_experiments_parallel, generative, X_test.max(): {X_test.max()}, X_test.min(): {X_test.min()}")
    save_tensors_to_folder(X_test, f"{targ_data_folder_prefix}", "img")
    n_samples_test = X_test.shape[0]


  # TODO: continue reading here, 7/9
  # also, please put this in a function......
  if shared_pca:
    # NOTE: this function MUTATES clients_data_dict
    pca_global, shared_min_comps, shared_max_comps, inv_pca_min, inv_pca_max = perform_federated_pca_mocked(clients_data_dict, max_size_clients, random_seed=random_state, math_int=math_int, device=device)
    # TODO: ... = perform_federated_pca_mocked()...

  # For sanity, print out the data as well as the amount of data that each client has.
  if debug:
    for client_type in clients_data_dict:
      clients_data_list = clients_data_dict[client_type]
      print_cust(f"client_type: {client_type}, len(clients_data_list): {len(clients_data_list)}")
      for client_idx in range(len(clients_data_list)):
        client_data = clients_data_list[client_idx]
        print_cust(f"client_type: {client_type}, client_idx: {client_idx}, len(client_data): {len(client_data)}")
        if not local_pca:
          (training_data, validation_data) = client_data
        else:
          (training_data, validation_data, pca_info) = client_data
        print_cust(f"client_type: {client_type}, client_idx: {client_idx}, len(training_data): {len(training_data)}, len(validation_data): {len(validation_data)}")
        (X_train, y_train) = training_data
        (X_val, y_val) = validation_data
        print_cust(f"client_type: {client_type}, client_idx: {client_idx}, X_train.shape: {X_train.shape}, y_train.shape: {y_train.shape}, X_val.shape: {X_val.shape}, y_val.shape: {y_val.shape}")

  # Initialize the parameters for each client.
  # qnode_func, device, pca_info
  qnode_func = None
  # device = None
  pca_info = ()
  if generative:
    qnode_func = create_qnode_qgan
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    pca_info = (pca_global, shared_min_comps, shared_max_comps, inv_pca_min, inv_pca_max)
  client_params_dict = initialize_client_params(clients_config_arg, model_size=min_size_clients, cur_client_params_dict=None, qubits_and_layers_to_add_block_params=qubits_and_layers_to_add_block_params, train_models_parallel=train_models_parallel,
                                                n_output_qubits=n_output_qubits, generative=generative, use_torch=use_torch, qnode_func=qnode_func, device=device, pca_info=pca_info)
  print_cust(f"run_qfl_experiments_parallel, first initialization of client params, client_params_dict: {client_params_dict}")

  # Make all parameters the same across clients with a new random initialization if we want clients to have the same initialization.
  if same_init:
    print_cust(f"run_qfl_experiments_parallel, same parameters initialization")
    with torch.no_grad():
      meta_params = generate_meta_params(client_params_dict, clients_data_dict, math_int=math_int)
      meta_params = generate_meta_params_random(meta_params, math_int=math_int)
      client_params_dict = broadcast_param_updates(client_params_dict, meta_params)
      print_cust(f"run_qfl_experiments_parallel, client_params_dict: {client_params_dict}")

  # client_types = sorted(list(clients_config_arg.keys()))

  # Initialize the data logging information for each communication round.

  for round_num in range(num_total_rounds):
    print_cust(f"run_qfl_experiments_parallel, round_num: {round_num}")
    # Initialize the data logging information for this communication round.
    data_logs[round_num] = {}
    data_logs[round_num]["aggregated_params"] = None
    for client_size, cfg in clients_config_arg.items():
      data_logs[round_num][client_size] = {}
      # TODO: add a field for testing loss in data_logs
      data_logs[round_num][client_size]["local_epochs"] = cfg["local_epochs"]
      data_logs[round_num][client_size]["client_metrics"] = []
      for client_idx in range(cfg["num_clients"]):
        data_logs[round_num][client_size]["client_metrics"].append({"trained_params": None,
                                                      "minibatch_losses": None,
                                                      "validation_losses": None,
                                                      "training_acc": None,
                                                      "testing_acc": None,
                                                      "testing_loss": None,
                                                      "training_acc_stdev": None,
                                                      "testing_acc_stdev": None,
                                                      "training_acc_topk": None,
                                                      "testing_acc_topk": None,
                                                      "testing_fid": None})

  rounds_elapsed = 0
  for cur_shared_model_size, model_size_rounds in client_types_to_rounds.items():

    if cur_shared_model_size > min_size_clients and not train_models_parallel:
      # two things to check:
      # 1. initialization with existing parameters - OK
      # 2. expansion from smaller to larger qubits. - easy fix in pca_rescaler
      client_params_dict = initialize_client_params(clients_config_arg, model_size=cur_shared_model_size, cur_client_params_dict=client_params_dict, qubits_and_layers_to_add_block_params=qubits_and_layers_to_add_block_params, train_models_parallel=train_models_parallel,
                                                n_output_qubits=n_output_qubits, generative=generative, use_torch=use_torch, qnode_func=qnode_func, device=device, pca_info=pca_info)
      print_cust(f"run_qfl_experiments_parallel, after expansion, client_params_dict: {client_params_dict}")
      existing_clients_datadict = list(clients_data_dict.keys())
      for existing_client_size in existing_clients_datadict:
        if cur_shared_model_size > existing_client_size:
          del clients_data_dict[existing_client_size]

      print_cust(f"run_qfl_experiments_parallel, not train_models_parallel, clients_data_dict.keys(): {clients_data_dict.keys()}")

    # For each communication round, iterate over each client and have each client train on their dataset for the specified number of epochs.
    for round_num in range(model_size_rounds):
      round_num += rounds_elapsed
      print_cust(f"run_qfl_experiments_parallel, round_num: {round_num}")

      for client_type, client_params in client_params_dict.items():
        # TODO: add a condition, that, based on the current model size, if heirarchical, skips particular client types.
        print_cust(f"run_qfl_experiments_parallel, client_type: {client_type}")
        # Get the clients' data and the number of convolutional layers for clients of this type.
        client_data = clients_data_dict[client_type]
        if train_models_parallel:
          cur_model_size = client_type
        else:
          cur_model_size = cur_shared_model_size
        print_cust(f"run_qfl_experiments_parallel, cur_model_size: {cur_model_size}")
        conv_layers = compute_conv_layers(cur_model_size, n_output_qubits, generative=generative)
        print_cust(f"run_qfl_experiments_parallel, conv_layers: {conv_layers}")
        for client_idx in range(len(client_params)):
          print_cust(f"run_qfl_experiments_parallel, client_idx: {client_idx}")
          # Get the training data and parameters for each client.
          client_params_indiv = client_params[client_idx]
          client_data_indiv = client_data[client_idx]

          client_train_data = client_data_indiv[0]
          client_val_data = client_data_indiv[1]

          if debug:
            print_cust(f"run_qfl_experiments_parallel, len(client_train_data): {len(client_train_data)}")
            print_cust(f"run_qfl_experiments_parallel, client_train_data: {client_train_data}")
            print_cust(f"run_qfl_experiments_parallel, len(client_val_data): {len(client_val_data)}")
            print_cust(f"run_qfl_experiments_parallel, client_val_data: {client_val_data}")

          grad_mask = None

          # Gradient masking hasn't been tested for a while, so currently is commented out.
          # TODO: implement gradient masking!!!!!!

          if mask_grads:
            if cur_model_size > min_size_clients and not generative:
              grad_mask = mask_gradients(client_params_indiv)

          print_cust(f"run_qfl_experiments_parallel, grad_mask: {grad_mask}")

          # Create the feature list indicating what features for which qubits are to be used for this client, as well as the expansion data (used for identity
          # initialization).
          expand = False
          if cur_model_size > min_size_clients:
            expand = True
          if not generative:
            feature_list, expansion_data = create_feat_list_expansion_data(cur_model_size, conv_layers, expand=expand, pool_in=pool_in, min_qubits_noexpand=min_size_clients,
                                                                          train_models_parallel=train_models_parallel, feat_sel_type=feat_sel_type)
          else:
            feature_list, expansion_data = list(range(cur_model_size)), []

          # Create the QNode for this client.
          # NOTE: for generative, there is no need to create a separate qnode; it should already be created for the generative model, upon instantiation.
          if not generative:
            qnode = create_qnode_qcnn(cur_model_size, conv_layers, expansion_data, n_classes=len(classes))

          num_local_epochs = clients_config_arg[client_type]["local_epochs"]

          if not generative:
            client_params_indiv = copy.deepcopy(client_params_indiv)

          # If amplitude embedding the data, reorder the data; otherwise, sample features based on feature_list.
          if amp_embed:
            training_data = reorder_amplitude_data(client_train_data[0], feature_list)
            validation_data = reorder_amplitude_data(client_val_data[0], feature_list)
          else:
            training_data = client_train_data[0][:, feature_list]
            validation_data = client_val_data[0][:, feature_list]

          # Train the client for the specified number of epochs.
          # TODO: update the grad_mask here.
          if not generative:
            trained_params, minibatch_losses, validation_losses = train_epochs_angle_param_adam(
          client_params_indiv, training_data, client_train_data[1], validation_data, client_val_data[1],
          n_epochs=num_local_epochs, shots=shots, batch_size=local_batch_size, lr=local_lr, qnode=qnode, trainable_mask=grad_mask)
            # Update the client's trained parameters.
            client_params[client_idx] = trained_params
          else:
            # I mean, I could return the minibatch_losses, validation_losses, but idt I will.
            print_cust(f"run_qfl_experiments_parallel, generative is True (QGAN training)")
            client_generator = client_params_indiv[5][0]
            print_cust(f"run_qfl_experiments_parallel, type(client_generator): {type(client_generator)}")
            client_discriminator = client_params_indiv[5][1]
            print_cust(f"run_qfl_experiments_parallel, type(client_discriminator): {type(client_discriminator)}")
            print_cust(f"run_qfl_experiments_parallel, nn.ParameterList([client_generator.q_params[-1]]: {nn.ParameterList([client_generator.q_params[-1]])}")
            client_optim_gen = torch.optim.SGD(nn.ParameterList([client_generator.q_params[-1]]), lr=lr_gen)
            client_optim_disc = torch.optim.SGD(client_discriminator.parameters(), lr=lr_disc)
            # maybe concatenate in the train with val data? not using val data otherwise
            # can use test_result_imgs for visualization, later, if I'd like.
            test_result_imgs = train_models(client_generator.n_qubits_gen, local_batch_size, client_generator, client_discriminator, client_optim_disc, client_optim_gen,
                                            noise_func=noise_func, criterion=criterion_func(), train_data=training_data, device=device, image_size=0, compressed_img_size=0, max_num_epochs=num_local_epochs,
                                            n_qubits_small=0, gen_pcas=True, disc_img_size=None, pca_disc=True)

          # Perform PCA on the testing data, if specified.
          if local_pca and not generative:
            client_pca_info = client_data_indiv[2]
            client_pca = client_pca_info[0]
            client_data_pca = client_pca_info[1]
            X_test_client_pca = client_pca.transform(X_test)
            # Scale each PCA component independently to [0, π]
            X_test_client_angle = math_int.zeros_like(X_test_client_pca)
            # assuming that the number of components is simply the client's data type
            # TODO: rescale the testing data to match the PCA scale for each client
            for i in range(client_type):
                comp = X_test_client_pca[:, i]
                if shared_pca:
                  lo, hi = shared_min_comps[i], shared_max_comps[i]
                  comp_norm = ( (comp - lo) / (hi - lo + 1e-8) )
                  # comp_norm = np.clip(comp_norm, 0, 1)
                else:
                  orig_comp = client_data_pca[:, i]
                  comp_norm = (comp - orig_comp.min()) / (orig_comp.max() - orig_comp.min() + 1e-8)
                X_test_client_angle[:, i] = comp_norm * math_int.pi
          else:
            X_test_client_angle = X_test

          # Select the features for amplitude encoding based on feat_ordering.
          if amp_embed:
            if feat_ordering == "highest_var":
              variances = X_angles.var(axis=0, ddof=0)
              order = np.argsort(variances)[::-1]
              X_test_client_angle = X_test_client_angle[:, order] + 1e-3
            testing_data = reorder_amplitude_data(X_test_client_angle, feature_list)
          else:
            if not generative:
              testing_data = X_test_client_angle[:, feature_list]

            # TODO: encapsulate this logic in a helper function


          # train_acc = compute_avg_acc_angle_param_batch(trained_params, client_train_data[0][:, feature_list], client_train_data[1], layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode)

          # test_acc = compute_avg_acc_angle_param_batch(trained_params, X_test_client_angle[:, feature_list], y_test, layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode)

          # test_loss = compute_loss_angle_param_batch(trained_params, X_test_client_angle[:, feature_list], y_test, shots=shots, qnode=qnode)

          # train_acc_stdev = compute_std_acc_angle_param_batch(trained_params, client_train_data[0][:, feature_list], client_train_data[1], layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode)

          # test_acc_stdev = compute_std_acc_angle_param_batch(trained_params, X_test_client_angle[:, feature_list], y_test, layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode)

          # train_acc_topk = compute_top_k_acc_angle_param_batch(trained_params, client_train_data[0][:, feature_list], client_train_data[1], layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode)

          # test_acc_topk = compute_top_k_acc_angle_param_batch(trained_params, X_test_client_angle[:, feature_list], y_test, layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode)

          # Compute the training and testing metrics.
          if not generative:
            train_acc, train_acc_stdev, train_acc_topk, train_loss = compute_metrics_angle_param_batch(trained_params, training_data, client_train_data[1], layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode)

            # Store the training and testing metrics.
            data_logs[round_num][cur_model_size]["client_metrics"][client_idx]["training_acc"] = train_acc
            # data_logs[round_num][cur_model_size]["client_metrics"][client_idx]["testing_acc"] = test_acc
            # data_logs[round_num][cur_model_size]["client_metrics"][client_idx]["testing_loss"] = test_loss

            data_logs[round_num][cur_model_size]["client_metrics"][client_idx]["training_acc_stdev"] = train_acc_stdev
            # data_logs[round_num][cur_model_size]["client_metrics"][client_idx]["testing_acc_stdev"] = test_acc_stdev

            data_logs[round_num][cur_model_size]["client_metrics"][client_idx]["training_acc_topk"] = train_acc_topk
            # data_logs[round_num][cur_model_size]["client_metrics"][client_idx]["testing_acc_topk"] = test_acc_topk

            data_logs[round_num][cur_model_size]["client_metrics"][client_idx]["trained_params"] = trained_params
            data_logs[round_num][cur_model_size]["client_metrics"][client_idx]["minibatch_losses"] = minibatch_losses
            data_logs[round_num][cur_model_size]["client_metrics"][client_idx]["validation_losses"] = validation_losses

      # perform aggregation
      # NOTE: assume, during aggregation, that the parameters for all the clients have the same shape.
      # Aggregate parameters based on the strategy provided.
      print_cust(f"run_qfl_experiments_parallel, round_num: {round_num}, parameter aggregation, agg_strategy: {agg_strategy}")

      print_cust(f"run_qfl_experiments_parallel, round_num: {round_num}, before aggregation, client_params_dict: {client_params_dict}")

      if agg_strategy == "fedavg":
        aggregated_params = federated_averaging(client_params_dict, clients_data_dict)
      elif agg_strategy == "fedavg_quat":
        aggregated_params = federated_averaging_quat(client_params_dict, clients_data_dict)
      elif agg_strategy == "fedavg_circ":
        aggregated_params = federated_averaging_circular_parallel_shared(client_params_dict, clients_data_dict, generative=generative)
        # print_cust(f"run_qfl_experiments, aggregated_params: {aggregated_params}")

      print_cust(f"run_qfl_experiments_parallel, round_num: {round_num}, after aggregation, aggregated_params: {aggregated_params}")

      # # broadcast parameter updates
      # for client_type, client_params in client_params_dict.items():
      #   for client_idx in range(len(client_params)):
      #     client_params[client_idx] = aggregated_params

      # client_params_dict = broadcast_param_updates_shared(client_params_dict, aggregated_params)

      # Broadcast parameters based on the personalization scheme.
      if morepers == "aggshared":
        client_params_dict = broadcast_param_updates_shared(client_params_dict, aggregated_params)
      elif morepers == "aggnoconv":
        client_params_dict = broadcast_param_updates_shared_noconv(client_params_dict, aggregated_params)
      elif morepers == "aggnoconvnofinal":
        client_params_dict = broadcast_param_updates_shared_noconv_nofinal(client_params_dict, aggregated_params)

      print_cust(f"run_qfl_experiments_parallel, round_num: {round_num}, after broadcasting, client_params_dict: {client_params_dict}")

      # Store the aggregated parameters for this round.
      # TODO: aggregated parameters don't have a significant meaning as not all of them are broadcasted to clients based on the broadcasting/personalization scheme
      data_logs[round_num]["aggregated_params"] = aggregated_params


      # TODO: recompute metrics of the SHARED model for EACH client.
      one_cli_size = list(client_params_dict.keys())[0]

      if generative:
        # NOTE: this assumes that all generators and discriminators have the same parameters.
        agg_model = client_params_dict[one_cli_size][0][5][0]
        # agg_disc = client_params_dict[one_cli_size][0][5][1]
      else:
        agg_model = client_params_dict[one_cli_size][0]


      test_acc, test_acc_stdev, test_acc_topk, test_loss, test_fid_score = (None, None, None, None, None)

      if fid_batch_size is None: # for now, for debug
        fid_batch_size = n_samples_test

      if num_total_rounds > 10:
        if (round_num + 1) % 10 == 0:
          if not generative:
            # TODO: this is broken for the non-generative case. process the testing data outside of the loop for each client.. not doing that now tho; not too relevant for
            # generative case
            test_acc, test_acc_stdev, test_acc_topk, test_loss = compute_metrics_angle_param_batch(agg_model, testing_data, y_test, layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode)
          else:
            folder_id_suffix = f"round_num_{round_num}_cur_shared_model_size_{cur_shared_model_size}"
            test_fid_score = compute_fid_to_data(agg_model, noise_func, f"{targ_data_folder_prefix}",
                                                f"{gen_data_folder_prefix}/{folder_id_suffix}", n_samples_test, device, fid_batch_size=fid_batch_size)
        else:
          test_acc, test_acc_stdev, test_acc_topk, test_loss = (None, None, None, None)
      else:
        if not generative:
          test_acc, test_acc_stdev, test_acc_topk, test_loss = compute_metrics_angle_param_batch(agg_model, testing_data, y_test, layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode)
        else:
          folder_id_suffix = f"round_num_{round_num}_cur_shared_model_size_{cur_shared_model_size}"
          test_fid_score = compute_fid_to_data(agg_model, noise_func, f"{targ_data_folder_prefix}",
                                              f"{gen_data_folder_prefix}/{folder_id_suffix}", n_samples_test, device, fid_batch_size=fid_batch_size)
      if not generative:
        print_cust(f"run_qfl_experiments_parallel, round_num: {round_num}, test_acc: {test_acc}, test_loss: {test_loss}, test_acc_stdev: {test_acc_stdev}, test_acc_topk: {test_acc_topk}")
      else:
        print_cust(f"run_qfl_experiments_parallel, round_num: {round_num}, cur_model_size: {cur_model_size}, test_fid_score: {test_fid_score}")

    rounds_elapsed += model_size_rounds

  return data_logs

"""## QFL Experiment Parallel Multiprocess"""

import copy
import pickle
# import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed, ThreadPoolExecutor
from functools import partial
from typing import Dict, Tuple, Any, List

"""### QFL Experiment Train Single Client Helper Function"""

import sys, contextlib

def _train_single_client(job: Dict[str, Any]) -> Dict[str, Any]:
    """
    Worker executed in a separate process.

    Returns a dict with:
        client_type, client_idx,
        trained_params, minibatch_losses, validation_losses,
        train_acc, train_acc_stdev, train_acc_topk,
        test_acc,  test_acc_stdev,  test_acc_topk, test_loss
    """
    # --- unpack ------------------------------------------------------------
    # NOTE, layers, for client_params_indiv: it is ALL OF THE CLIENT PARAMS. list of block params is also there.
    (client_type, client_idx, client_params_indiv,
     client_train_data, client_val_data, testing_data,
     cur_model_size, min_size_clients, pool_in, feat_sel_type,
     train_models_parallel, amp_embed, shots, local_batch_size,
     local_lr, num_local_epochs, conv_layers, feat_ordering, classes,
     qnode_builder, num_total_rounds, round_num, grad_mask, generative, lr_gen, lr_disc,
     noise_func, criterion_func, log_data_folder, device, client_optims_indiv, optim_type,
     gen_betas, disc_betas, use_torch, pennylane_interface, opt_layers, layer_types_list,
     loss_type, lr_disc_decay, cont_optim_state) = (job[k] for k in
         ("client_type","client_idx","client_params_indiv",
          "client_train_data","client_val_data","testing_data",
          "cur_model_size","min_size_clients","pool_in","feat_sel_type",
          "train_models_parallel","amp_embed","shots","local_batch_size",
          "local_lr","num_local_epochs","conv_layers","feat_ordering","classes",
          "qnode_builder","num_total_rounds","round_num","grad_mask","generative",
          "lr_gen","lr_disc","noise_func","criterion_func","log_data_folder","device",
          "client_optims_indiv","optim_type","gen_betas","disc_betas","use_torch","pennylane_interface","opt_layers",
          "layer_types_list", "loss_type", "lr_disc_decay", "cont_optim_state"))

    with open(f"{log_data_folder}/cli_type_{client_type}_client_idx_{client_idx}_round_num_{round_num}_stdout.txt", "w") as fout_loc, open(f"{log_data_folder}/cli_type_{client_type}_client_idx_{client_idx}_round_num_{round_num}_stderr.txt", "w") as ferr_loc:
      with contextlib.redirect_stdout(fout_loc), contextlib.redirect_stderr(ferr_loc):

        print_cust(f"_train_single_client, opt_layers: {opt_layers}")

        print_cust(f"_train_single_client, device: {device}")

        print_cust(f"_train_single_client, qnode_builder: {qnode_builder}")

        print_cust(f"_train_single_client, lr_gen: {lr_gen}, lr_disc: {lr_disc}")

        # TODO: won't work for nongenerative case, but not going to worry about that for now.
        print_cust(f"_train_single_client, client_optims_indiv: {client_optims_indiv}")

        print_cust(f"_train_single_client, optim_type: {optim_type}")

        print_cust(f"_train_single_client, gen_betas: {gen_betas}, disc_betas: {disc_betas}")

        print_cust(f"_train_single_client, grad_mask: {grad_mask}, generative: {generative}")

        print_cust(f"_train_single_client, client_train_data: {client_train_data}")

        print_cust(f"_train_single_client, conv_layers: {conv_layers}")

        print_cust(f"_train_single_client, round_num: {round_num}, lr_disc_decay: {lr_disc_decay}")

        print_cust(f"_train_single_client, cont_optim_state: {cont_optim_state}")

        if not generative and conv_layers == 0:
          # NOTE, layers: override cur_model_size
          num_qubits_bps = []
          for block_param in client_params_indiv[5]:
            num_qubits_bps.append(block_param.shape[1])

          print_cust(f"_train_single_client, num_qubits_bps: {num_qubits_bps}")

          # NOTE, layers: this is a logical override.
          cur_model_size = max(num_qubits_bps)

          print_cust(f"_train_single_client, not generative and conv_layers == 0, cur_model_size: {cur_model_size}")

        # ----------------------------------------------------------------------
        # build feature list + QNode locally (nothing un‑pickle‑able is sent in)
        # DONE (ish): TOMODIFY, Layers: don't have any expansion data for the layers expansion (for now)
        # ^ TOMODIFY, LAYERS, HACK (ish): not called if conv_layers 0.
        if not generative and conv_layers > 0:
          print_cust(f"_train_single_client, calling create_feat_list_expansion_data")
          feature_list, expansion_data = create_feat_list_expansion_data(
              cur_model_size, conv_layers, expand=cur_model_size>min_size_clients,
              pool_in=pool_in, min_qubits_noexpand=min_size_clients,
              train_models_parallel=train_models_parallel,
              feat_sel_type=feat_sel_type)
        else:
          feature_list, expansion_data = list(range(cur_model_size)), []
        print_cust(f"_train_single_client, feature_list: {feature_list}, expansion_data: {expansion_data}")
        if not generative and conv_layers > 0:
          # DONE: TOMODIFY, layers: inject pennylane_interface here.
          qnode = qnode_builder(cur_model_size, conv_layers,
                                expansion_data, n_classes=len(classes), pennylane_interface=pennylane_interface)
        # DONE: TOMODIFY, layers: have another option for creating a quantumclassifer using the builder, *BUT*
        # note that the input now is the list of tensors representing the variational parameters.
        # DONE: TOMODIFY, layers: inject pennylane_interface here.
        elif not generative:
          print_cust(f"_train_single_client, conv_layers should be 0, conv_layers: {conv_layers}")

          # NOCHANGE: TOMODIFY, depthFL: add an arg for number of classifiers to create.
          # ^ TOMODIFY, depthFL: have some kind of argument to CONTROL the number/amount of classifiers I want to create.
          # classifier_data_comps = [cur_model_size, conv_layers, expansion_data, len(classes), pennylane_interface, layer_types_list, device]
          # DONE: BUG, depthFL: classifier_data_comps needs to include device somehow.... but doesn't work if it's now a dictionary.. probably, change build_vqc
          # function.
          classifier_data_comps = {
            "n_data": cur_model_size,
            "conv_layers": conv_layers,
            "expansion_data": expansion_data,
            "n_classes": len(classes),
            "pennylane_interface": pennylane_interface,
            "layer_types_list": layer_types_list,
            "device": device
          }
          cur_classifier_params = client_params_indiv
          print_cust(f"_train_single_client, cur_classifier_params: {cur_classifier_params}")
          cli_classifier_loc = build_variationalquantumclassifier(classifier_data_comps, cur_classifier_params, qnode_builder)
          print_cust(f"_train_single_client, cli_classifier_loc: {cli_classifier_loc}")
        else:
          (gen_state_dict, gen_metadata) = client_params_indiv[5][0]
          (disc_state_dict, disc_metadata) = client_params_indiv[5][1]
          print_cust(f"_train_single_client, gen_state_dict: {gen_state_dict}, gen_metadata: {gen_metadata}")
          print_cust(f"_train_single_client, disc_state_dict: {disc_state_dict}, disc_metadata: {disc_metadata}")
          cli_generator_loc = build_patchquantumgen(gen_metadata, gen_state_dict, qnode_builder)
          cli_disc_loc = build_pca_discriminator(disc_metadata, disc_state_dict)
          print_cust(f"_train_single_client, cli_generator_loc: {cli_generator_loc}")
          print_cust(f"_train_single_client, cli_disc_loc: {cli_disc_loc}")




        # ----------------------------------------------------------------------
        # slice / reorder data --------------------------------------------------
        if amp_embed:
            # TOMODIFY, layers, amp embed: comment this out. hopefully, don't need to do this; just do vanilla ampencode.
            # TOMODIFY, layers: what I should do is have a condition; if feature_list is just the range of qubits, then
            # don't do anything. Otherwise, call reorder_amplitude_data.
            print_cust(f"_train_single_client, amp_embed is True")
            # X_train = reorder_amplitude_data(client_train_data[0], feature_list)
            # X_val   = reorder_amplitude_data(client_val_data[0],  feature_list)
            # X_test  = reorder_amplitude_data(testing_data[0],      feature_list)
            X_train = client_train_data[0]
            X_val = client_val_data[0]
            X_test = testing_data[0]
        else:
            X_train = client_train_data[0][:, feature_list]
            X_val   = client_val_data[0][:,  feature_list]
            if not generative:
              X_test  = testing_data[0][:,     feature_list]
              # DONE (NO CHANGE): TOMODIFY, layers: X_test needs to be ordered in terms of feature_list.
              # ^ I mean, variational QNN is not generative

        # ----------------------------------------------------------------------
        # local training --------------------------------------------------------
        if not generative and conv_layers > 0:
          trained_params, minibatch_losses, validation_losses = train_epochs_angle_param_adam(
              copy.deepcopy(client_params_indiv),
              X_train, client_train_data[1],
              X_val,   client_val_data[1],
              n_epochs=num_local_epochs, shots=shots,
              batch_size=local_batch_size, lr=local_lr,
              qnode=qnode, trainable_mask=grad_mask)           # grad_mask handled outside

          # ----------------------------------------------------------------------
          # metrics ---------------------------------------------------------------
          if use_torch:
            math_int = torch
          else:
            math_int = np

          print_cust(f"_train_single_client, math_int: {math_int}")

          train_acc, train_acc_stdev, train_acc_topk, _ = \
              compute_metrics_angle_param_batch(trained_params, X_train,
                                                client_train_data[1],
                                                layers=conv_layers, shots=shots,
                                                batch_size=local_batch_size,
                                                qnode=qnode, math_int=math_int)
        elif not generative:
          # DONE: TOMODIFY, layers: call a different training function
          # NOT IMPLEMENTED (should be implemented in the main func): and use the updated compute_metrics_angle_param_batch function.
          # DONE: TOMODIFY, layers: initialize my optimizer from my optimizer state_dict, as before.
          print_cust(f"_train_single_client, conv_layers should be 0, conv_layers: {conv_layers}")
          client_classifier = cli_classifier_loc
          print_cust(f"_train_single_client, client_classifier: {client_classifier}")
          curr_cli_blockparams = client_classifier.q_params
          list_params = []
          if opt_layers is None:
            list_params = [block_param_torch for block_param_torch in curr_cli_blockparams]
          else:
            for layer_idx in opt_layers:
              layer_opt_param = curr_cli_blockparams[layer_idx]
              list_params.append(layer_opt_param)

          if optim_type == "sgd":
            client_optim_classifier = torch.optim.SGD(nn.ParameterList(list_params), lr=lr_disc * (lr_disc_decay ** round_num))
          elif optim_type == "adam":
            client_optim_classifier = torch.optim.Adam(nn.ParameterList(list_params), lr=lr_disc * (lr_disc_decay ** round_num), betas=disc_betas)

          if client_optims_indiv is not None and cont_optim_state:
            # NOTE: this is a SHALLOW copy of loading state dict from previous client optimizer.
            # HACK, depthFL: loading in state dict into client_optim_classifier only in non-SGD is a hack; figure out
            # how to dynamically change LR.
            print_cust(f"_train_single_client, loading state dict for client_optim_classifier")
            client_optim_classifier.load_state_dict(client_optims_indiv[5])

          print_cust(f"_train_single_client, client_optim_classifier: {client_optim_classifier}")

          print_cust(f"_train_single_client, X_train: {X_train}")

          # TODO, 8/4, 6:38 PM: continue here
          train_metrics_dict = train_models_classifier(local_batch_size, client_classifier, client_optim_classifier, criterion_func(), X_train, client_train_data[1], device, num_local_epochs, loss_type=loss_type)

          trained_classifier_params = client_classifier.q_params

          print_cust(f"_train_single_client, trained_classifier_params: {trained_classifier_params}")

          trained_optim_classifier = client_optim_classifier.state_dict()

          print_cust(f"_train_single_client, trained_optim_classifier: {trained_optim_classifier}")

          # MODIFY, layers: call compute_metrics_angle_param_batch to get the training accuracies.

        else:
          # I mean, I could return the minibatch_losses, validation_losses, but idt I will.
          print_cust(f"_train_single_client, generative is True (QGAN training)")
          client_generator = cli_generator_loc
          print_cust(f"_train_single_client, client_generator: {client_generator}")
          client_discriminator = cli_disc_loc
          print_cust(f"_train_single_client, client_discriminator: {client_discriminator}")
          # client_optim_gen = torch.optim.SGD(nn.ParameterList([client_generator.q_params[-1]]), lr=lr_gen)
          # client_optim_disc = torch.optim.SGD(client_discriminator.parameters(), lr=lr_disc)
          if optim_type == "sgd":
            client_optim_gen = torch.optim.SGD(nn.ParameterList([client_generator.q_params[-1]]), lr=lr_gen)
            client_optim_disc = torch.optim.SGD(client_discriminator.parameters(), lr=lr_disc * (lr_disc_decay ** round_num))
          elif optim_type == "adam":
            client_optim_gen = torch.optim.Adam(nn.ParameterList([client_generator.q_params[-1]]), lr=lr_gen, betas=gen_betas)
            client_optim_disc = torch.optim.Adam(client_discriminator.parameters(), lr=lr_disc * (lr_disc_decay ** round_num), betas=disc_betas)

          if client_optims_indiv is not None:
            client_optim_gen.load_state_dict(client_optims_indiv[5][0])
            client_optim_disc.load_state_dict(client_optims_indiv[5][1])

          print_cust(f"_train_single_client, client_optim_gen: {client_optim_gen}")
          print_cust(f"_train_single_client, client_optim_disc: {client_optim_disc}")
          # maybe concatenate in the train with val data? not using val data otherwise
          # can use test_result_imgs for visualization, later, if I'd like.
          print_cust(f"_train_single_client, X_train: {X_train}")
          test_result_imgs, train_metrics_dict = train_models(client_generator.n_qubits_gen, local_batch_size, client_generator, client_discriminator, client_optim_disc, client_optim_gen,
                                          noise_func=noise_func, criterion=criterion_func(), train_data=X_train, device=device, image_size=0, compressed_img_size=0, max_num_epochs=num_local_epochs,
                                          n_qubits_small=0, gen_pcas=True, disc_img_size=None, pca_disc=True)
          print_cust(f"_train_single_client, train_metrics_dict: {train_metrics_dict}")
          trained_gen_params = client_generator.state_dict()
          trained_disc_params = client_discriminator.state_dict()

          print_cust(f"_train_single_client, trained_gen_params: {trained_gen_params}")
          print_cust(f"_train_single_client, trained_disc_params: {trained_disc_params}")

          trained_optim_gen = client_optim_gen.state_dict()
          trained_optim_disc = client_optim_disc.state_dict()

          print_cust(f"_train_single_client, trained_optim_gen: {trained_optim_gen}")
          print_cust(f"_train_single_client, trained_optim_disc: {trained_optim_disc}")

        test_acc, test_acc_stdev, test_acc_topk, test_loss = (None, None, None, None)

        # test_acc, test_acc_stdev, test_acc_topk, test_loss = \
        #     compute_metrics_angle_param_batch(trained_params, X_test,
        #                                       testing_data[1],
        #                                       layers=conv_layers, shots=shots,
        #                                       batch_size=local_batch_size,
        #                                       qnode=qnode)

        # if num_total_rounds > 10:
        #   if (round_num + 1) % 10 == 0:
        #     test_acc, test_acc_stdev, test_acc_topk, test_loss = compute_metrics_angle_param_batch(trained_params, X_test, testing_data[1], layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode)
        #   else:
        #     test_acc, test_acc_stdev, test_acc_topk, test_loss = (None, None, None, None)
        # else:
        #   test_acc, test_acc_stdev, test_acc_topk, test_loss = compute_metrics_angle_param_batch(trained_params, X_test, testing_data[1], layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode)
        # ----------------------------------------------------------------------
        if not generative and conv_layers > 0:
          ret_dict = dict(client_type=client_type,
                    client_idx=client_idx,
                    trained_params=trained_params,
                    minibatch_losses=minibatch_losses,
                    validation_losses=validation_losses,
                    train_acc=train_acc,
                    test_acc=test_acc,
                    test_loss=test_loss,
                    train_acc_stdev=train_acc_stdev,
                    test_acc_stdev=test_acc_stdev,
                    train_acc_topk=train_acc_topk,
                    test_acc_topk=test_acc_topk)
        elif not generative:
          # NOTE, layers: trying to match the same interface as in the generative case.
          # TOMODIFY, layers: add additional training metrics (training accuracies) here.
          ret_dict = dict(client_type=client_type,
                          client_idx=client_idx,
                          trained_disc_params=trained_classifier_params,
                          trained_optim_disc=trained_optim_classifier,
                          train_metrics_dict=train_metrics_dict)
        else:
          ret_dict = dict(client_type=client_type,
                          client_idx=client_idx,
                          trained_gen_params=trained_gen_params,
                          trained_disc_params=trained_disc_params,
                          trained_optim_gen=trained_optim_gen,
                          trained_optim_disc=trained_optim_disc,
                          train_metrics_dict=train_metrics_dict)
        # DONE: TOMODIFY, layers: return new object, with trained layers params (send back the list of block params for consistency), and send back the
        # torch optimizer state dict. (Later -- can send back a training metrics dictionary.)
        # DONE (just did the point above): TOMODIFY, layers: for minimum viable things, send back whatever metrics are necessary.

    return ret_dict


"""### DummyExecutor Class"""
from concurrent.futures import Future, Executor
from threading import Lock
class DummyExecutor(Executor):

    def __init__(self, max_workers=0):
        self._shutdown = False
        self._shutdownLock = Lock()
        self.max_workers = max_workers

    def submit(self, fn, *args, **kwargs):
        with self._shutdownLock:
            if self._shutdown:
                raise RuntimeError('cannot schedule new futures after shutdown')

            f = Future()
            try:
                result = fn(*args, **kwargs)
            except BaseException as e:
                f.set_exception(e)
            else:
                f.set_result(result)

            return f

    def shutdown(self, wait=True):
        with self._shutdownLock:
            self._shutdown = True

"""### QFL Experiment Parallel Multiprocess Function"""

import os

def run_qfl_experiments_parallel_multiprocess(clients_config_arg, classes=["4", "9"], n_samples=1000, dataset_type="mnist", agg_strategy="fedavg", test_frac=0.2, val_frac=0.1, random_state=42, pool_in=True,
                        local_batch_size=32, local_lr=0.01, shots=1024, debug=False, init_client_data_dict=None, save_pkl=False, mask_grads=False, qubits_and_layers_to_add_block_params=[],
                        train_models_parallel=False, same_init=False, feature_skew=0.0, label_skew=None, local_pca=False, do_lda=False, feat_sel_type="top", amp_embed=False, feat_ordering="same",
                                 morepers=False, custom_debug=False, shared_pca=False, heirarchical_train=False, generative=False, use_torch=False, fed_pca_mocked=True, lr_gen=0.004, lr_disc=0.001,
                                 noise_func=generate_latent_noise, criterion_func=nn.BCELoss, targ_data_folder_prefix="testing_gen_imgs", gen_data_folder_prefix="qgan_gen_imgs", device=None, fid_batch_size=None,
                                              max_workers=None, mp_ctx=None, log_data_folder="local_placeholder", initial_supp_params=None, optim_type="sgd", gen_betas=(0.5, 0.9), disc_betas=(0.5, 0.9),
                                              resc_invpca=True, compute_fid=True, is_qcnn=True, pennylane_interface="autograd", opt_layers=[-1], alt_zeros_init="", multiclassifier_type="", testacc_rd_cutoff=101,
                                              qubits_and_layer_types_block_params=[], loss_type="depthfl", lr_disc_decay=1.0, cont_optim_state=False):

    """
    Function that runs the QFL workflow for a given configuration and returns a log of the training.

    Note: in general, the client parameters dictionary is of the form:

        {
          <int>: [[conv_params, pool_params, final_pool_params,
              final_params, bias_param, block_params],
              ... for each client],
          ... for each client type
        }

    Parameters:
      clients_config_arg, a dictionary mapping client types to configuration information for clients of that type. For this code, I expect clients_config_arg to be of the form:
        {
          <int>: {
              "percentage_data": <float>,
              "num_clients": <int>,
              "local_epochs": <int>,
              "communication_rounds": <int>
          },
          ... for each client type
        }
      classes, a list of strings (preferably string representations of integers) representing the classes on which to perform classification
        ex: ["4", "9"]
      n_samples, an integer representing the number of samples to use for training
      dataset_type, a string representing the dataset type for training
        should either be "mnist", "Fashion-MNIST", "cifar10", "synthetic", "pima", "higgs", or "covertype" for those respective datasets
      agg_strategy, a string representing the aggregation scheme to perform
        should be either "fedavg", "fedavg_quat", or "fedavg_circ"
      test_frac, a float representing the fraction of the total data to be used for testing
        should be between 0 and 1
      val_frac, a float representing the fraction of each client's data to be used for validation
        should be between 0 and 1
      random_state, an integer representing the random state used in the entire program (assuming that this function generates client data -- TODO: make the use of the random
      state argument more clear)
      pool_in, a Boolean representing whether or not the model should have reduce the number of qubits to the center
      local_batch_size, an integer representing the batch size each local model should use for training
      local_lr, a float representing the learning rate each model should use for training, upon each communication round
      shots, an integer representing the number of shots each model should use (currently unused)
      debug, a Boolean indicating whether or not this function should be run in debug mode (prints the data, the amount of data per client)
      init_client_data_dict, a dictionary that maps client types to lists of data for each client that is used for each client
        should be of the form
        {
          <int>: [[(X_train, y_train), (X_val, y_val), (pca_obj, pca_reduced_data)],
                  ... for each client
                ],
          ... for each client type
        }
      save_pkl, a Boolean indicating whether or not the data should be saved to a pickle file
      mask_grads, a Boolean indicating whether or not the gradients should have a Boolean mask (and parameters in smaller qubit sets should not be updated)
      qubits_and_layers_to_add_block_params, a dictionary that maps client types to a list of (n_qubits, n_layers) representing the block parameters that each client has
      train_models_parallel, a Boolean indicating whether or not the models should be trained in parallel (for personalized models)
      same_init, a Boolean indicating whether or not all the clients should start with the same initial parameters
      feature_skew, a float between 0 and 1 specifying the magnitude of the feature skew (i.e., the strength of sorting the features by the first feature and how prominently that appears
      in the data for clients)
      label_skew, a float between 0 and 1 specifying the magnitude of the label skew that each client has
      do_lda, a Boolean indicating whether or not to perform random sketching
      feat_sel_type, a string representing the choice of features to make in angle encoding for each client
        should be either "top" for selecting the features with the highest variance, or "toplow" for selecting half the features with the highest variance and half the features with the
        lowest variance at each expansion
      amp_embed, a Boolean indicating whether the data should be amplitude encoded
      feat_ordering, a string representing whether the features in amplitude encoding should be taken as-is, or if it should be taken in a different order
        should be either "same" for as-is feature ordering, or "highest_var" for sorting the features in descending order in terms of highest variance
      morepers, a Boolean indicating whether or not only the the convolutional parameters should or should not be aggregated (if not, the convolutional parameters are unique to each client,
      so each client has its own personalized convolutional parameters -- see PDF drawing sent in Slack)

      Returns:
        data_logs, a dictionary of the following form:
          {
            'clients_data_dict': <client_data_dict> of the above form,
            'testing_data': (X_test, y_test),
            0:
              {
                "aggregated_params": <params>,
                <int>: {
                  "local_epochs": <int>,
                  "client_metrics": [
                      {
                        "trained_params": <params>,
                        "minibatch_losses": list<float>,
                        "validation_losses": list<float>,
                        "training_acc": <float>,
                        "testing_acc": <float>,
                        "testing_loss": <float>,
                        "training_acc_stdev": <float>,
                        "testing_acc_stdev": <float>,
                        "training_acc_topk": <float>,
                        "testing_acc_topk": <float>
                      }, ... (for each client)
                  ]
                },
                ... for each client type
              },
            ... for each communication round
            clients_config_arg["communication_rounds"] - 1: (same format as above data logs dictionary)
          }
    """
    print_cust(f"run_qfl_experiments_parallel_multiprocess, alt_zeros_init: {alt_zeros_init}")
    print_cust(f"run_qfl_experiments_parallel_multiprocess, opt_layers: {opt_layers}")

    print_cust(f"run_qfl_experiments_parallel_multiprocess, compute_fid: {compute_fid}")
    print_cust(f"run_qfl_experiments_parallel_multiprocess, lr_gen: {lr_gen}, lr_disc: {lr_disc}")
    print_cust(f"run_qfl_experiments_parallel_multiprocess, resc_invpca: {resc_invpca}")

    print_cust(f"run_qfl_experiments_parallel_multiprocess, lr_disc_decay: {lr_disc_decay}")

    print_cust(f"run_qfl_experiments_parallel_multiprocess, cont_optim_state: {cont_optim_state}")

    # TOADD, generative:
    # 1. An option to this function to indicate that we are having a generative model

    # TOADD, layers: an argument for the classification circuit WITHOUT QCNN part.

    # TOMODIFY, layers (NOT strictly necessary, for now): add an argument to this function specifying we are doing layer expansion

    print_cust(f"run_qfl_experiments_parallel_multiprocess, generative: {generative}")

    print_cust(f"run_qfl_experiments_parallel_multiprocess, mp_ctx: {mp_ctx}")

    print_cust(f"run_qfl_experiments_parallel_multiprocess, heirarchical_train: {heirarchical_train}")

    print_cust(f"run_qfl_experiments_parallel_multiprocess, log_data_folder: {log_data_folder}")

    print_cust(f"run_qfl_experiments_parallel_multiprocess, optim_type: {optim_type}")
    # Find the maximum, minimum size clients, as well as the total communication rounds
    max_size_clients = max(clients_config_arg.keys())
    min_size_clients = min(clients_config_arg.keys())

    if custom_debug:
      print_cust(f"run_qfl_experiments_parallel_multiprocess, max_size_clients: {max_size_clients}, min_size_clients, {min_size_clients}")
      assert max_size_clients >= min_size_clients, "Maximum sized client is NOT at least as large as minimum sized client"

    num_total_rounds = max([clients_config_arg[key]["communication_rounds"] for key in clients_config_arg])

    # Note: this does impose some more constraints on the format in which the round info is passed (larger MUST be > smaller; training 4 implies training 8.) going to stick with this for now
    total_rounds_accum = 0
    client_types_to_rounds = {}
    for client_type in sorted(clients_config_arg.keys()):
      client_rounds = clients_config_arg[client_type]["communication_rounds"]
      client_types_to_rounds[client_type] = client_rounds - total_rounds_accum
      total_rounds_accum += (client_rounds - total_rounds_accum)

    # this should be client type to rounds TO RUN for basically clients that size AND ABOVE.
    print_cust(f"run_qfl_experiments_parallel_multiprocess, client_types_to_rounds: {client_types_to_rounds}")

    if custom_debug:
      print_cust(f"run_qfl_experiments_parallel_multiprocess, num_total_rounds: {num_total_rounds}")
      assert num_total_rounds >= 0, "Number of total rounds is not at least 0"

    # The number of output qubits is enough to accommodate the number of classes.
    n_output_qubits = int(np.ceil(np.log2(len(classes))))

    if custom_debug:
      print_cust(f"run_qfl_experiments_parallel_multiprocess, n_output_qubits: {n_output_qubits}")
      assert n_output_qubits >= 0, "Number of output qubits is not at least 0"

    print_cust(f"run_qfl_experiments_parallel_multiprocess, num_total_rounds: {num_total_rounds}")

    print_cust(f"run_qfl_experiments_parallel_multiprocess, n_output_qubits: {n_output_qubits}")

    # Load the input dataset. If we are doing amplitude embedding or local PCA, then we do not want to dimensionality reduce the input images.
    # Note that the pixels are normalized to be between 0 and 1.
    keep_orig_imgs = (local_pca or amp_embed)

    if custom_debug:
      print_cust(f"run_qfl_experiments_parallel_multiprocess, keep_orig_imgs: {keep_orig_imgs}")

    X_angles, y = load_dataset(dataset_type=dataset_type, classes=classes, n_samples=n_samples, num_feats=max_size_clients, keep_orig_imgs=keep_orig_imgs, custom_debug=custom_debug)

    print_cust(f"run_qfl_experiments_parallel_multiprocess, X_angles.shape: {X_angles.shape}")

    print_cust(f"run_qfl_experiments_parallel_multiprocess, local_pca: {local_pca}, do_lda: {do_lda}")

    # Load in the previous data dictionary, if supplied.
    if init_client_data_dict is not None:
      clients_data_dict = init_client_data_dict["clients_data_dict"]
      (X_test, y_test) = init_client_data_dict["testing_data"]
      # if shared_pca:
      #   shared_max_comps = init_client_data_dict['shared_max_comps']
      #   shared_min_comps = init_client_data_dict['shared_min_comps']
      print_cust(f"run_qfl_experiments_parallel_multiprocess, loaded in existing data")
    else:
      filtered_clients_config_arg = {}
      for cli_type, cli_config_val in clients_config_arg.items():
        if cli_config_val["percentage_data"] > 0.0:
          filtered_clients_config_arg[cli_type] = cli_config_val
      print_cust(f"run_qfl_experiments_parallel_multiprocess, filtered_clients_config_arg: {filtered_clients_config_arg}")
      clients_data_dict, (X_test, y_test) = split_data_federated(X_angles, y, filtered_clients_config_arg, test_frac, val_frac=val_frac, random_state=random_state, feature_skew=feature_skew, label_skew=label_skew, local_pca=local_pca,
                                                                do_lda=do_lda, feat_sel_type=feat_sel_type, amp_embed=amp_embed, feat_ordering=feat_ordering, shared_pca=shared_pca, fed_pca_mocked=fed_pca_mocked)
      print_cust(f"run_qfl_experiments_parallel_multiprocess, generated new data")

    # Create a log of the information that we have throughout training.
    data_logs = {}

    data_logs["clients_data_dict"] = clients_data_dict
    data_logs["testing_data"] = (X_test, y_test)
    # if shared_pca:
    #   data_logs['shared_max_comps'] = shared_max_comps
    #   data_logs['shared_min_comps'] = shared_min_comps

    # Store the initial set of training parameters and data in a pickle file.
    # MODIFIED, layers: added is_qcnn to the name of the file
    if save_pkl:
      with open(f"{log_data_folder}/data_logs_n_samples_{n_samples}_dataset_type_{dataset_type}_classes_{'_'.join(classes)}_train_models_parallel_{train_models_parallel}_feature_skew_{feature_skew}_label_skew_{label_skew}_local_pca_{local_pca}_shared_pca_{shared_pca}_gen_{generative}_qcnn_{is_qcnn}_{random_state}.pkl", "wb") as file:
        pickle.dump(data_logs, file)


    if use_torch:
      math_int = torch
    else:
      math_int = np

    # Maybe right here, have a torch data conversion function ??? (if not done in perform_federated_pca_mocked)
    clients_data_dict = convert_data_to_lib(clients_data_dict, math_int=math_int)
    if math_int == torch:
      # TODO, layers: convert y_test to a PyTorch tensor as well.
      X_test = torch.tensor(X_test, dtype=torch.float32)
      y_test = torch.tensor(y_test, dtype=torch.float32)
    print_cust(f"run_qfl_experiments_parallel_multiprocess, after torch conversion, clients_data_dict: {clients_data_dict}")

    if generative:
      print_cust(f"run_qfl_experiments_parallel_multiprocess, X_test.shape: {X_test.shape}")
      if len(X_test.shape) == 2:
        img_dim = math.isqrt(X_test.shape[1])
        assert (img_dim ** 2) == X_test.shape[1], f"run_qfl_experiments_parallel_multiprocess, X_test is not a perfect square, X_test.shape: {X_test.shape}"
        X_test = X_test.view(X_test.shape[0], img_dim, img_dim)
      print_cust(f"run_qfl_experiments_parallel_multiprocess, generative, X_test.max(): {X_test.max()}, X_test.min(): {X_test.min()}")
      save_tensors_to_folder(X_test, f"{targ_data_folder_prefix}", "img")
      n_samples_test = X_test.shape[0]


    # TODO: continue reading here, 7/9
    # also, please put this in a function......
    if shared_pca:
      # NOTE: this function MUTATES clients_data_dict
      pca_global, shared_min_comps, shared_max_comps, inv_pca_min, inv_pca_max = perform_federated_pca_mocked(clients_data_dict, max_size_clients, random_seed=random_state, math_int=math_int, device=device, generative=generative)
      # TODO: ... = perform_federated_pca_mocked()...

    # For sanity, print out the data as well as the amount of data that each client has.
    if debug:
      for client_type in clients_data_dict:
        clients_data_list = clients_data_dict[client_type]
        print_cust(f"client_type: {client_type}, len(clients_data_list): {len(clients_data_list)}")
        for client_idx in range(len(clients_data_list)):
          client_data = clients_data_list[client_idx]
          print_cust(f"client_type: {client_type}, client_idx: {client_idx}, len(client_data): {len(client_data)}")
          if not local_pca:
            (training_data, validation_data) = client_data
          else:
            # NOTE, layers: for less namespace conflicts, pca_info -> pca_info_cli
            (training_data, validation_data, pca_info_cli) = client_data
          print_cust(f"client_type: {client_type}, client_idx: {client_idx}, len(training_data): {len(training_data)}, len(validation_data): {len(validation_data)}")
          (X_train, y_train) = training_data
          (X_val, y_val) = validation_data
          print_cust(f"client_type: {client_type}, client_idx: {client_idx}, X_train.shape: {X_train.shape}, y_train.shape: {y_train.shape}, X_val.shape: {X_val.shape}, y_val.shape: {y_val.shape}")

    # Initialize the parameters for each client.
    # qnode_func, device, pca_info
    qnode_func = None
    # device = None
    pca_info = ()
    if generative:
      qnode_func = create_qnode_qgan
      # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
      pca_info = (pca_global, shared_min_comps, shared_max_comps, inv_pca_min, inv_pca_max, device)
    # NOTE, layers: continue reading here. (stopped here, 7/31)
    # DONE: TOMODIFY, layers: add an argument that specifies whether or not we need conv layers.
    # DONE: TOMODIFY, layers: do this in a torch.no_grad() context.
    # NO CHANGE (I believe the change should be made in the main call): TOMODIFY, layers: for additional layers for the same number of qubits, add them to the list IN ORDER (so the order of the block params really does matter)
    # (later, if you need to keep track of the layer information, you can do so using the qubits_and_layers_to_add_block_params arg for smaller sized models, so it is OK)
    with torch.no_grad():
      client_params_dict = initialize_client_params(clients_config_arg, model_size=min_size_clients, cur_client_params_dict=None, qubits_and_layers_to_add_block_params=qubits_and_layers_to_add_block_params, train_models_parallel=train_models_parallel,
                                                    n_output_qubits=n_output_qubits, generative=generative, use_torch=use_torch, qnode_func=qnode_func, device=device, pca_info=pca_info, is_qcnn=is_qcnn, alt_zeros_init=alt_zeros_init)

    print_cust(f"run_qfl_experiments_parallel_multiprocess, first initialization of client params, client_params_dict: {client_params_dict}")

    # Make all parameters the same across clients with a new random initialization if we want clients to have the same initialization.
    if same_init:
      print_cust(f"run_qfl_experiments_parallel_multiprocess, same parameters initialization")
      with torch.no_grad():
        meta_params = generate_meta_params(client_params_dict, clients_data_dict, math_int=math_int)
        meta_params = generate_meta_params_random(meta_params, math_int=math_int)
        client_params_dict = broadcast_param_updates(client_params_dict, meta_params, math_int=math_int)
        print_cust(f"run_qfl_experiments_parallel_multiprocess, client_params_dict: {client_params_dict}")
        if initial_supp_params is not None:
          # client_params_dict = initial_supp_params
          for cli_type in client_params_dict:
            for cli_idx, cli_params in enumerate(client_params_dict[cli_type]):
              # DONE: TOMODIFY, layers: this is for the generative case; for the nongenerative case, have the code. should be pretty straightforward; just replace block params entirely
              # ^ and later, can do validation to make sure that the structure, order of the supplied block params is consistent with qubits_and_layers_to_add_block_params.
              if generative:
                cli_gen = cli_params[5][0]
                supp_cli_gen = initial_supp_params[cli_type][cli_idx][5][0]
                cli_gen.load_state_dict(supp_cli_gen.state_dict())
                cli_params[5][1].load_state_dict(initial_supp_params[cli_type][cli_idx][5][1].state_dict())
              elif not is_qcnn:
                cli_params_list = list(cli_params)
                cli_params_list[5] = initial_supp_params[cli_type][cli_idx][5]
                client_params_dict[cli_type][cli_idx] = tuple(cli_params_list)

          print_cust(f"run_qfl_experiments_parallel, supplied initial_supp_params, client_params_dict: {client_params_dict}")

    # Store the initial set of training parameters and data in a pickle file.
    print_cust(f"run_qfl_experiments_parallel, saving client_params_dict")
    # MODIFIED, layers: added is_qcnn to the name of the file
    if save_pkl:
      with open(f"{log_data_folder}/client_params_dict_n_samples_{n_samples}_dataset_type_{dataset_type}_classes_{'_'.join(classes)}_train_models_parallel_{train_models_parallel}_feature_skew_{feature_skew}_label_skew_{label_skew}_local_pca_{local_pca}_shared_pca_{shared_pca}_gen_{generative}_qcnn_{is_qcnn}_{random_state}.pkl", "wb") as file:
        pickle.dump(client_params_dict, file)

    # client_types = sorted(list(clients_config_arg.keys()))

    # Initialize the data logging information for each communication round.

    client_optims_dict = initialize_client_optimizers(client_params_dict, lr_gen, lr_disc, gen_betas, disc_betas, optim_type, existing_optims_dict=None, generative=generative, is_qcnn=is_qcnn, opt_layers=opt_layers)
    print_cust(f"run_qfl_experiments_parallel_multiprocess, client_optims_dict: {client_optims_dict}")

    for round_num in range(num_total_rounds):
      print_cust(f"run_qfl_experiments_parallel_multiprocess, round_num: {round_num}")
      # Initialize the data logging information for this communication round.
      data_logs[round_num] = {}
      data_logs[round_num]["aggregated_params"] = None
      data_logs[round_num]["test_fid_score"] = None

      data_logs[round_num]["testing_acc"] = None
      data_logs[round_num]["testing_acc_stdev"] = None
      data_logs[round_num]["testing_acc_topk"] = None
      data_logs[round_num]["testing_loss"] = None

      if multiclassifier_type != "":
        data_logs[round_num]["testing_acc_classifiers"] = None
        data_logs[round_num]["testing_acc_stdev_classifiers"] = None
        data_logs[round_num]["testing_acc_topk_classifiers"] = None
        data_logs[round_num]["testing_loss_classifiers"] = None
        data_logs[round_num]["gen_all_probs"] = None

      for client_size, cfg in clients_config_arg.items():
        data_logs[round_num][client_size] = {}
        # TODO: add a field for testing loss in data_logs
        data_logs[round_num][client_size]["local_epochs"] = cfg["local_epochs"]
        data_logs[round_num][client_size]["client_metrics"] = []
        for client_idx in range(cfg["num_clients"]):
          data_logs[round_num][client_size]["client_metrics"].append({"trained_params": None,
                                                        "minibatch_losses": None,
                                                        "validation_losses": None,
                                                        "training_acc": None,
                                                        "testing_acc": None,
                                                        "testing_loss": None,
                                                        "training_acc_stdev": None,
                                                        "testing_acc_stdev": None,
                                                        "training_acc_topk": None,
                                                        "testing_acc_topk": None,
                                                        "testing_fid": None,
                                                        "grad_norms": None,
                                                        "optimizer_state": None,
                                                        "testing_acc_classifiers": None,
                                                        "testing_acc_stdev_classifiers": None,
                                                        "testing_acc_topk_classifiers": None,
                                                        "testing_loss_classifiers": None,
                                                        "gen_all_probs": None})

    # ---------- (UNCHANGED BOILERPLATE: data loading / setup) -------------
    # ... everything down to the initialisation of `client_params_dict`
    # ---------------------------------------------------------------------

    # ---------------------------------------------------------------------
    #  Communication rounds – now PARALLEL
    # ---------------------------------------------------------------------
    # max_workers = None          # defaults to os.cpu_count()
    print_cust(f"run_qfl_experiments_parallel_multiprocess, os.cpu_count(): {os.cpu_count()}")
    if not generative:
      qnode_builder = create_qnode_qcnn      # tiny alias for pickling friendliness
      if multiclassifier_type == "multirun":
        # NOTE, depthFL: this is NOT a real qnode; its a caller that calls a qnode MULTIPLE times.
        qnode_builder = create_qnode_qcnn_multieval
      elif multiclassifier_type == "ancilla_endmeas":
        qnode_builder = create_qnode_qcnn_singleeval
      elif multiclassifier_type == "cheating":
        qnode_builder = create_qnode_qcnn_multieval_cheating
      elif multiclassifier_type == "tunnel_down":
        qnode_builder = create_qnode_qcnn_singleeval_tunneldown
    else:
      qnode_builder = create_qnode_qgan
    # DONE: TOMODIFY, depthFL: a function that is a wrapper around qnode_builder that yields a custom QNode.
    # Q: am I using it as the qnode_builder then? so I'm just replacing it? A: yes, I think so.

    print_cust(f"run_qfl_experiments_parallel_multiprocess, qnode_builder: {qnode_builder}")

    rounds_elapsed = 0
    for cur_shared_model_size, model_size_rounds in client_types_to_rounds.items():

      if cur_shared_model_size > min_size_clients and not train_models_parallel:
        # DONE: TOMODIFY, layers: do this in a torch.no_grad() context.
        with torch.no_grad():
          client_params_dict = initialize_client_params(clients_config_arg, model_size=cur_shared_model_size, cur_client_params_dict=client_params_dict, qubits_and_layers_to_add_block_params=qubits_and_layers_to_add_block_params, train_models_parallel=train_models_parallel,
                                                    n_output_qubits=n_output_qubits, generative=generative, use_torch=use_torch, qnode_func=qnode_func, device=device, pca_info=pca_info, is_qcnn=is_qcnn, alt_zeros_init=alt_zeros_init)
        # DONE: TOMODIFY, layers: inject the mask_grads argument here to specify what params to optimize over.
        print_cust(f"run_qfl_experiments_parallel_multiprocess, after expansion reinitialization, client_params_dict: {client_params_dict}")

        client_optims_dict = initialize_client_optimizers(client_params_dict, lr_gen, lr_disc, gen_betas, disc_betas, optim_type, existing_optims_dict=client_optims_dict, generative=generative, is_qcnn=is_qcnn, opt_layers=opt_layers)

        existing_clients_datadict = list(clients_data_dict.keys())
        for existing_client_size in existing_clients_datadict:
          if cur_shared_model_size > existing_client_size:
            del clients_data_dict[existing_client_size]

        print_cust(f"run_qfl_experiments_parallel_multiprocess, not train_models_parallel, clients_data_dict.keys(): {clients_data_dict.keys()}")

      for round_num in range(model_size_rounds):
          round_num += rounds_elapsed
          print_cust(f"run_qfl_experiments_parallel_multiprocess, [round {round_num}] starting")

          # TODO: DI in the right workers and args for mp_context based on processpoolexecutor or threadpoolexecutor.
          # NOTE, depthFL: with multiple clients, trying dummyexecutor for now.
          with DummyExecutor(max_workers=max_workers) as pool:
              print_cust(f"run_qfl_experiments_parallel_multiprocess, max_workers: {max_workers}")
              futures = []

              # -------------------------------------------------------------
              # spawn one worker job per client
              # -------------------------------------------------------------
              for client_type, client_params in client_params_dict.items():
                  client_data_list = clients_data_dict[client_type]
                  # NOCHANGE: TOMODIFY, depthFL: if train_models_parallel and not is_qcnn, dynamically figure out the client size based on the max number
                  # of qubits needed to run their block params.
                  # ^ OR, can just not change it for now; this allows different client types to train for different local epochs.
                  if train_models_parallel:
                    cur_model_size   = client_type
                  else:
                    cur_model_size   = cur_shared_model_size
                  print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_model_size: {cur_model_size}")
                  # DONE: TOMODIFY, layers: inject the arg here to not do QCNN and thus have conv_layers be 0.
                  conv_layers      = compute_conv_layers(cur_model_size,
                                                        n_output_qubits, generative=generative, is_qcnn=is_qcnn)
                  print_cust(f"run_qfl_experiments_parallel_multiprocess, conv_layers: {conv_layers}")
                  num_local_epochs = clients_config_arg[cur_model_size]["local_epochs"]

                  for client_idx, client_params_indiv in enumerate(client_params):
                      client_train_data = client_data_list[client_idx][0]
                      client_val_data   = client_data_list[client_idx][1]

                      print_cust(f"run_qfl_experiments_parallel_multiprocess, client_train_data: {client_train_data}, client_val_data: {client_val_data}")

                      grad_mask = None

                      # DONE: TOMODIFY, layers: can temporarily override this to not call mask_gradients; doing that in my optimizer construction.
                      if mask_grads:
                        if cur_model_size > min_size_clients and not generative and is_qcnn:
                          print_cust(f"run_qfl_experiments_parallel_multiprocess, calling mask_gradients()")
                          grad_mask = mask_gradients(client_params_indiv)

                      # Create the feature list indicating what features for which qubits are to be used for this client, as well as the expansion data (used for identity
                      # initialization).
                      # expand = False
                      # if cur_model_size > min_size_clients:
                      #   expand = True
                      # feature_list, expansion_data = create_feat_list_expansion_data(cur_model_size, conv_layers, expand=expand, pool_in=pool_in, min_qubits_noexpand=min_size_clients,
                      #                                                               train_models_parallel=train_models_parallel, feat_sel_type=feat_sel_type)

                      # build client‑specific testing data once here
                      if local_pca and not generative:
                        # NOTE: why was client_data_indiv used here?? doesn't make that much sense.
                        # NOTE, layers: changed client_pca_info here to use index 2 instead of 1; not sure if it's right
                        client_pca_info = client_data_list[client_idx][2]
                        print_cust(f"run_qfl_experiments_parallel_multiprocess, client_pca_info: {client_pca_info}")
                        client_pca = client_pca_info[0]
                        client_data_pca = client_pca_info[1]
                        # DONE: TOMODIFY, layers: note that the PCA is for numpy, and X_test is a pytorch tensors, so I'll need to do some data conversions (applies for this entire block of code below)
                        print_cust(f"run_qfl_experiments_parallel_multiprocess, X_test.shape: {X_test.shape}, type(X_test): {type(X_test)}")
                        if use_torch:
                          X_test_np = X_test.detach().cpu().numpy()
                        else:
                          X_test_np = X_test
                        print_cust(f"run_qfl_experiments_parallel_multiprocess, X_test_np.shape: {X_test_np.shape}, type(X_test_np): {type(X_test_np)}")
                        X_test_client_pca = client_pca.transform(X_test_np)
                        # Scale each PCA component independently to [0, π]
                        # MODIFIED, layers: changed np -> math_int
                        if use_torch:
                          # MODIFIED, layers: converted to torch after PCA transform
                          # NOTE, layers: I might get dtype issues here, in which case I'd need to change to torch.float32 explicitly.
                          X_test_client_pca = torch.from_numpy(X_test_client_pca).to(device)
                        X_test_client_angle = math_int.zeros_like(X_test_client_pca)
                        # assuming that the number of components is simply the client's data type
                        # TODO: rescale the testing data to match the PCA scale for each client
                        print_cust(f"run_qfl_experiments_parallel_multiprocess, X_test_client_pca.shape: {X_test_client_pca.shape}, type(X_test_client_pca): {type(X_test_client_pca)}")
                        print_cust(f"run_qfl_experiments_parallel_multiprocess, X_test_client_angle.shape: {X_test_client_angle.shape}, type(X_test_client_angle): {type(X_test_client_angle)}")
                        print_cust(f"run_qfl_experiments_parallel_multiprocess, shared_min_comps: {shared_min_comps}, shared_max_comps: {shared_max_comps}, type(shared_min_comps): {type(shared_min_comps)}, type(shared_max_comps): {type(shared_max_comps)}")
                        for i in range(cur_model_size):
                            comp = X_test_client_pca[:, i]
                            if shared_pca:
                              # NOTE, layers: technically I should do instanceof check here to make sure that
                              # these are torch return types
                              lo, hi = shared_min_comps.values[i], shared_max_comps.values[i]
                              print_cust(f"run_qfl_experiments_parallel_multiprocess, comp.shape: {comp.shape}, type(comp): {type(comp)}, lo.shape: {lo.shape}, hi.shape: {hi.shape}, type(lo): {type(lo)}, type(hi): {type(hi)}")
                              comp_norm = ( (comp - lo) / (hi - lo + 1e-8) )
                              # NOTE, depthFL: need to clip comp_norm????
                              # comp_norm = np.clip(comp_norm, 0, 1)
                            else:
                              orig_comp = client_data_pca[:, i]
                              comp_norm = (comp - orig_comp.min()) / (orig_comp.max() - orig_comp.min() + 1e-8)
                            # MODIFIED, layers: changed np -> math_int
                            X_test_client_angle[:, i] = comp_norm * math_int.pi
                      else:
                        X_test_client_angle = X_test

                      # # Select the features for amplitude encoding based on feat_ordering.
                      # if amp_embed:
                      #   if feat_ordering == "highest_var":
                      #     variances = X_angles.var(axis=0, ddof=0)
                      #     order = np.argsort(variances)[::-1]
                      #     X_test_client_angle = X_test_client_angle[:, order] + 1e-3
                      #   testing_data = reorder_amplitude_data(X_test_client_angle, feature_list)
                      # else:
                      #   testing_data = X_test_client_angle[:, feature_list]

                      print_cust(f"run_qfl_experiments_parallel, X_test_client_angle.shape: {X_test_client_angle.shape}, type(X_test_client_angle): {type(X_test_client_angle)}, y_test.shape: {y_test.shape}, type(y_test): {type(y_test)}")
                      print_cust(f"run_qfl_experiments_parallel, X_test_client_angle.min(): {X_test_client_angle.min()}, X_test_client_angle.max(): {X_test_client_angle.max()}, y_test.min(): {y_test.min()}, y_test.max(): {y_test.max()}")
                      if local_pca and not generative:
                        X_test_client_angle = math_int.clip(X_test_client_angle, 0.0, math_int.pi)
                      print_cust(f"run_qfl_experiments_parallel, X_test_client_angle.shape: {X_test_client_angle.shape}, type(X_test_client_angle): {type(X_test_client_angle)}, y_test.shape: {y_test.shape}, type(y_test): {type(y_test)}")
                      print_cust(f"run_qfl_experiments_parallel, X_test_client_angle.min(): {X_test_client_angle.min()}, X_test_client_angle.max(): {X_test_client_angle.max()}, y_test.min(): {y_test.min()}, y_test.max(): {y_test.max()}")
                      testing_data = (X_test_client_angle, y_test)


                      if generative:
                        generator_metadata = client_params_indiv[5][0].get_data_components()
                        disc_metadata = client_params_indiv[5][1].get_data_components()
                        client_params_indiv_serialized = [[] for _ in range(len(client_params_indiv))]
                        print_cust(f"run_qfl_experiments_parallel_multiprocess, len(client_params_indiv_serialized): {len(client_params_indiv_serialized)}")
                        client_params_indiv_serialized[5].append((client_params_indiv[5][0].state_dict(), generator_metadata))
                        client_params_indiv_serialized[5].append((client_params_indiv[5][1].state_dict(), disc_metadata))
                        # client_params_indiv[5][0] = (client_params_indiv[5][0].state_dict(), generator_metadata)
                        # client_params_indiv[5][1] = (client_params_indiv[5][1].state_dict(), disc_metadata)
                        # print_cust(f"run_qfl_experiments_parallel_multiprocess, 'serializing' model, client_params_indiv: {client_params_indiv}")
                        print_cust(f"run_qfl_experiments_parallel_multiprocess, 'serializing' model, client_params_indiv_serialized: {client_params_indiv_serialized}")
                        client_params_indiv = client_params_indiv_serialized

                        client_optims_indiv = client_optims_dict[client_type][client_idx]

                        client_optims_indiv_serialized = [[] for _ in range(len(client_optims_indiv))]

                        client_optims_indiv_serialized[5].append(client_optims_indiv[5][0].state_dict())

                        client_optims_indiv_serialized[5].append(client_optims_indiv[5][1].state_dict())

                        # client_optims_indiv_serialized = [client_optims_indiv[5][0].state_dict(), client_optims_indiv[5][1].state_dict()]

                      # DONE: TOMODIFY, layers: for the client_optims_indiv_serialized, do something similar in the new format of the optims dictionary for
                      # the classifier case.

                      # NO CHANGE: TOMODIFY, layers: for parameters that don't have a value (None), from the generative case, set them to be None.
                      # DONE: TOMODIFY, layers: for this overall function, have an argument to specify the pennylane_interface (for layers, I want it to be torch).
                      elif not is_qcnn:
                        client_params_indiv_serialized = copy.deepcopy(client_params_indiv)
                        client_optims_indiv = client_optims_dict[client_type][client_idx]

                        client_optims_indiv_serialized = [None for _ in range(len(client_optims_indiv))]

                        client_optims_indiv_serialized[5] = client_optims_indiv[5].state_dict()

                        print_cust(f"run_qfl_experiments_parallel, not is_qcnn, client_params_indiv_serialized: {client_params_indiv_serialized}")
                        print_cust(f"run_qfl_experiments_parallel, not is_qcnn, client_optims_indiv_serialized: {client_optims_indiv_serialized}")

                      if client_type in qubits_and_layer_types_block_params:
                        layer_types_list = qubits_and_layer_types_block_params[client_type]
                      else:
                        layer_types_list = []
                      
                      print_cust(f"run_qfl_experiments_parallel_multiprocess, layer_types_list: {layer_types_list}")

                      job = dict(client_type=client_type,
                                client_idx=client_idx,
                                client_params_indiv=client_params_indiv,
                                client_train_data=client_train_data,
                                client_val_data=client_val_data,
                                testing_data=testing_data,
                                cur_model_size=cur_model_size,
                                min_size_clients=min_size_clients,
                                pool_in=pool_in,
                                feat_sel_type=feat_sel_type,
                                train_models_parallel=train_models_parallel,
                                amp_embed=amp_embed, shots=shots,
                                local_batch_size=local_batch_size,
                                local_lr=local_lr,
                                num_local_epochs=num_local_epochs,
                                conv_layers=conv_layers,
                                feat_ordering=feat_ordering,
                                classes=classes,
                                qnode_builder=qnode_builder,
                                num_total_rounds=num_total_rounds,
                                round_num=round_num,
                                grad_mask=grad_mask,
                                generative=generative,
                                lr_gen=lr_gen,
                                lr_disc=lr_disc,
                                noise_func=noise_func,
                                criterion_func=criterion_func,
                                log_data_folder=log_data_folder,
                                device=device,
                                client_optims_indiv=client_optims_indiv_serialized,
                                optim_type=optim_type,
                                gen_betas=gen_betas,
                                disc_betas=disc_betas,
                                use_torch=use_torch,
                                pennylane_interface=pennylane_interface,
                                opt_layers=opt_layers,
                                layer_types_list=layer_types_list,
                                loss_type=loss_type,
                                lr_disc_decay=lr_disc_decay,
                                cont_optim_state=cont_optim_state)
                      futures.append(pool.submit(_train_single_client, job))

              # -------------------------------------------------------------
              #  gather results as they finish
              # -------------------------------------------------------------
              for fut in as_completed(futures):
                  res = fut.result()

                  ctype, cidx = res["client_type"], res["client_idx"]
                  print_cust(f"run_qfl_experiments_parallel_multiprocess, round_num: {round_num}, ctype: {ctype}, cidx: {cidx}")

                  if not generative and is_qcnn:
                    client_params_dict[ctype][cidx] = res["trained_params"]

                    # write metrics into the log
                    dlog = data_logs[round_num][ctype]["client_metrics"][cidx]
                    dlog["trained_params"]   = res["trained_params"]
                    dlog["minibatch_losses"] = res["minibatch_losses"]
                    dlog["validation_losses"] = res["validation_losses"]
                    dlog["training_acc"]     = res["train_acc"]
                    dlog["testing_acc"]      = res["test_acc"]
                    dlog["testing_loss"]     = res["test_loss"]
                    dlog["training_acc_stdev"] = res["train_acc_stdev"]
                    dlog["testing_acc_stdev"]  = res["test_acc_stdev"]
                    dlog["training_acc_topk"]  = res["train_acc_topk"]
                    dlog["testing_acc_topk"]   = res["test_acc_topk"]

                    print_cust(f'run_qfl_experiments_parallel_multiprocess, round_num: {round_num}, cur_model_size: {cur_model_size}, client_idx: {cidx}, train_acc: {res["train_acc"]}, test_acc: {res["test_acc"]}, test_loss: {res["test_loss"]}, train_acc_stdev: {res["train_acc_stdev"]}, test_acc_stdev: {res["test_acc_stdev"]}, train_acc_topk: {res["train_acc_topk"]}, test_acc_topk: {res["test_acc_topk"]}')
                  elif not generative:
                    # DONE: TOMODIFY, layers: in addition, if optimizer state is supplied, then update the state dict for the optimizer here.
                    # (for debug, for layers, can later store the optimizer state at each round.)
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, res['trained_disc_params']: {res['trained_disc_params']}")
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, qnode_builder: {qnode_builder}")
                    cur_cli_params = client_params_dict[ctype][cidx]
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_params[5]: {cur_cli_params[5]}")
                    cur_cli_params_list = list(cur_cli_params)
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_params_list: {cur_cli_params_list}")
                    cur_cli_params_list[5] = res["trained_disc_params"]
                    cur_cli_params = tuple(cur_cli_params_list)
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, type(cur_cli_params): {type(cur_cli_params)}")
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_params[5]: {cur_cli_params[5]}")
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_params: {cur_cli_params}")
                    # MODIFIED, depthFL: update client params.
                    client_params_dict[ctype][cidx] = cur_cli_params

                    for i, p in enumerate(cur_cli_params[5]):
                        print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_params[5][{i}]:\n{p}")

                    # BUG (suspected): NOT changing client_params_dict.
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, client_params_dict[ctype][cidx]: {client_params_dict[ctype][cidx]}")

                    cur_cli_optims = client_optims_dict[ctype][cidx]
                    cur_cli_optim_classifier = cur_cli_optims[5]
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_optim_classifier: {cur_cli_optim_classifier}")
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_optim_classifier.state_dict(): {cur_cli_optim_classifier.state_dict()}")
                    # NOTE: loads in a SHALLOW copy of state dict for the classifier optimizer.
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, res['trained_optim_disc']: {res['trained_optim_disc']}")
                    cur_cli_optim_classifier.load_state_dict(res["trained_optim_disc"])
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_optim_classifer: {cur_cli_optim_classifier}")
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_optim_classifier.state_dict(): {cur_cli_optim_classifier.state_dict()}")

                    train_metrics_dict = res["train_metrics_dict"]

                    print_cust(f"run_qfl_experiments_multiprocess, train_metrics_dict: {train_metrics_dict}")

                    dlog = data_logs[round_num][ctype]["client_metrics"][cidx]

                    dlog["trained_params"] = [copy.deepcopy(res["trained_disc_params"])]
                    dlog["minibatch_losses"] = [train_metrics_dict["disc_loss"]]
                    dlog["grad_norms"] = [train_metrics_dict["disc_grad_norms"]]
                    # how in the world is this sending a shallow copy????????
                    # it is. that's ok
                    dlog["optimizer_state"] = [copy.deepcopy(res["trained_optim_disc"])]

                  else:
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, res['trained_gen_params']: {res['trained_gen_params']}, res['trained_disc_params']: {res['trained_disc_params']}")
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, qnode_builder: {qnode_builder}")
                    cur_cli_params = client_params_dict[ctype][cidx]
                    cur_cli_gen = cur_cli_params[5][0]
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_gen: {cur_cli_gen}")
                    # constructed_cli_gen = build_patchquantumgen(cur_cli_gen_metadata, res["trained_gen_params"], qnode_builder)
                    cur_cli_disc = cur_cli_params[5][1]
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_disc: {cur_cli_disc}")
                    # constructed_cli_disc = build_pca_discriminator(cur_cli_disc_metadata, res["trained_disc_params"])

                    cur_cli_gen.load_state_dict(res["trained_gen_params"])
                    cur_cli_disc.load_state_dict(res["trained_disc_params"])
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_gen: {cur_cli_gen}")
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_disc: {cur_cli_disc}")

                    cur_cli_optims = client_optims_dict[ctype][cidx]
                    cur_cli_optim_gen = cur_cli_optims[5][0]
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_optim_gen: {cur_cli_optim_gen}")
                    cur_cli_optim_gen.load_state_dict(res["trained_optim_gen"])
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_optim_gen: {cur_cli_optim_gen}")

                    cur_cli_optim_disc = cur_cli_optims[5][1]
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_optim_disc: {cur_cli_optim_disc}")
                    cur_cli_optim_disc.load_state_dict(res["trained_optim_disc"])
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_optim_disc: {cur_cli_optim_disc}")
                    # cur_cli_params[5][0] = constructed_cli_gen
                    # cur_cli_params[5][1] = constructed_cli_disc
                    # can print out client_params_dict to verify that it changes here if I want I guess
                    # print_cust(f"run_qfl_experiments_parallel_multiprocess, cur_cli_p")

                    train_metrics_dict = res["train_metrics_dict"]

                    print_cust(f"run_qfl_experiments_multiprocess, train_metrics_dict: {train_metrics_dict}")

                    dlog = data_logs[round_num][ctype]["client_metrics"][cidx]

                    print_cust(f"run_qfl_experiments_multiprocess, res['trained_optim_gen']: {res['trained_optim_gen']}")
                    print_cust(f"run_qfl_experiments_multiprocess, res['trained_optim_disc']: {res['trained_optim_disc']}")

                    test_gen_opt_param_log = res['trained_optim_gen']['state'][0]['exp_avg']
                    test_disc_opt_param_log = res['trained_optim_disc']['state'][0]['exp_avg']

                    print_cust(f"run_qfl_experiments_multiprocess, test_gen_opt_param_log: {test_gen_opt_param_log}")
                    print_cust(f"run_qfl_experiments_multiprocess, test_disc_opt_param_log: {test_disc_opt_param_log}")

                    print_cust(f"run_qfl_experiments_multiprocess, test_gen_opt_param_log.storage().is_shared(): {test_gen_opt_param_log.storage().is_shared()}")
                    print_cust(f"run_qfl_experiments_multiprocess, test_disc_opt_param_log.storage().is_shared(): {test_disc_opt_param_log.storage().is_shared()}")
                    print_cust(f"run_qfl_experiments_multiprocess, test_gen_opt_param_log.data_ptr(): {test_gen_opt_param_log.data_ptr()}")
                    print_cust(f"run_qfl_experiments_multiprocess, test_disc_opt_param_log.data_ptr(): {test_disc_opt_param_log.data_ptr()}")

                    dlog["trained_params"] = [res["trained_gen_params"], res["trained_disc_params"]]
                    dlog["minibatch_losses"] = [train_metrics_dict["gen_loss"], train_metrics_dict["disc_loss"]]
                    dlog["grad_norms"] = [train_metrics_dict["gen_grad_norms"], train_metrics_dict["disc_grad_norms"]]
                    # how in the world is this sending a shallow copy????????
                    # it is. that's ok
                    dlog["optimizer_state"] = [copy.deepcopy(res["trained_optim_gen"]), copy.deepcopy(res["trained_optim_disc"])]
                    # dlog["validation_losses"] = res["validation_losses"]
                    # dlog["training_acc"]     = res["train_acc"]
                    # dlog["testing_acc"]      = res["test_acc"]
                    # dlog["testing_loss"]     = res["test_loss"]
                    # dlog["training_acc_stdev"] = res["train_acc_stdev"]
                    # dlog["testing_acc_stdev"]  = res["test_acc_stdev"]
                    # dlog["training_acc_topk"]  = res["train_acc_topk"]
                    # dlog["testing_acc_topk"]   = res["test_acc_topk"]






          # -----------------------------------------------------------------
          #  aggregation (unchanged)
          # -----------------------------------------------------------------

          print_cust(f"run_qfl_experiments_parallel_multiprocess, before aggregation and broadcasting, round_num: {round_num}, client_params_dict: {client_params_dict}")

          print_cust(f"run_qfl_experiments_parallel_multiprocess, parameter aggregation, agg_strategy: {agg_strategy}, morepers: {morepers}")
          if agg_strategy == "fedavg":
              aggregated_params = federated_averaging(client_params_dict,
                                                      clients_data_dict)
          elif agg_strategy == "fedavg_quat":
              aggregated_params = federated_averaging_quat(client_params_dict,
                                                          clients_data_dict)
          elif agg_strategy == "fedavg_circ":
              aggregated_params = federated_averaging_circular_parallel_shared(
                                    client_params_dict, clients_data_dict, generative=generative)
          else:
              raise ValueError(f"unknown agg_strategy {agg_strategy}")

          if morepers == "aggshared":
              client_params_dict = broadcast_param_updates_shared(
                                      client_params_dict, aggregated_params, math_int=math_int)
          elif morepers == "aggnoconv":
              client_params_dict = broadcast_param_updates_shared_noconv(
                                      client_params_dict, aggregated_params)
          elif morepers == "aggnoconvnofinal":
              client_params_dict = broadcast_param_updates_shared_noconv_nofinal(
                                      client_params_dict, aggregated_params)
          else:
            print_cust(f"run_qfl_experiments_parallel, no broadcasting performed, morepers: {morepers}")

          # Store the aggregated parameters for this round.
          # TODO: aggregated parameters don't have a significant meaning as not all of them are broadcasted to clients based on the broadcasting/personalization scheme
          data_logs[round_num]["aggregated_params"] = aggregated_params

          print_cust(f"run_qfl_experiments_parallel_multiprocess, round_num: {round_num}, aggregated_params: {aggregated_params}")
          print_cust(f"run_qfl_experiments_parallel_multiprocess, round_num: {round_num}, client_params_dict: {client_params_dict}")

          # TODO: recompute metrics of the SHARED model for EACH client.
          # BUGFIX, depthFL: changed one_cli_size to be the maximum sized client.
          # NOTE: this assumes that the MAXIMUM "qubits" client has the MOST layers.
          one_cli_size = max(list(client_params_dict.keys()))

          print_cust(f"run_qfl_experiments_parallel_multiprocess, one_cli_size: {one_cli_size}")

          if generative:
            # NOTE: this assumes that all generators and discriminators have the same parameters.
            agg_model = client_params_dict[one_cli_size][0][5][0]
            # agg_disc = client_params_dict[one_cli_size][0][5][1]
          else:
            agg_model = client_params_dict[one_cli_size][0]


          test_acc, test_acc_stdev, test_acc_topk, test_loss, test_fid_score = (None, None, None, None, None)

          if fid_batch_size is None and generative:
            fid_batch_size = n_samples_test

          if not generative and not is_qcnn:
            # NOTE, layers: only doing this for the not is_qcnn case.
            # TOMODIFY, layers: change cur_shared_model_size here. and also, subset the testing data here.
            num_qubits_bps = []
            for block_param in agg_model[5]:
              num_qubits_bps.append(block_param.shape[1])

            print_cust(f"run_qfl_experiments_parallel_multiprocess, not generative and not is_qcnn, num_qubits_bps: {num_qubits_bps}")

            # NOTE, layers: this is a logical override.
            cur_shared_model_size = max(num_qubits_bps)
            if local_pca or shared_pca:
              print_cust(f"run_qfl_experiments_parallel_multiprocess, not generative and not is_qcnn, cur_shared_model_size: {cur_shared_model_size}")
              X_test_client_angle = X_test_client_angle[:, :cur_shared_model_size]
            # TOMODIFY, depthFL: change the name of this variable (for key into qubits_and_layer_types_block_params)
            largest_clisize_layertypes = max(qubits_and_layer_types_block_params)
            print_cust(f"run_qfl_experiments_parallel_multiprocess, largest_clisize_layertypes: {largest_clisize_layertypes}")
            layer_types_list_largest = qubits_and_layer_types_block_params[largest_clisize_layertypes]
            # TOMODIFY, layers: supply layer_types_list to this qnode_builder
            qnode_test = qnode_builder(cur_shared_model_size, conv_layers, [], n_classes=len(classes), pennylane_interface=pennylane_interface, layer_types_list=layer_types_list_largest)

            print_cust(f"run_qfl_experiments_parallel_multiprocess, X_test_client_angle: {X_test_client_angle}")
            print_cust(f"run_qfl_experiments_parallel, X_test_client_angle.shape: {X_test_client_angle.shape}, type(X_test_client_angle): {type(X_test_client_angle)}, y_test.shape: {y_test.shape}, type(y_test): {type(y_test)}")
            print_cust(f"run_qfl_experiments_parallel, X_test_client_angle.min(): {X_test_client_angle.min()}, X_test_client_angle.max(): {X_test_client_angle.max()}, y_test.min(): {y_test.min()}, y_test.max(): {y_test.max()}")

          # TOMODIFY, layers: set this to be > 101, so that I always compute testing accuracy for logging.
          print_cust(f"run_qfl_experiments_parallel_multiprocess, testacc_rd_cutoff: {testacc_rd_cutoff}")
          if num_total_rounds > testacc_rd_cutoff:
            if (round_num + 1) % 10 == 0:
              if not generative:
                # TODO: this is broken for the non-generative case. process the testing data outside of the loop for each client.. not doing that now tho; not too relevant for
                # generative case
                # TODO: passing in testing data for multiproc seems to be broken here for the nongenerative case. can try to fix later.
                # TOMODIFY, layers (and in general, for classification): testing_data is outside of scope here. Create a 'global' set of testing data (this assumes that
                # testing data is independent of any clients' PCA). Make it clear what testing data I'm using.
                # ^ quickly hacked this by using out of scope testing data for quick impl... but can recreate the testing data as well
                # DONE: TOMODIFY, layers: call this in a torch.no_grad() context.
                # NOTE: for is_qcnn case, I think I override the testing_acc here (check with ctrl + f "testing_acc" ?)
                with torch.no_grad():
                  if morepers != "mocked_bcast":
                    if multiclassifier_type == "":
                      test_acc, test_acc_stdev, test_acc_topk, test_loss = compute_metrics_angle_param_batch(agg_model, X_test_client_angle, y_test, layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode_test, math_int=math_int)
                    else:
                      test_acc, test_acc_stdev, test_acc_topk, test_loss, avg_acc_classifiers, std_acc_classifiers, top_k_accuracies_classifiers, avg_loss_classifiers, gen_all_probs = compute_metrics_angle_param_batch(agg_model, X_test_client_angle, y_test, layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode_test, math_int=math_int)
                    data_logs[round_num]["testing_acc"] = test_acc
                    data_logs[round_num]["testing_acc_stdev"] = test_acc_stdev
                    data_logs[round_num]["testing_acc_topk"] = test_acc_topk
                    data_logs[round_num]["testing_loss"] = test_loss
                    if multiclassifier_type != "":
                      data_logs[round_num]["testing_acc_classifiers"] = avg_acc_classifiers
                      data_logs[round_num]["testing_acc_stdev_classifiers"] = std_acc_classifiers
                      data_logs[round_num]["testing_acc_topk_classifiers"] = top_k_accuracies_classifiers
                      data_logs[round_num]["testing_loss_classifiers"] = avg_loss_classifiers
                      data_logs[round_num]["gen_all_probs"] = gen_all_probs
                  else:
                    print_cust(f"run_qfl_experiments_parallel_multiprocess, morepers is mocked_bcast, morepers: {morepers}")
                    # compute the testing acc for each client.
                    for cli_type, cli_params_list in client_params_dict.items():
                      for cli_idx, cli_params in enumerate(cli_params_list):
                        print_cust(f"run_qfl_experiments_parallel_multiprocess, cli_type: {cli_type}, cli_idx: {cli_idx}")
                        print_cust(f"run_qfl_experiments_parallel_multiprocess, cli_params: {cli_params}")
                        if multiclassifier_type == "":
                          test_acc, test_acc_stdev, test_acc_topk, test_loss = compute_metrics_angle_param_batch(cli_params, X_test_client_angle, y_test, layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode_test, math_int=math_int)
                        else:
                          test_acc, test_acc_stdev, test_acc_topk, test_loss, avg_acc_classifiers, std_acc_classifiers, top_k_accuracies_classifiers, avg_loss_classifiers, gen_all_probs = compute_metrics_angle_param_batch(cli_params, X_test_client_angle, y_test, layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode_test, math_int=math_int)
                        sel_dlog = data_logs[round_num][cli_type]["client_metrics"][cli_idx]
                        sel_dlog["testing_acc"] = test_acc
                        sel_dlog["testing_acc_stdev"] = test_acc_stdev
                        sel_dlog["testing_acc_topk"] = test_acc_topk
                        sel_dlog["testing_loss"] = test_loss
                        if multiclassifier_type != "":
                          sel_dlog["testing_acc_classifiers"] = avg_acc_classifiers
                          sel_dlog["testing_acc_stdev_classifiers"] = std_acc_classifiers
                          sel_dlog["testing_acc_topk_classifiers"] = top_k_accuracies_classifiers
                          sel_dlog["testing_loss_classifiers"] = avg_loss_classifiers
                          sel_dlog["gen_all_probs"] = gen_all_probs
              else:
                if compute_fid:
                  print_cust(f"run_qfl_experiments_multiprocess, computing FID")
                  folder_id_suffix = f"round_num_{round_num}_cur_shared_model_size_{cur_shared_model_size}"
                  test_fid_score = compute_fid_to_data(agg_model, noise_func, f"{targ_data_folder_prefix}",
                                                      f"{gen_data_folder_prefix}/{folder_id_suffix}", n_samples_test, device, fid_batch_size=fid_batch_size, resc_invpca=resc_invpca)
                  data_logs[round_num]["test_fid_score"] = test_fid_score
            else:
              test_acc, test_acc_stdev, test_acc_topk, test_loss = (None, None, None, None)
          else:
            if not generative:
              # TOMODIFY, layers (and in general, for classification): testing_data is outside of scope here. Create a 'global' set of testing data (this assumes that
              # testing data is independent of any clients' PCA). Make it clear what testing data I'm using.
              # ^ quickly hacked this by using out of scope testing data for quick impl... but can recreate the testing data as well
              # DONE: TOMODIFY, layers: call this in a torch.no_grad() context.
              # NOTE: for is_qcnn case, I think I override the testing_acc here (check with ctrl + f "testing_acc" ?)
              with torch.no_grad():
                if morepers != "mocked_bcast":
                  if multiclassifier_type == "":
                    test_acc, test_acc_stdev, test_acc_topk, test_loss = compute_metrics_angle_param_batch(agg_model, X_test_client_angle, y_test, layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode_test, math_int=math_int)
                  else:
                    test_acc, test_acc_stdev, test_acc_topk, test_loss, avg_acc_classifiers, std_acc_classifiers, top_k_accuracies_classifiers, avg_loss_classifiers, gen_all_probs = compute_metrics_angle_param_batch(agg_model, X_test_client_angle, y_test, layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode_test, math_int=math_int)
                  data_logs[round_num]["testing_acc"] = test_acc
                  data_logs[round_num]["testing_acc_stdev"] = test_acc_stdev
                  data_logs[round_num]["testing_acc_topk"] = test_acc_topk
                  data_logs[round_num]["testing_loss"] = test_loss
                  if multiclassifier_type != "":
                    data_logs[round_num]["testing_acc_classifiers"] = avg_acc_classifiers
                    data_logs[round_num]["testing_acc_stdev_classifiers"] = std_acc_classifiers
                    data_logs[round_num]["testing_acc_topk_classifiers"] = top_k_accuracies_classifiers
                    data_logs[round_num]["testing_loss_classifiers"] = avg_loss_classifiers
                    data_logs[round_num]["gen_all_probs"] = gen_all_probs
                else:
                  print_cust(f"run_qfl_experiments_parallel_multiprocess, morepers is mocked_bcast, morepers: {morepers}")
                  # compute the testing acc for each client.
                  for cli_type, cli_params_list in client_params_dict.items():
                    for cli_idx, cli_params in enumerate(cli_params_list):
                      print_cust(f"run_qfl_experiments_parallel_multiprocess, cli_type: {cli_type}, cli_idx: {cli_idx}")
                      print_cust(f"run_qfl_experiments_parallel_multiprocess, cli_params: {cli_params}")
                      if multiclassifier_type == "":
                        test_acc, test_acc_stdev, test_acc_topk, test_loss = compute_metrics_angle_param_batch(cli_params, X_test_client_angle, y_test, layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode_test, math_int=math_int)
                      else:
                        test_acc, test_acc_stdev, test_acc_topk, test_loss, avg_acc_classifiers, std_acc_classifiers, top_k_accuracies_classifiers, avg_loss_classifiers, gen_all_probs = compute_metrics_angle_param_batch(cli_params, X_test_client_angle, y_test, layers=conv_layers, shots=shots, batch_size=local_batch_size, qnode=qnode_test, math_int=math_int)
                      sel_dlog = data_logs[round_num][cli_type]["client_metrics"][cli_idx]
                      sel_dlog["testing_acc"] = test_acc
                      sel_dlog["testing_acc_stdev"] = test_acc_stdev
                      sel_dlog["testing_acc_topk"] = test_acc_topk
                      sel_dlog["testing_loss"] = test_loss
                      if multiclassifier_type != "":
                        sel_dlog["testing_acc_classifiers"] = avg_acc_classifiers
                        sel_dlog["testing_acc_stdev_classifiers"] = std_acc_classifiers
                        sel_dlog["testing_acc_topk_classifiers"] = top_k_accuracies_classifiers
                        sel_dlog["testing_loss_classifiers"] = avg_loss_classifiers
                        sel_dlog["gen_all_probs"] = gen_all_probs
            else:
              if compute_fid:
                print_cust(f"run_qfl_experiments_multiprocess, computing FID")
                folder_id_suffix = f"round_num_{round_num}_cur_shared_model_size_{cur_shared_model_size}"
                test_fid_score = compute_fid_to_data(agg_model, noise_func, f"{targ_data_folder_prefix}",
                                                    f"{gen_data_folder_prefix}/{folder_id_suffix}", n_samples_test, device, fid_batch_size=fid_batch_size, resc_invpca=resc_invpca)
                data_logs[round_num]["test_fid_score"] = test_fid_score
          if not generative:
            print_cust(f"run_qfl_experiments_parallel_multiprocess, round_num: {round_num}, cur_model_size: {cur_model_size}, test_acc: {test_acc}, test_loss: {test_loss}, test_acc_stdev: {test_acc_stdev}, test_acc_topk: {test_acc_topk}")
          else:
            print_cust(f"run_qfl_experiments_parallel_multiprocess, round_num: {round_num}, cur_model_size: {cur_model_size}, test_fid_score: {test_fid_score}")

          print_cust(f"run_qfl_experiments_parallel_multiprocess, [round {round_num}] done\n")
      rounds_elapsed += model_size_rounds

    return data_logs

"""## QFL Experiment Parallel, Non-personalized"""

def run_qfl_experiments_parallel_nonpers(clients_config_arg, classes=["4", "9"], n_samples=1000, dataset_type="mnist", agg_strategy="fedavg", test_frac=0.2, val_frac=0.1, random_state=42, pool_in=True,
                        local_batch_size=32, local_lr=0.01, shots=1024, debug=False, init_client_data_dict=None, save_pkl=False, mask_grads=False, qubits_and_layers_to_add_block_params=[],
                        train_models_parallel=False, same_init=False, feature_skew=0.0):

  max_size_clients = max(clients_config_arg.keys())
  min_size_clients = min(clients_config_arg.keys())

  num_total_rounds = max([clients_config_arg[key]["communication_rounds"] for key in clients_config_arg])

  print_cust(f"run_qfl_experiments_parallel_nonpers, num_total_rounds: {num_total_rounds}")

  X_angles, y = load_dataset(dataset_type=dataset_type, classes=classes, n_samples=n_samples, num_feats=max_size_clients)

  print_cust(f"run_qfl_experiments_parallel, X_angles.shape: {X_angles.shape}")

  if init_client_data_dict is not None:
    clients_data_dict = init_client_data_dict["clients_data_dict"]
    (X_test, y_test) = init_client_data_dict["testing_data"]
    print_cust(f"run_qfl_experiments_parallel, loaded in existing data")
  else:
    clients_data_dict, (X_test, y_test) = split_data_federated(X_angles, y, clients_config_arg, test_frac, val_frac=val_frac, random_state=random_state, feature_skew=feature_skew)
    print_cust(f"run_qfl_experiments_parallel, generated new data")

  if debug:
    for client_type in clients_data_dict:
      clients_data_list = clients_data_dict[client_type]
      print_cust(f"client_type: {client_type}, len(clients_data_list): {len(clients_data_list)}")
      for client_idx in range(len(clients_data_list)):
        client_data = clients_data_list[client_idx]
        print_cust(f"client_type: {client_type}, client_idx: {client_idx}, len(client_data): {len(client_data)}")
        (training_data, validation_data) = client_data
        print_cust(f"client_type: {client_type}, client_idx: {client_idx}, len(training_data): {len(training_data)}, len(validation_data): {len(validation_data)}")
        (X_train, y_train) = training_data
        (X_val, y_val) = validation_data
        print_cust(f"client_type: {client_type}, client_idx: {client_idx}, X_train.shape: {X_train.shape}, y_train.shape: {y_train.shape}, X_val.shape: {X_val.shape}, y_val.shape: {y_val.shape}")

  client_params_dict = initialize_client_params(clients_config_arg, model_size=min_size_clients, cur_client_params_dict=None, qubits_and_layers_to_add_block_params=qubits_and_layers_to_add_block_params, train_models_parallel=train_models_parallel)

  if same_init:
    print_cust(f"run_qfl_experiments_parallel, same parameters initialization")
    meta_params = generate_meta_params(client_params_dict, clients_data_dict)
    meta_params = generate_meta_params_random(meta_params)
    client_params_dict = broadcast_param_updates(client_params_dict, meta_params)

  # client_types = sorted(list(clients_config_arg.keys()))

  data_logs = {}

  data_logs["clients_data_dict"] = clients_data_dict
  data_logs["testing_data"] = (X_test, y_test)

  if save_pkl:
    with open(f"data_logs_n_samples_{n_samples}_dataset_type_{dataset_type}_classes_{'_'.join(classes)}_train_models_parallel_{train_models_parallel}.pkl", "wb") as file:
      pickle.dump(data_logs, file)


  for round_num in range(num_total_rounds):
    print_cust(f"run_qfl_experiments_parallel, round_num: {round_num}")
    data_logs[round_num] = {}
    for client_size, cfg in clients_config_arg.items():
      data_logs[round_num][client_size] = {}
      # TODO: add a field for testing loss in data_logs
      data_logs[round_num][client_size]["aggregated_params_clients"] = []
      data_logs[round_num][client_size]["testing_loss_clients"] = []
      data_logs[round_num][client_size]["testing_accuracy_clients"] = []
      data_logs[round_num][client_size]["local_epochs"] = cfg["local_epochs"]
      data_logs[round_num][client_size]["client_metrics"] = []
      for client_idx in range(cfg["num_clients"]):
        data_logs[round_num][client_size]["client_metrics"].append({"trained_params": None,
                                                       "minibatch_losses": None,
                                                       "validation_losses": None,
                                                       "training_acc": None})

  for round_num in range(num_total_rounds):
    print_cust(f"run_qfl_experiments_parallel, round_num: {round_num}")
    for client_type, client_params in client_params_dict.items():
      print_cust(f"run_qfl_experiments_parallel, client_type: {client_type}")
      client_data = clients_data_dict[client_type]
      cur_model_size = client_type
      for client_idx in range(len(client_params)):
        print_cust(f"run_qfl_experiments_parallel, client_idx: {client_idx}")
        client_params_indiv = client_params[client_idx]
        client_data_indiv = client_data[client_idx]

        client_train_data = client_data_indiv[0]
        client_val_data = client_data_indiv[1]

        if debug:
          print_cust(f"run_qfl_experiments_parallel, len(client_train_data): {len(client_train_data)}")
          print_cust(f"run_qfl_experiments_parallel, client_train_data: {client_train_data}")
          print_cust(f"run_qfl_experiments_parallel, len(client_val_data): {len(client_val_data)}")
          print_cust(f"run_qfl_experiments_parallel, client_val_data: {client_val_data}")

        grad_mask = None

        # if mask_grads:
        #   if cur_model_size > min_size_clients:
        #     grad_mask = mask_gradients(client_params_indiv)

        print_cust(f"run_qfl_experiments_parallel, grad_mask: {grad_mask}")

        # TODO: create qnode
        expand = False
        if cur_model_size > min_size_clients:
          expand = True
        feature_list, expansion_data = create_feat_list_expansion_data(cur_model_size, int(math.log2(cur_model_size)), expand=expand, pool_in=pool_in, min_qubits_noexpand=min_size_clients,
                                                                       train_models_parallel=train_models_parallel)

        qnode = create_qnode_qcnn(cur_model_size, int(math.log2(cur_model_size)), expansion_data, n_classes=len(classes))

        num_local_epochs = clients_config_arg[client_type]["local_epochs"]

        client_params_indiv = copy.deepcopy(client_params_indiv)

        trained_params, minibatch_losses, validation_losses = train_epochs_angle_param_adam(
      client_params_indiv, client_train_data[0][:, feature_list], client_train_data[1], client_val_data[0][:, feature_list], client_val_data[1],
      n_epochs=num_local_epochs, shots=shots, batch_size=local_batch_size, lr=local_lr, qnode=qnode, trainable_mask=grad_mask)

        client_params[client_idx] = trained_params

        train_acc = compute_avg_acc_angle_param_batch(trained_params, client_train_data[0][:, feature_list], client_train_data[1], layers=int(np.log2(cur_model_size)), shots=1024, batch_size=local_batch_size, qnode=qnode)

        data_logs[round_num][cur_model_size]["client_metrics"][client_idx]["training_acc"] = train_acc

        data_logs[round_num][cur_model_size]["client_metrics"][client_idx]["trained_params"] = trained_params
        data_logs[round_num][cur_model_size]["client_metrics"][client_idx]["minibatch_losses"] = minibatch_losses
        data_logs[round_num][cur_model_size]["client_metrics"][client_idx]["validation_losses"] = validation_losses

    # perform aggregation
    # NOTE: assume, during aggregation, that the parameters for all the clients have the same shape.
    print_cust(f"run_qfl_experiments_parallel, parameter aggregation, agg_strategy: {agg_strategy}")
    if agg_strategy == "fedavg":
      aggregated_params = federated_averaging(client_params_dict, clients_data_dict)
    elif agg_strategy == "fedavg_quat":
      aggregated_params = federated_averaging_quat(client_params_dict, clients_data_dict)
    elif agg_strategy == "fedavg_circ":
      aggregated_params = federated_averaging_circular_parallel(client_params_dict, clients_data_dict)
      # print_cust(f"run_qfl_experiments, aggregated_params: {aggregated_params}")

    # # broadcast parameter updates
    # for client_type, client_params in client_params_dict.items():
    #   for client_idx in range(len(client_params)):
    #     client_params[client_idx] = aggregated_params

    client_params_dict = broadcast_param_updates(client_params_dict, aggregated_params)

    for client_type, client_params_list in client_params_dict.items():
      for client_params in client_params_list:

        # TODO: create qnode
        cur_model_size = client_type
        expand = False
        if cur_model_size > min_size_clients:
          expand = True

        feature_list, expansion_data = create_feat_list_expansion_data(cur_model_size, int(math.log2(cur_model_size)), expand=expand, pool_in=pool_in, min_qubits_noexpand=min_size_clients,
                                                                       train_models_parallel=train_models_parallel)

        qnode = create_qnode_qcnn(cur_model_size, int(math.log2(cur_model_size)), expansion_data, n_classes=len(classes))

        test_loss = compute_loss_angle_param_batch(client_params, X_test[:, feature_list], y_test,
                                          shots=shots, qnode=qnode)

        test_acc = compute_avg_acc_angle_param_batch(client_params, X_test[:, feature_list], y_test, layers=int(np.log2(cur_model_size)), shots=1024, batch_size=local_batch_size, qnode=qnode)

        print_cust(f"run_qfl_experiments_parallel, test_loss: {test_loss}")

        print_cust(f"run_qfl_experiments_parallel, test_acc: {test_acc}")

        data_logs[round_num][client_type]["testing_loss_clients"].append(test_loss)

        data_logs[round_num][client_type]["testing_accuracy_clients"].append(test_acc)

        data_logs[round_num][client_type]["aggregated_params_clients"].append(copy.deepcopy(client_params))


  return data_logs

"""# Run QFL Experiments"""

# TODO, depthFL: continue here, 3:00 PM, 8/19

# client_config_exper = {
#     4: {
#         "percentage_data": 0.5,
#         "num_clients": 2,
#         "local_epochs": 1,
#         "communication_rounds": 5
#     },
#     8: {
#         "percentage_data": 0.5,
#         "num_clients": 2,
#         "local_epochs": 1,
#         "communication_rounds": 5
#     }
# }

# from pennylane import numpy as np

# with open("data_logs_n_samples_569_dataset_type_breast_cancer_classes_0_1_train_models_parallel_False_feature_skew_0.0_label_skew_None_local_pca_True_shared_pca_True_min_4_max_4.pkl", "rb") as file:
#   data_logs_prev = pickle.load(file)

# data_logs_prev['clients_data_dict'][4][0][2][1].shape

# data_logs_prev['clients_data_dict'][8] = data_logs_prev['clients_data_dict'][4]

# data_logs_prev.keys()

# data_logs_prev['clients_data_dict'].keys()

# del data_logs_prev["clients_data_dict"][4]

# data_logs_prev.keys()

# data_logs_prev['clients_data_dict'].keys()

# data_logs_prev['testing_data'][1].shape

# data_logs_prev["clients_data_dict"][8] = data_logs_prev["clients_data_dict"][8][:2]

# len(data_logs_prev["clients_data_dict"][8])

# data_logs_prev['clients_data_dict'][6] = data_logs_prev['clients_data_dict'][6][:2]

# data_logs_prev.keys()

# data_logs_prev['clients_data_dict'].keys()

# data_logs_prev['clients_data_dict']

# del data_logs_prev["clients_data_dict"][6]

# data_logs_prev['clients_data_dict'].keys()

# data_logs_prev['clients_data_dict']

# data_logs_prev['testing_data']

# import warnings

# # Make *all* RuntimeWarnings raise errors
# warnings.simplefilter('error', RuntimeWarning)

# # Mount Google drive at a specified filepath, for file operations, if using the
# # notebook as a Google Colab notebook.
# from google.colab import drive
# drive.mount('/content/drive')

# import sys, contextlib









def count_numcli_clitype(config_dict):
  list_counts = []
  for cli_type, cli_configs in config_dict.items():
    list_counts.append(str(cli_configs["num_clients"]))
    list_counts.append(str(cli_type))
  return "_".join(list_counts)

def count_qubits_layers(qubits_layers_list):
  list_counts = []
  for qubits_and_layers in qubits_layers_list:
    list_counts.append(str(qubits_and_layers[0]))
    list_counts.append(str(qubits_and_layers[1]))
  return "_".join(list_counts)