import os
import re
import sys
import matplotlib.pyplot as plt
from collections import defaultdict

def parse_epoch_layer_trends(filepath):
    """
    Parses the log file and returns per-epoch NSA and LNSA trends per layer.
    Returns a tuple: (epoch_numbers, nsa_per_layer, lnsa_per_layer)
    """
    pattern = r"Epoch (\d+), NSA: ([\d\., ]+)"
    raw_epoch_data = defaultdict(lambda: defaultdict(list))

    with open(filepath, 'r') as f:
        for line in f:
            match = re.search(pattern, line)
            if match:
                epoch = int(match.group(1))
                values_nsa = [float(v.strip()) for v in match.group(2).split(',')]
                for layer, v_total in enumerate(values_nsa):
                    if v_total != 0.0 and v_total != 'nan':
                        raw_epoch_data[epoch][layer].append(v_total)
                    else:
                        raw_epoch_data[epoch][layer].append(None)

    all_epochs = sorted(raw_epoch_data.keys())
    layer_set = set()
    for epoch_layers in raw_epoch_data.values():
        layer_set.update(epoch_layers.keys())
    max_layer = max(layer_set)

    nsa_per_layer = {l: [None] * len(all_epochs) for l in range(max_layer + 1)}

    for i, epoch in enumerate(all_epochs):
        for layer in range(max_layer + 1):
            if layer in raw_epoch_data[epoch]:
                nsa_vals = raw_epoch_data[epoch][layer]

                nsa_running_total = 0
                len_nsa_vals = 1
                for nsa_val in nsa_vals:
                    if nsa_val == None or nsa_val > 1:
                        continue
                    else:
                        nsa_running_total += nsa_val
                        len_nsa_vals += 1
                nsa_avg = nsa_running_total / len_nsa_vals
                nsa_per_layer[layer][i] = nsa_avg

    return all_epochs, nsa_per_layer


def plot_layer_trends_across_epochs(filepaths, mode="lnsa", save_as="epoch_trends"):
    """
    Plots NSA or LNSA trends over epochs for each layer in 3 subplots (one per category/file).
    """
    assert mode in ("nsa", "lnsa")
    fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True)

    for i, filepath in enumerate(filepaths):
        ax = axes[i]
        label = f'Category {i + 1}'
        epochs, nsa_layer_trends = parse_epoch_layer_trends(filepath)
        trends = nsa_layer_trends
        trends.popitem()

        for layer, values in trends.items():
            ax.plot(epochs, values, label=f'Layer {layer+1}')

        ax.set_title(label)
        ax.set_xlabel("Epoch")
        if i == 0:
            ax.set_ylabel(f"{'LNSA' if mode == 'lnsa' else 'NSA'}")
        ax.grid(True)
        ax.legend()

    plt.suptitle(f"{'LNSA' if mode == 'lnsa' else 'NSA'} Trends per Layer over Epochs")
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(f"{save_as}_{mode}_trends.png")
    plt.show()

if __name__ == "__main__":
    if len(sys.argv) < 5:
        print("Usage: python script.py <file1.txt> <file2.txt> <file3.txt> <nsa|lnsa> <save_as>")
        sys.exit(1)

    files = sys.argv[1:4]
    mode = sys.argv[4]
    save_as = sys.argv[5]
    plot_layer_trends_across_epochs(files, mode, save_as)
