import pandas as pd
import matplotlib.pyplot as plt
import os


flow = {
    "1.1": "uniform",
    "2.1": "uniform",
    "3.1": "uniform",
    "4.1": "uniform",
    "5.1": "uniform",
    "6.1": "uniform",
    "7.1": "uniform",
    "8.1": "uniform",
    "9.1": "uniform",
    "10.1": "uniform",
    "11.1": "uniform",
    "12.1": "identity",
    "13.1": "identity",
    "14.1": [("E1", "D1"), ("D2", "C1"), ("C2", "B1"), ("B2", "A1")],
    "15.1": "uniform",
    "16.1": "uniform",
    "17.1": "uniform",
    "18.1": "uniform",
    "19.1": "uniform",
    "20.1": "uniform",
    "21.1": [("D2", "E1"), ("C2", "D2"), ("A2", "B2"), ("B2", "C2")],
    "22.1": [("C2", "D2"), ("A2", "B2"), ("B2", "C2")],
    "23.1": [("A2", "B2")],
    "24.1": [("A2", "C2")],
}

titles = []
for head_id, edges in flow.items():
    layer, head = head_id.split(".")
    layer = int(layer)
    head = int(head)
    if isinstance(edges, list):
        for dst, src in edges:
            src = src[0] + " (" + ("first" if src[1] == "1" else "second") + ")"
            dst = dst[0] + " (" + ("first" if dst[1] == "1" else "second") + ")"
            titles.append(f"Layer {layer}, Head {head}, {src} → {dst}")

threshold = 0.5


# Function to linearly interpolate the x where y first exceeds 0.5
def find_first_x_above_threshold(x, y):
    for i in range(1, len(y)):
        if y[i - 1] <= threshold < y[i]:  # Find where y crosses the threshold
            # Perform linear interpolation between points (x[i-1], y[i-1]) and (x[i], y[i])
            slope = (y[i] - y[i - 1]) / (x[i] - x[i - 1])
            x_at_threshold = x[i - 1] + (threshold - y[i - 1]) / slope
            return x_at_threshold
    return None  # Return None if no crossing found


# Function to read CSV and plot each column
def plot_columns_from_csv(csv_file, output_dir):
    # Create the output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Read the CSV file, assuming no header and columns are named as x, y1, y2, ..., yN
    df = pd.read_csv(csv_file, header=None)

    # The first column is 'x', and the rest are 'y1', 'y2', ..., 'yN'
    x = df.iloc[:, 0]  # First column (x)

    # Iterate over the remaining columns to plot each y separately
    for i in range(1, df.shape[1]):
        y = df.iloc[:, i]  # Select each yi column

        x_above_05 = find_first_x_above_threshold(x, y)
        if x_above_05 is not None:
            x_above_05 = int(x_above_05 + 0.5)  # Round to the nearest integer
            print(f"For {titles[i - 1]}, the first Epoch where Attention > {threshold} is: {x_above_05}")
        else:
            print(f"For y{i}, no values are greater than 0.5")

        plt.figure()
        plt.plot(x, y)
        plt.xlabel("Epoch")
        plt.ylabel("Attention")
        plt.title(titles[i - 1])
        plt.tight_layout()

        # Save the plot as a separate image file
        output_file = os.path.join(output_dir, f"plot_y{i}.pdf")
        plt.savefig(output_file)
        plt.close()


# Example usage
csv_file = "attn_ablat/model_D/results.csv"  # Path to your CSV file
output_dir = "attn_ablat/model_D/plots"  # Directory to save the plot images
plot_columns_from_csv(csv_file, output_dir)
