import numpy as np
from mbi.marginal_loss import LinearMeasurement

class MeasurementWrapper:
    """Converter between lists of mbi.LinearMeasurement at numpy vectors.
    """
    def __init__(self, measurements: list[LinearMeasurement]):
        """Initialize to convert between given list and a numpy vector.

        Args:
            measurements (list[LinearMeasurement])
        """
        self.measurements = measurements
        self.measurement_lengths = [m.noisy_measurement.size for m in measurements]

    def get_concatenated_values(self) -> np.ndarray:
        return np.concatenate([m.noisy_measurement for m in self.measurements], axis=0)

    def convert_vector_to_measurements(self, vector: np.ndarray) -> list[LinearMeasurement]:
        measurements = []
        current_index = 0
        for length, m in zip(self.measurement_lengths, self.measurements):
            end = current_index + length
            noisy_measurement = vector[current_index:end]
            measurements.append(LinearMeasurement(
                noisy_measurement=noisy_measurement,
                clique=m.clique,
                stddev=m.stddev,
                query=m.query
            ))
            current_index = end
        return measurements