from eos_line_search.run import *
from eos_line_search.plot import *
from eos_line_search.data import *
from haven import haven_utils as hu
import torch
import pickle
import io
import sys

import numpy as np
import os


path = sys.argv[1]
pickle_files = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]

# Load each pickle file
for pkl_file in pickle_files:
    pkl_path = os.path.join(path, pkl_file)

    class CPUUnpickler(pickle.Unpickler):
        def find_class(self, module, name):
            if module == "torch.storage" and name == "_load_from_bytes":
                return lambda b: torch.load(io.BytesIO(b), map_location="cpu")
            elif module.startswith("torch.cuda"):
                # Redirect CUDA classes to CPU equivalents
                if "FloatStorage" in name:
                    return torch.FloatStorage
                elif "LongStorage" in name:
                    return torch.LongStorage
                elif "IntStorage" in name:
                    return torch.IntStorage
                elif "ByteStorage" in name:
                    return torch.ByteStorage
                elif "DoubleStorage" in name:
                    return torch.DoubleStorage
                elif "HalfStorage" in name:
                    return torch.HalfStorage
            return super().find_class(module, name)

    if torch.cuda.is_available():
        try:
            result = hu.load_pkl(pkl_path)
        except Exception as e:
            print(f"Failed to load {pkl_path}: {e}")
    else:
        with open(pkl_path, "rb") as f:
            result = CPUUnpickler(f).load()
    run = result["run"]
    print(run.run_id, run.optimizer, run.epochs)
