import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np


def read_csv_files(folder_path, has_extra_losses=True):
    """Reads all CSV files in the folder and extracts relevant data"""
    data = {}

    # List all files in the folder
    for filename in os.listdir(folder_path):
        if filename.startswith("layers_") and filename.endswith(".csv"):
            # Extract layer and run number from filename
            parts = filename.replace(".csv", "").split("_")
            num_layers = int(parts[1])
            run_id = int(parts[3])

            # Read the CSV file
            if has_extra_losses:
                columns = ["epoch", "train_loss", "valid_loss", "loss_1", "loss_2", "loss_3", "loss_4", "loss_5"]
            else:
                columns = ["epoch", "train_loss", "valid_loss", "loss_1", "loss_2", "loss_3"]

            file_path = os.path.join(folder_path, filename)
            df = pd.read_csv(file_path, header=None, names=columns)

            # Store data based on the number of layers
            if num_layers not in data:
                data[num_layers] = []
            data[num_layers].append((run_id, df))

    return data


def compute_averages(data, has_extra_losses=True):
    """Computes the average of the last 50 epochs for each metric"""
    averages = {}

    for num_layers, runs in data.items():
        avg_valid_loss = []
        avg_loss_1 = []
        avg_loss_2 = []
        avg_loss_3 = []
        if has_extra_losses:
            avg_loss_4 = []
            avg_loss_5 = []

        for run_id, run in runs:
            # Select the last 50 epochs
            last_50_epochs = run.tail(50)
            avg_valid_loss.append(last_50_epochs["valid_loss"].mean())
            avg_loss_1.append(last_50_epochs["loss_1"].mean())
            avg_loss_2.append(last_50_epochs["loss_2"].mean())
            avg_loss_3.append(last_50_epochs["loss_3"].mean())

            if has_extra_losses:
                avg_loss_4.append(last_50_epochs["loss_4"].mean())
                avg_loss_5.append(last_50_epochs["loss_5"].mean())

        # Store the averages for each number of layers
        if has_extra_losses:
            averages[num_layers] = {
                "valid_loss": np.mean(avg_valid_loss),
                "loss_1": np.mean(avg_loss_1),
                "loss_2": np.mean(avg_loss_2),
                "loss_3": np.mean(avg_loss_3),
                "loss_4": np.mean(avg_loss_4),
                "loss_5": np.mean(avg_loss_5),
            }
        else:
            averages[num_layers] = {
                "valid_loss": np.mean(avg_valid_loss),
                "loss_1": np.mean(avg_loss_1),
                "loss_2": np.mean(avg_loss_2),
                "loss_3": np.mean(avg_loss_3),
            }

    return averages


def find_first_epoch_below_threshold(data, threshold=0.5):
    """Finds the first epoch where losses go below a given threshold and checks for increasing epochs"""
    for num_layers, runs in data.items():
        for run_id, run in runs:
            epoch_breakthroughs = []

            print(f"Run {run_id} (Layers: {num_layers}):")
            for loss_col in ["loss_1", "loss_2", "loss_3", "loss_4", "loss_5"]:
                if loss_col in run.columns:
                    below_threshold_epochs = run[run[loss_col] < threshold]["epoch"]
                    if not below_threshold_epochs.empty:
                        first_epoch = below_threshold_epochs.iloc[0]
                        print(f"  {loss_col} goes below {threshold} at epoch {first_epoch}")
                        epoch_breakthroughs.append(first_epoch)
                    else:
                        print(f"  {loss_col} never goes below {threshold}")

            # Test if breakthrough epochs are increasing
            for i in range(1, len(epoch_breakthroughs)):
                if epoch_breakthroughs[i] < epoch_breakthroughs[i - 1]:
                    print("WARNING: Breakthrough epochs are not increasing!", epoch_breakthroughs)
                    exit()

            print("All breakthrough epochs are increasing!\n")


