#!/usr/bin/env python3
"""
Aggregate SNN grid-search results across time steps t into a summary res.mat
Place one summary res.mat in each parent directory of the t_{}/ folders (i.e. in bs_{}/)

Behavior & requirements (from user):
- Do NOT modify any original res.mat files inside t_{}/
- Minimize try/except; when error occurs, print a clear message
- Verify that K (length of LASFR flattened) is the same across all t for a given combination
- Output file name: res.mat saved into the bs_... directory (parent of t_ dirs)

Assumptions:
- Individual t_{t}/res.mat files contain keys 'snn_acc' (can convert to float)
  and 'LASFR' (when flattened is length K)
- t_values ordering is important; we use the provided t_values list to order

This script tries to be robust to MATLAB mat variations (squeezed arrays, nested types)

Usage: edit ROOT and optionally the value lists below, then run.
"""

import os
import sys
from typing import Optional, Tuple, List

import numpy as np
from scipy.io import loadmat, savemat


# --- Configuration (edit here if needed) ---
ROOT = "./snn_conversion/AEC/results"

# grid values (use the same values you described)
LRS = [1e-3, 1e-4, 1e-5]
BSS = [32, 64, 128]
T_VALUES = [2, 4, 8, 16, 32, 64, 128]  # must be ordered small->large

# If you have a fixed list of architectures/datasets, you can populate these lists.
# Otherwise the script will scan subdirectories under ROOT for architectures and datasets.
# ARCHS = ['MLP', 'VGG-16']  # example
# DSETS = ['CIFAR10']


# --- Helpers ---

def safe_loadmat(path: str) -> Optional[dict]:
    """Load a .mat file and return the dict. On error print and return None.
    Minimize try/except: only catch load failures and report them.
    """
    if not os.path.isfile(path):
        print(f"MISSING: {path}")
        return None
    try:
        md = loadmat(path, squeeze_me=True, struct_as_record=False)
        return md
    except Exception as e:
        print(f"ERROR loading .mat: {path} -> {e}")
        return None


def extract_snn_acc(mat: dict) -> Optional[float]:
    if 'snn_acc' not in mat:
        print("ERROR: key 'snn_acc' not in mat")
        return None
    val = mat['snn_acc']
    try:
        # try to coerce to python float
        arr = np.array(val)
        # reduce to scalar
        if arr.size == 1:
            return float(np.squeeze(arr))
        # otherwise try to take first element
        return float(np.squeeze(arr).item())
    except Exception as e:
        print(f"ERROR converting snn_acc to float: {e} (value={type(val)})")
        return None


def extract_lasfr_flat(mat: dict) -> Optional[np.ndarray]:
    if 'LASFR' not in mat:
        print("ERROR: key 'LASFR' not in mat")
        return None
    val = mat['LASFR']
    try:
        arr = np.array(val)
        # If it's a structured object array (common when MATLAB structs are present), try to ravel
        if arr.size == 0:
            # empty
            return np.array([])
        # squeeze then flatten
        arr = np.squeeze(arr)
        if arr.ndim == 0:
            return np.array([arr]).ravel()
        # if it's 2D or more, flatten to 1D (user said flattened length K)
        return arr.ravel()
    except Exception as e:
        print(f"ERROR extracting LASFR: {e} (type={type(val)})")
        return None


# --- Main routine ---

