import matplotlib.pyplot as plt
import numpy as np


# read the data from average_dist_errors.txt

if __name__ == "__main__":
    data = np.loadtxt("average_distance_errors.txt", delimiter=",")

    fig, ax1 = plt.subplots()

    ax1.scatter(data[:,0], data[:,1], marker="o", linestyle="-",  label="raw data")
    ax1.scatter(data[:,0], data[:,2], marker="o",  color="red", label="filtered data (exp moving average)")
    ax1.set_xlabel("updates")
    ax1.set_ylabel("average distance errors each update")
    ax1.set_title("Average distance errors and affinity shift")

    ax2 = ax1.twinx()
    ax2.scatter(data[:,0], data[:,3], marker="o", color="green", label="affinity shift")
    ax2.set_ylabel("affinity shift")

    # 🔥 Sync the y-axis limits based on max of both
    y1_min, y1_max = ax1.get_ylim()
    y2_min, y2_max = ax2.get_ylim()

    y_min = min(y1_min, y2_min)
    y_max = max(y1_max, y2_max)

    ax1.set_ylim(y_min, y_max)
    ax2.set_ylim(y_min, y_max)

    # Combine legends
    lines_1, labels_1 = ax1.get_legend_handles_labels()
    lines_2, labels_2 = ax2.get_legend_handles_labels()
    ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc="upper right")

    plt.show()