from eos_line_search.run import *
from eos_line_search.plot import *
from eos_line_search.data import *
from haven import haven_utils as hu
import torch
import cProfile
import pstats

import numpy as np
from eos_line_search import (
    plotting_offline,
    workshop_plotting,
)
import os


def plot_experiment(path):
    experiment_directory = os.path.join(path, "experiments")

    # Get all dataset directories
    dataset_dirs = [
        d
        for d in os.listdir(experiment_directory)
        if os.path.isdir(os.path.join(experiment_directory, d))
    ]

    for dataset in dataset_dirs:
        dataset_path = os.path.join(experiment_directory, dataset)

        # Get all model directories
        model_dirs = [
            d
            for d in os.listdir(dataset_path)
            if os.path.isdir(os.path.join(dataset_path, d))
        ]

        for model in model_dirs:
            model_path = os.path.join(dataset_path, model)
            batch_size_path = os.path.join(model_path, "full")
            # Get all batch_size directories
            optimizers_dirs = [
                d for d in os.listdir(model_path) if os.path.isdir(batch_size_path)
            ]

            runs_list = []
            for optimizer in ["PoNoS"]:  # , "SLS"]:  #optimizers_dirs:
                optimizer_path = os.path.join(batch_size_path, optimizer)

                # Get all pickle files in this directory
                pickle_files = [
                    f
                    for f in os.listdir(optimizer_path)
                    if os.path.isfile(os.path.join(optimizer_path, f))
                ]

                # Load each pickle file
                for pkl_file in pickle_files:
                    pkl_path = os.path.join(optimizer_path, pkl_file)
                    try:
                        run = hu.load_pkl(pkl_path)["run"]
                        runs_list.append(run)
                        print(f"Loaded: {pkl_path}")
                    except Exception as e:
                        print(f"Failed to load {pkl_path}: {e}")
            print("Plot Experiment")
            plotting_offline.plot_assmpt(runs_list, path)
            plotting_offline.plot_assmpt_per_it(runs_list, path)


if __name__ == "__main__":
    plot_experiment(os.getcwd())