def aggregate_for_root(root: str):
    if not os.path.isdir(root):
        print(f"Root path does not exist: {root}")
        return

    archs = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))]
    if not archs:
        print(f"No architectures found under root: {root}")
        return

    for arch in archs:
        arch_path = os.path.join(root, arch)
        datasets = [d for d in os.listdir(arch_path) if os.path.isdir(os.path.join(arch_path, d))]
        if not datasets:
            print(f"No datasets found under {arch_path}")
            continue

        SPARSITIES = (0.0, 0.99) if arch=='MLP' else (0.0, 0.5)

        for ds in datasets:
            for cur_sparsity in SPARSITIES:
                # construct sparse path consistent with your mat_path function
                if arch == 'MLP':
                    sparse_dir = os.path.join(arch_path, ds, f"s_{cur_sparsity}")
                else:
                    sparse_dir = os.path.join(arch_path, ds, f"conv_{cur_sparsity}", "s_0.0")

                if not os.path.isdir(sparse_dir):
                    # this is not an exceptional error, just skip
                    # but print so the user knows
                    print(f"SKIP (no sparse dir): {sparse_dir}")
                    continue

                for lr in LRS:
                    lr_dir = os.path.join(sparse_dir, f"lr_{lr}")
                    if not os.path.isdir(lr_dir):
                        print(f"SKIP (no lr dir): {lr_dir}")
                        continue

                    for bs in BSS:
                        bs_dir = os.path.join(lr_dir, f"bs_{bs}")
                        if not os.path.isdir(bs_dir):
                            print(f"SKIP (no bs dir): {bs_dir}")
                            continue

                        # collect per-t results
                        accs: List[float] = []
                        lasfr_list: List[np.ndarray] = []
                        missing_any = False

                        for t in T_VALUES:
                            t_dir = os.path.join(bs_dir, f"t_{t}")
                            res_path = os.path.join(t_dir, 'res.mat')
                            md = safe_loadmat(res_path)
                            if md is None:
                                print(f"Skipping combination due to missing/invalid file: {res_path}")
                                missing_any = True
                                break

                            acc = extract_snn_acc(md)
                            if acc is None:
                                print(f"Skipping combination due to snn_acc issue at: {res_path}")
                                missing_any = True
                                break

                            lasfr = extract_lasfr_flat(md)
                            if lasfr is None:
                                print(f"Skipping combination due to LASFR issue at: {res_path}")
                                missing_any = True
                                break

                            accs.append(acc)
                            lasfr_list.append(lasfr)

                        if missing_any:
                            print(f"NOT SAVED for {arch}/{ds}/lr_{lr}/bs_{bs} due to previous errors.\n")
                            continue

                        # validation: we must have exactly len(T_VALUES) entries
                        if len(accs) != len(T_VALUES) or len(lasfr_list) != len(T_VALUES):
                            print(f"ERROR: incomplete data for {arch}/{ds}/lr_{lr}/bs_{bs} (got {len(accs)} t-results)")
                            continue

                        # verify K consistent
                        Ks = [arr.size for arr in lasfr_list]
                        if len(set(Ks)) != 1:
                            print(f"ERROR: inconsistent K sizes across t for {arch}/{ds}/lr_{lr}/bs_{bs}: K sizes = {Ks}")
                            continue

                        K = Ks[0]
                        # stack LASFR as shape (K, len(T_VALUES)) -> (K,7)
                        try:
                            lasfr_stack = np.stack(lasfr_list, axis=1)
                        except Exception as e:
                            print(f"ERROR stacking LASFR for {arch}/{ds}/lr_{lr}/bs_{bs}: {e}")
                            continue

                        snn_acc_array = np.array(accs, dtype=float)

                        # save summary file in bs_dir (parent of t_ dirs)
                        summary_path = os.path.join(bs_dir, 'res.mat')
                        # Do not modify t_{}/res.mat files; we create/overwrite parent res.mat
                        try:
                            savemat(summary_path, {'snn_acc': snn_acc_array, 'LASFR': lasfr_stack})
                            print(f"SAVED summary: {summary_path}  (LASFR shape={lasfr_stack.shape}, snn_acc.shape={snn_acc_array.shape})")
                        except Exception as e:
                            print(f"ERROR saving summary to {summary_path}: {e}")


if __name__ == '__main__':
    # allow overriding root by environment variable or command line
    root = ROOT
    if len(sys.argv) > 1:
        root = sys.argv[1]
    aggregate_for_root(root)