def validate_non_ic_final_loss(data, threshold=0.7):
    """Validates that the final validation loss for non-IC data never goes below 0.7"""
    for num_layers, runs in data.items():
        for run_id, run in runs:
            final_valid_loss = run["valid_loss"].iloc[-1]
            if final_valid_loss < threshold:
                print(f"WARNING: Run {run_id} (Layers: {num_layers}) has final validation loss {final_valid_loss} below {threshold}")
                exit()
            else:
                print(f"Run {run_id} (Layers: {num_layers}) passed with final validation loss {final_valid_loss} >= {threshold}")


def plot_validation_loss(averages_ic, averages_non_ic, output_file):
    """Plots the final validation loss by number of layers for IC and non-IC data and saves it to a PDF"""
    layers_ic = sorted(averages_ic.keys())
    valid_losses_ic = [averages_ic[l]["valid_loss"] for l in layers_ic]

    layers_non_ic = sorted(averages_non_ic.keys())
    valid_losses_non_ic = [averages_non_ic[l]["valid_loss"] for l in layers_non_ic]

    plt.figure(figsize=(6.5, 3.25))
    plt.plot(layers_ic, valid_losses_ic, marker="o", label="IC, D=5")
    plt.plot(layers_non_ic, valid_losses_non_ic, marker="s", label="non-IC, D=3")
    plt.xlabel("Number of Layers")
    plt.ylabel("Validation Loss")
    plt.title("Final Validation Loss by Number of Layers")
    plt.ylim(0, 1.2)  # Set y-axis limits only for the validation loss plot
    plt.grid(True)
    plt.legend(loc="upper right")  # Legend at the top right
    plt.tight_layout()

    # Save the plot to a PDF file
    plt.savefig(output_file)
    plt.close()


def plot_final_losses(averages_ic, output_file):
    """Plots the final losses (1,2,3,4,5) by number of layers and saves it to a PDF"""
    layers = sorted(averages_ic.keys())
    loss_1 = [averages_ic[l]["loss_1"] for l in layers]
    loss_2 = [averages_ic[l]["loss_2"] for l in layers]
    loss_3 = [averages_ic[l]["loss_3"] for l in layers]
    loss_4 = [averages_ic[l]["loss_4"] for l in layers]
    loss_5 = [averages_ic[l]["loss_5"] for l in layers]

    plt.figure(figsize=(6.5, 3.25))
    plt.plot(layers, loss_1, marker="o", label="X$_1$ (B)")
    plt.plot(layers, loss_2, marker="s", label="X$_2$ (C)")
    plt.plot(layers, loss_3, marker="x", label="X$_3$ (D)")
    plt.plot(layers, loss_4, marker="d", label="X$_4$ (E)")
    plt.plot(layers, loss_5, marker="*", label="X$_5$ (F)")

    plt.xlabel("Number of Layers")
    plt.ylabel("Final Loss")
    plt.title("Final Partial Loss by Number of Layers (IC)")
    plt.legend(loc="upper right")  # Legend at the top right
    plt.grid(True)
    plt.tight_layout()

    # Save the plot to a PDF file
    plt.savefig(output_file)
    plt.close()


def main():
    # Load data for IC (D=5)
    folder_path_ic = "./mega_run/d=5_40k_steps"
    data_ic = read_csv_files(folder_path_ic, has_extra_losses=True)
    averages_ic = compute_averages(data_ic, has_extra_losses=True)

    # Load data for non-IC (D=3)
    folder_path_non_ic = "./mega_run/d=3_nonIC"
    data_non_ic = read_csv_files(folder_path_non_ic, has_extra_losses=False)
    averages_non_ic = compute_averages(data_non_ic, has_extra_losses=False)

    # Generate and save the plots
    plot_validation_loss(averages_ic, averages_non_ic, "validation_loss_by_layers.pdf")
    plot_final_losses(averages_ic, "final_losses_by_layers.pdf")

    # Find and print the first epoch where each loss goes below 0.5 for IC
    find_first_epoch_below_threshold(data_ic, threshold=0.5)

    # Validate that the final validation loss for non-IC never goes below 0.7
    validate_non_ic_final_loss(data_non_ic, threshold=0.7)


if __name__ == "__main__":
    main()
