import pickle
import pandas as pd
import os
import glob
import numpy as np

from pathlib import Path
from collections import defaultdict


class ResultsLoader:
    def __init__(self, path, file_name="data.pkl"):
        self.path = path
        self.file_name = file_name
        self.metrics_file = f"{self.path}/{self.file_name}"
        self._load_data()

    def _load_data(self):
        with open(self.metrics_file, "rb") as f:
            data = pickle.load(f)

        self.logs = data.get("logs", {})
        try:
            for k1, v1 in self.logs.items():
                for k2, v2 in v1.items():
                    self.logs[k1][k2] = np.array(v2)
        except Exception:
            pass

        self.net_setup_params = data.get("net_setup_params", {})
        self.training_params  = data.get("training_params", {})
        self.problem_params   = data.get("problem_cfg", {})
        self.n_processes      = data.get("n_processes", {})
        self.run_times        = data.get("run_times", {})
        self.runtime_args     = data.get("runtime_args", {})
        self.timings          = data.get("timings", {})

    # Accessors
    def get_gating_params(self, level):
        return self.net_setup_params.get("gating", {}).get(level, {})

    def get_training_params(self):
        return self.training_params

    def get_poly_params(self, level):
        return self.net_setup_params.get("poly", {}).get(level, {})

    def get_num_partitions(self, level):
        return self.get_gating_params(level).get("num_partitions")

    def get_error_norms(self, level):
        return self.logs[level].get("l2_error")

    def get_loss(self, level):
        return self.logs[level].get("loss")

    def get_coarse_partitions(self):
        return self.get_gating_params("coarse").get("num_partitions")

    def get_fine_partitions(self):
        return self.get_coarse_partitions() * self.get_gating_params("fine").get("num_partitions")

    def get_total_partitions(self):
        return self.get_coarse_partitions() * (1 + self.get_fine_partitions())

    def get_num_nodes(self):
        return self.n_processes.get("nodes", None)

    def get_run_times(self):
        return self.run_times

    def get_timings(self, level):
        df = pd.DataFrame([{k: float(v) for k, v in d.items()}
                                for d in self.logs[level]['timings']])
        return df

class ResultsBatchProcessor:
    def __init__(self, base_results_path):
        self.base_results_path = Path(base_results_path)
        self.processed = {}
        self.process_all()

    def process_all(self):
        print(f"Collecting data from: {self.base_results_path}")
        for subdir in sorted(self.base_results_path.iterdir()):
            if subdir.is_dir():
                data_file = subdir / "data.pkl"
                if data_file.exists():
                    loader = ResultsLoader(subdir)
                    key = subdir.name
                    self.processed[key] = loader
                    print(f"Processed: {key}")
                else:
                    print(f"Skipped (no data.pkl): {subdir.name}")

        print("\n")

    def get_loader(self, key):
        return self.processed.get(key, None)

    def get_loaders(self):
        return self.processed

    def group_by_solver(self):
        grouped = defaultdict(list)
        for name, loader in self.processed.items():
            if "slv_" in name:
                solver = name.split("slv_")[-1]
                grouped[solver].append(loader)
        return dict(grouped)

