

from __future__ import annotations
import os
os.environ['TENSORLY_BACKEND'] = 'numpy'
import sys
import json
import time
import math
import pickle
import argparse
import logging
from logging.handlers import RotatingFileHandler
from datetime import datetime
from dataclasses import dataclass
from typing import Tuple, Optional, List, Dict, Any
import torch

import numpy as np
import pandas as pd
from PIL import Image
import tensorly as tl
# tl.set_backend('numpy')
# print(tl.get_backend())

                                            
# from ..cp_regression_cuda import CPRegressor
# from ..cp_basisexpansion_regression import CPBasisRegressor 
                                                   
# from ..tucker_regression_2 import TuckerRegressor  # optional
                                                                     
                                                                         
                                                                             
                                                                                        
                                                                             
                                                                                         
from ..supervised_pca_regression_perp import SpectralBlockPerpRGDRegressorFast
# from ..full_regression_cuda import FullRegressor
      
#     from ..tt_regression import TTRegressorIHT,TTRegressorRGD  # optional Tensor-Train regressor
                   
#     TTRegressorIHT = None
# from ..HOPLS_cuda import HOPLSRegressor
# from ..NPLS_cuda import NPLSRegressor
# from ..tt_regression import TTRegressorIHT,TTRegressorRGD
from ...basis_vector import dft_basis,dft_basis_high2low,eye_basis
from ...metrics.regression import RMSE
from ... import backend as T
from tensorly import unfold, fold
from tensorly.tenalg import multi_mode_dot, tensordot
import string
ALPHABET = string.ascii_lowercase
                            
# Logging
                            

                                                                                              
#     os.makedirs(log_dir, exist_ok=True)
#     if run_name is None:
#         run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_run")
#     run_dir = os.path.join(log_dir, run_name)
#     os.makedirs(run_dir, exist_ok=True)

#     logger = logging.getLogger(run_name)
#     logger.setLevel(level)
#     logger.propagate = False

#     ch = logging.StreamHandler(sys.stdout)
#     ch.setLevel(level)
#     ch.setFormatter(logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s", "%H:%M:%S"))
#     logger.addHandler(ch)

#     fh = RotatingFileHandler(os.path.join(run_dir, "log.txt"), maxBytes=5_000_000, backupCount=3, encoding="utf-8")
#     fh.setLevel(level)
#     fh.setFormatter(logging.Formatter("%(asctime)s | %(name)s | %(levelname)s | %(message)s", "%Y-%m-%d %H:%M:%S"))
#     logger.addHandler(fh)

#     logger.info("Logger initialized. Run dir: %s", run_dir)
#     return logger, run_dir
                           

def analyze_saved_predictions(run_dir: str, adapter: DatasetAdapter, logger: logging.Logger):
\
\
       
    preds_path = os.path.join(run_dir, "predictions_test.npz")
    if not os.path.exists(preds_path):
        logger.warning(f"File not found: {preds_path}. Skipping saved prediction analysis.")
        return

    try:
        preds = np.load(preds_path)
        Y_true_orig_saved = preds["y_true"]
        Y_pred_orig_saved = preds["y_pred"]
    except Exception as e:
        logger.error(f"Failed to load or access data from {preds_path}. Error: {e}")
        return
        
    attr_names = adapter.attr_names()
    
    logger.info("\n" + "="*20 + " Saved Prediction Metrics (Y_te_orig) " + "="*20)
    
          
    rpe_saved = relative_prediction_error(Y_true_orig_saved, Y_pred_orig_saved)
    r_flat_saved = pearson_r_flat(Y_true_orig_saved, Y_pred_orig_saved)
    logger.info(f"FULL Saved Pred (Overall):\t RPE = {rpe_saved:.6f},\t r_flat = {r_flat_saved:.6f}")
    
                                               
    if Y_true_orig_saved.ndim == 2:
        per_dim_saved_results = []
        for d, attr_name in enumerate(attr_names):
                                          
            if d >= Y_true_orig_saved.shape[1]:
                logger.warning(f"Attribute names mismatch Y dim: {len(attr_names)} vs {Y_true_orig_saved.shape[1]}. Stopping per-dim analysis.")
                break
                
            r_dim = pearson_r_flat(Y_true_orig_saved[:, d], Y_pred_orig_saved[:, d])
            rpe_dim = relative_prediction_error(Y_true_orig_saved[:, d], Y_pred_orig_saved[:, d])
            per_dim_saved_results.append(f" {attr_name}: r={r_dim:.4f}, RPE={rpe_dim:.4f}")
        logger.info(f"FULL Saved Pred (Per-dim):\n  {' | '.join(per_dim_saved_results)}")
    else:
        logger.warning(f"Saved Y shape ({Y_true_orig_saved.shape}) is not 2D. Skipping per-dim analysis.")

    logger.info("="*64 + "\n")

def analyze_model_components(run_dir: str):
\
\
\
       
             
    logger, _ = setup_logger(log_dir=os.path.join(run_dir, "analysis_logs"), run_name="component_analysis")
    logger.info(f"--- Starting Component Analysis for run: {run_dir} ---")

                    
    try:
        prev_cfg, fixed_indices, preload_norm, estimator = load_previous_run(run_dir)
        logger.info("Successfully loaded model, config, and data artifacts.")
    except Exception as e:
        logger.error(f"Failed to load previous run from {run_dir}. Error: {e}")
        return

    if not isinstance(estimator, SpectralBlockPerpRGDRegressorFast):
        logger.error(f"Analysis is only supported for 'pcaregperp' models. Found {type(estimator)}.")
        return

                  
    adapter = build_adapter(argparse.Namespace(**prev_cfg), logger)
    X, Y, _ = adapter.load()
    _, te_idx = fixed_indices
    X_te, Y_te_orig = X[te_idx], Y[te_idx]                     

    mean_X, y_mean, y_std = preload_norm
    X_te_c = X_te - mean_X
                          
    Y_te_z = (Y_te_orig - y_mean) / np.where(y_std < 1e-8, 1e-8, y_std)
    logger.info(f"Test data prepared: X_test shape={X_te.shape}, Y_test shape={Y_te_orig.shape}")

                  
                                                  
    X_te_T = torch.as_tensor(X_te_c, device=estimator.device, dtype=estimator.dtype)
    
    with torch.no_grad():
                        
        Xc_T = estimator._build_Xcore_tensor(X_te_T)
    
    Xc_np = Xc_T.cpu().numpy()
    logger.info(f"Core tensor Xc built with shape: {Xc_np.shape}")

                          
    num_samples = Xc_np.shape[0]
    num_components = Xc_np.shape[-1]
    
    logger.info("\n" + "="*20 + " Per-Component Regression Results " + "="*20)

    for k in range(num_components):
                          
        # Xc_np shape: (N, r1, ..., rn, K_eff)
        features_k = Xc_np[..., k]
        features_k_flat = features_k.reshape(num_samples, -1)

                   
        A = np.hstack([features_k_flat, np.ones((num_samples, 1))])

        try:
                                  
                                                       
            W_k, _, _, _ = np.linalg.lstsq(A, Y_te_z, rcond=None)

                       
            Y_pred_z_k = A @ W_k

                       
            Y_pred_orig_k = Y_pred_z_k * y_std + y_mean

                             
            rpe_k = relative_prediction_error(Y_te_orig, Y_pred_orig_k)
            r_flat_k = pearson_r_flat(Y_te_orig, Y_pred_orig_k)
            
            combo_k = estimator.combos_[k] if k < len(estimator.combos_) else "N/A"

            logger.info(f"Component {k+1:02d}/{num_components} (Combo: {combo_k}):\t RPE = {rpe_k:.6f},\t r_flat = {r_flat_k:.6f}")

        except np.linalg.LinAlgError as e:
            logger.warning(f"Component {k+1:02d}/{num_components}: Could not solve linear regression. Error: {e}")

    logger.info("="*64 + "\n")
    analyze_saved_predictions(run_dir, adapter, logger)

def setup_logger(log_dir: str = "logs", run_name: Optional[str] = None, 
                 dataset: Optional[str] = None, method: Optional[str] = None,
                 level=logging.INFO):
    os.makedirs(log_dir, exist_ok=True)
    
                                            
    if run_name is None:
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
                                 
        # name_parts = [timestamp]
        name_parts = []
        if dataset:
            name_parts.append(dataset)
        if method:
            name_parts.append(method)
        name_parts.append("run")
        name_parts.append(timestamp)
        run_name = "_".join(name_parts)

    run_dir = os.path.join(log_dir, run_name)
    os.makedirs(run_dir, exist_ok=True)

    logger = logging.getLogger(run_name)
    logger.setLevel(level)
    logger.propagate = False
    if not logger.handlers:           
        ch = logging.StreamHandler(sys.stdout)
        ch.setLevel(level)
        ch.setFormatter(logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s", "%H:%M:%S"))
        logger.addHandler(ch)

        fh = RotatingFileHandler(os.path.join(run_dir, "log.txt"), maxBytes=5_000_000, backupCount=3, encoding="utf-8")
        fh.setLevel(level)
        fh.setFormatter(logging.Formatter("%(asctime)s | %(name)s | %(levelname)s | %(message)s", "%Y-%m-%d %H:%M:%S"))
        logger.addHandler(fh)

    logger.info("Logger initialized. Run dir: %s", run_dir)
    return logger, run_dir
                           

def analyze_pls_components(run_dir: str):
\
\
\
\
       
             
    logger, _ = setup_logger(log_dir=os.path.join(run_dir, "analysis_logs"), run_name="pls_component_analysis")
    logger.info(f"--- Starting PLS Component Analysis for run: {run_dir} ---")

                    
    try:
        prev_cfg, fixed_indices, preload_norm, estimator = load_previous_run(run_dir)
        logger.info("Successfully loaded model, config, and data artifacts.")
    except Exception as e:
        logger.error(f"Failed to load previous run from {run_dir}. Error: {e}")
        return

    method = prev_cfg.get("method")
    if method not in ["hopls", "npls"]:
        logger.error(f"Analysis is only supported for 'hopls' or 'npls' models. Found {method}.")
        return

                         
    adapter = build_adapter(argparse.Namespace(**prev_cfg), logger)
    X, Y, _ = adapter.load()
    tr_idx, te_idx = fixed_indices
    X_te, Y_te_orig = X[te_idx], Y[te_idx]                     

                                                      
                                                 
    mean_X, y_mean, y_std = preload_norm
    X_te_c = X_te - mean_X
                                             
    eps = 1e-8
    y_std_safe = np.where(y_std < eps, eps, y_std)
    Y_te_z = (Y_te_orig - y_mean) / y_std_safe          
    
    logger.info(f"Test data prepared: X_test shape={X_te.shape}, Y_test shape={Y_te_orig.shape}")

                           
    # HOPLS: T_ = (N_tr, R), NPLS: T_ = (N_tr, n_components)
                                                         
    
                             
    tl.set_backend('pytorch')
                        
    X_te_T = tl.tensor(X_te_c, device=estimator.device, dtype=torch.float32)
    if X_te_T.ndim > 1:
        modes_X = list(range(1, X_te_T.ndim))
    else:
        modes_X = [] 
    if method == "hopls":
                                                                            
        
                                                     
        T_te_list = []
        for r in range(estimator.n_components_):
                                 
            # Z_r_te shape: (N_test, L_1, L_2, ...)
            Z_r_te = multi_mode_dot(X_te_T, [p.T for p in estimator.P_[r]], modes=modes_X)
            
                                           
            # Z1_r_te shape: (N_test) x (Product of L_dims)
            Z1_r_te = unfold(Z_r_te, mode=0)
            
                                          
                                                         
            U_te, _, _ = torch.linalg.svd(Z1_r_te, full_matrices=False)
            t_r_te = U_te[:, 0]
            t_r_te /= (tl.norm(t_r_te) + 1e-12)
            
            T_te_list.append(tl.reshape(t_r_te, (-1,)))
        T_te_T = tl.stack(T_te_list, axis=1) # (N_te, R)
    elif method == "npls":
                                              
        x_indices = ALPHABET[:X_te_T.ndim]                     
        weight_indices = x_indices[1:]
        t_einsum_str = f'{x_indices},{",".join(weight_indices)}->{x_indices[0]}'
        
        T_te_list = []
        for r in range(estimator.n_components):
            W_r = [estimator.W_[i][:, r] for i in range(len(estimator.W_))] # (I_i,)
            t_r_te = tl.einsum(t_einsum_str, X_te_T, *W_r)
            t_r_te = tl.reshape(t_r_te, (-1, 1))
            T_te_list.append(t_r_te)
        T_te_T = tl.concatenate(T_te_list, axis=1) # (N_te, R)

    Xc_np = T_te_T.cpu().numpy()
    num_samples = Xc_np.shape[0]
    num_components = Xc_np.shape[-1]
    logger.info(f"Score tensor (Xc/T) extracted with shape: {Xc_np.shape}")

                                    
    
                  
    attr_names = adapter.attr_names()
    
    logger.info("\n" + "="*20 + f" Per-Component PLS Regression Results ({method.upper()})" + "="*20)

                         
    T_all = Xc_np # (N_te, R)
    A_all = np.hstack([T_all, np.ones((num_samples, 1))])
    W_all, _, _, _ = np.linalg.lstsq(A_all, Y_te_z, rcond=None)
    Y_pred_z_all = A_all @ W_all
    Y_pred_orig_all = Y_pred_z_all * y_std + y_mean
    rpe_all = relative_prediction_error(Y_te_orig, Y_pred_orig_all)
    r_flat_all = pearson_r_flat(Y_te_orig, Y_pred_orig_all)
    logger.info(f"ALL Components ({num_components} Ranks):\t RPE = {rpe_all:.6f},\t r_flat = {r_flat_all:.6f}")

    
                         
    for k in range(num_components):
                                 
        features_k = Xc_np[:, k].reshape(num_samples, 1)

                   
        A = np.hstack([features_k, np.ones((num_samples, 1))])

        try:
                                  
                                                         
            W_k, _, _, _ = np.linalg.lstsq(A, Y_te_z, rcond=None)

                       
            Y_pred_z_k = A @ W_k

                       
            Y_pred_orig_k = Y_pred_z_k * y_std + y_mean
            Y_true_orig = Y_te_orig

                             
            rpe_k = relative_prediction_error(Y_true_orig, Y_pred_orig_k)
            r_flat_k = pearson_r_flat(Y_true_orig, Y_pred_orig_k)
            
                                      
            per_dim_results = []
            for d, attr_name in enumerate(attr_names):
                r_dim = pearson_r_flat(Y_true_orig[:, d], Y_pred_orig_k[:, d])
                rpe_dim = relative_prediction_error(Y_true_orig[:, d], Y_pred_orig_k[:, d])
                per_dim_results.append(f" {attr_name}: r={r_dim:.4f}, RPE={rpe_dim:.4f}")
            
            logger.info(f"Component {k+1:02d}/{num_components}:\t RPE = {rpe_k:.6f},\t r_flat = {r_flat_k:.6f}")
            logger.info(f"  > Per-dim: {' | '.join(per_dim_results)}")

        except np.linalg.LinAlgError as e:
            logger.warning(f"Component {k+1:02d}/{num_components}: Could not solve linear regression. Error: {e}")

    logger.info("="*64 + "\n")
    analyze_saved_predictions(run_dir, adapter, logger)
                                                
#     preds_path = os.path.join(run_dir, "predictions_test.npz")
#     if os.path.exists(preds_path):
#         preds = np.load(preds_path)
#         Y_true_orig_saved = preds["y_true"]
#         Y_pred_orig_saved = preds["y_pred"]
        
#         logger.info("\n" + "="*20 + " Saved Prediction Metrics (Y_te_orig) " + "="*20)
        
                
#         rpe_saved = relative_prediction_error(Y_true_orig_saved, Y_pred_orig_saved)
#         r_flat_saved = pearson_r_flat(Y_true_orig_saved, Y_pred_orig_saved)
#         logger.info(f"FULL Saved Pred (Overall):\t RPE = {rpe_saved:.6f},\t r_flat = {r_flat_saved:.6f}")
        
                 
#         per_dim_saved_results = []
#         for d, attr_name in enumerate(attr_names):
#             r_dim = pearson_r_flat(Y_true_orig_saved[:, d], Y_pred_orig_saved[:, d])
#             rpe_dim = relative_prediction_error(Y_true_orig_saved[:, d], Y_pred_orig_saved[:, d])
#             per_dim_saved_results.append(f" {attr_name}: r={r_dim:.4f}, RPE={rpe_dim:.4f}")
#         logger.info(f"FULL Saved Pred (Per-dim): {' | '.join(per_dim_saved_results)}")
        
#         logger.info("="*64 + "\n")
#     else:
#         logger.warning(f"File not found: {preds_path}. Skipping saved prediction analysis.")
                              
# Utilities
                            

def train_test_split_indices(n: int, test_ratio=0.2, seed=1234):
    rng = np.random.RandomState(seed)
    idx = np.arange(n)
    rng.shuffle(idx)
    n_test = int(round(n * test_ratio))
    return idx[n_test:], idx[:n_test]  # train_idx, test_idx

def mean_center(X_train: np.ndarray, X_test: np.ndarray):
    mean = X_train.mean(axis=0, keepdims=True)
    return X_train - mean, X_test - mean, mean

def standardize_targets(y_train, y_test, eps=1e-8):
    mean = y_train.mean(axis=0, keepdims=True)
    std = y_train.std(axis=0, keepdims=True)
    std = np.where(std < eps, eps, std)
    return (y_train - mean) / std, (y_test - mean) / std, mean, std

def pearson_r_flat(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-12) -> float:
    yt = y_true.reshape(-1).astype(np.float64)
    yp = y_pred.reshape(-1).astype(np.float64)
    yt_m = yt - yt.mean()
    yp_m = yp - yp.mean()
    denom = (np.linalg.norm(yt_m) * np.linalg.norm(yp_m))
    if denom < eps:
        return 0.0
    return float(np.dot(yt_m, yp_m) / denom)

def relative_prediction_error(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-12) -> float:
    num = np.linalg.norm(y_pred - y_true)
    den = np.linalg.norm(y_true)
    if den < eps:
        return float("inf")
    return float(num / den)

def _make_head_ranks(wtype: str, r: Optional[int]):
    if r is None:
        return None
    key = {"cp": "cp", "tucker": "tucker", "tt": "tt"}.get(wtype)
    return {key: int(r)} if key else None

                            
# Dataset adapters
                            

class DatasetAdapter:
                                                                    
    name: str
    logger: logging.Logger

    def load(self) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
        raise NotImplementedError

    def split(self, N: int) -> Tuple[np.ndarray, np.ndarray]:
        raise NotImplementedError

    def attr_names(self) -> List[str]:
        return []

                                      
                            
# =SyntheticDataAdapter
                            
@dataclass
class SyntheticDataAdapter(DatasetAdapter):
    name: str = "synthetic"
    N: int = 5000
    dims: Tuple[int, ...] = (60, 60, 15)
    D: Tuple[int, ...] = (73,)         
    test_ratio: float = 0.2
    seed: int = 1234
    logger: logging.Logger = logging.getLogger("dummy")

    def load(self) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
        self.logger.info("Generating synthetic data")

                       
        rng = np.random.RandomState(self.seed)

                        
        X = rng.randn(self.N, *self.dims).astype(np.float32)

                          
                                          
        W_dims = self.dims + self.D
        W = rng.randn(*W_dims).astype(np.float32)

                        
                          
                                    
        X_flat = X.reshape(self.N, -1)
                                             
        W_flat = W.reshape(np.prod(self.dims), np.prod(self.D))

                
        Y = X_flat @ W_flat

                     
        Y = Y.reshape(self.N, *self.D)
        
        self.logger.info("Generated synthetic data -> X=%s, Y=%s", X.shape, Y.shape)
        return X, Y, {}

    def split(self, N: int) -> Tuple[np.ndarray, np.ndarray]:
        # Uses the common split function
        return train_test_split_indices(N, self.test_ratio, self.seed)

    def attr_names(self) -> List[str]:
        # Return a generic list of attribute names
        # For multi-dimensional Y, we flatten the names
        num_attrs = np.prod(self.D)
        return [f"Y_attr_{i}" for i in range(num_attrs)]

                            
# =EcogPersonAdapter
                            
@dataclass
class EcogPersonAdapter(DatasetAdapter):
\
\
\
       
    X_npy: str
    Y_npy: str
    name: str = "ecog_person"
    train_idx_npy: Optional[str] = None
    val_idx_npy: Optional[str] = None
    test_ratio: float = 0.2
    seed: int = 1234
    logger: logging.Logger = logging.getLogger("dummy")

    def __post_init__(self):
                   
        self._attr_names = ["Thumb", "Index", "Middle", "Ring", "Pinky"]

    def load(self) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
        self.logger.info("Loading ECoG (Person) from preprocessed .npy files")
        assert os.path.exists(self.X_npy), f"X_npy not found: {self.X_npy}"
        assert os.path.exists(self.Y_npy), f"Y_npy not found: {self.Y_npy}"
        
        X = np.load(self.X_npy)                                    
        Y = np.load(self.Y_npy)                
        
        assert X.shape[0] == Y.shape[0], f"样本数不一致: X={X.shape}, Y={Y.shape}"

                     
        if X.ndim != 4:
            self.logger.warning(f"期望 X 是 4D 张量 (N, C, F, T)，但实际是 {X.ndim}D。请检查预处理步骤。")
        else:
            self.logger.info(f"成功加载 4D ECoG 张量，形状: {X.shape}")
        
        if Y.ndim == 1:
            self.logger.info(f"检测到 Y 是一维数组 (shape: {Y.shape}), 将其重塑为二维。")
            Y = Y.reshape(-1, 1)
        X = X.astype(np.float32)
        Y = Y.astype(np.float32)

        self.logger.info("Loaded ECoG Person -> X=%s, Y=%s", X.shape, Y.shape)
                              
        return X, Y, {}

    def split(self, N: int) -> Tuple[np.ndarray, np.ndarray]:
                       
        if self.train_idx_npy and self.val_idx_npy and os.path.exists(self.train_idx_npy) and os.path.exists(self.val_idx_npy):
            try:
                                  
                tr_idx = np.load(self.train_idx_npy).astype(np.int64)
                te_idx = np.load(self.val_idx_npy).astype(np.int64)
                
                            
                assert tr_idx.max() < N and te_idx.max() < N, "索引超出样本范围"
                
                self.logger.info("使用固定的训练/验证集索引: train=%d, test=%d", len(tr_idx), len(te_idx))
                return tr_idx, te_idx
            except Exception as e:
                self.logger.warning(f"从文件加载索引失败 ({e})，将回退到随机划分。")
        
                                 
        self.logger.info(f"进行随机划分，测试集比例: {self.test_ratio}")
        return train_test_split_indices(N, self.test_ratio, self.seed)

    def attr_names(self) -> List[str]:
        return list(self._attr_names)

@dataclass
class LFWAdapter(DatasetAdapter):
    images_root: str
    attr_csv: str
    target_hw: Tuple[int, int] = (90, 90)
    test_ratio: float = 0.2
    seed: int = 1234
    logger: logging.Logger = logging.getLogger("dummy")
               
    train_idx_npy: Optional[str] = None
    val_idx_npy: Optional[str]   = None                         

    def _img_path(self, person, imagenum):
        fname = f"{person}_{int(imagenum):04d}.jpg"
        return os.path.join(self.images_root, person, fname)

    def _read_and_resize(self, image_path, size=(90, 90)):
        with Image.open(image_path) as im:
            im = im.convert("RGB")
            im = im.resize(size, resample=Image.BILINEAR)
            arr = np.asarray(im, dtype=np.float32) / 255.0
        return arr

    def load(self):
        self.logger.info("Loading LFW: root=%s, attr_csv=%s", self.images_root, self.attr_csv)
        df = pd.read_csv(self.attr_csv)
        assert {"person", "imagenum"}.issubset(df.columns), "属性 CSV 缺少 person/imagenum 列"
        attr_cols = [c for c in df.columns if c not in ("person", "imagenum")]
        if len(attr_cols) != 73:
            self.logger.warning("属性维度不是 73，而是 %d，仍按表格列推进。", len(attr_cols))

        X_list, Y_list, meta = [], [], []
        missing = 0
        for _, row in df.iterrows():
            person = str(row["person"]); imagenum = int(row["imagenum"])
            path = self._img_path(person, imagenum)
            if not os.path.exists(path):
                missing += 1; continue
            try:
                img = self._read_and_resize(path, size=self.target_hw)
            except Exception:
                missing += 1; continue
            attrs = row[attr_cols].astype(np.float32).to_numpy()
            X_list.append(img); Y_list.append(attrs); meta.append((person, imagenum, path))

        if missing > 0:
            self.logger.info("有 %d 条记录的图像缺失或读取失败，已跳过。", missing)

        X = np.stack(X_list, axis=0).astype(np.float32)   # (N, H, W, 3)
        Y = np.stack(Y_list, axis=0).astype(np.float32)   # (N, D)
        self._attr_names = attr_cols
        self._meta = np.array(meta, dtype=object)
        self.logger.info("Loaded LFW -> X=%s, Y=%s", X.shape, Y.shape)
        return X, Y, {"meta": self._meta}

    def _load_indices(self, path: str, N: int) -> np.ndarray:
        arr = np.load(path)
        if arr.dtype == bool:
            assert arr.shape[0] == N, f"mask 长度 {arr.shape[0]} != N {N}"
            idx = np.where(arr)[0]
        else:
            idx = arr.astype(np.int64)
        return idx

    def split(self, N: int):
                           
        if self.train_idx_npy and self.val_idx_npy and os.path.exists(self.train_idx_npy) and os.path.exists(self.val_idx_npy):
            try:
                tr = self._load_indices(self.train_idx_npy, N)
                te = self._load_indices(self.val_idx_npy,   N)
                                  
                inter = np.intersect1d(tr, te)
                if inter.size > 0:
                    self.logger.warning("train 与 test 有 %d 条重叠索引，将按文件使用。", inter.size)
                self.logger.info("使用固定划分：train=%d, test=%d (来自 %s / %s)",
                                 tr.size, te.size, os.path.basename(self.train_idx_npy), os.path.basename(self.val_idx_npy))
                return tr, te
            except Exception as e:
                self.logger.warning("读取固定划分失败(%s)，回退随机划分。", e)

                
        return train_test_split_indices(N, self.test_ratio, self.seed)

    def attr_names(self) -> List[str]:
        return list(self._attr_names)

                                      

@dataclass
class EcogMonkeyAdapter(DatasetAdapter):
    X_npy: str
    Y_npy: str
    mask_npy: Optional[str] = None  # True=artifact (chewing)
    train_idx_npy: Optional[str] = None
    val_idx_npy: Optional[str] = None
    ecog_reshape: Optional[Tuple[int, int, int]] = (64, 10, 10)  # reshape 6400 -> (64,10,10)
    test_ratio: float = 0.2
    seed: int = 1234
    logger: logging.Logger = logging.getLogger("dummy")

    def load(self):
        self.logger.info("Loading ECoG (Monkey) from .npy")
        assert os.path.exists(self.X_npy), f"X_npy not found: {self.X_npy}"
        assert os.path.exists(self.Y_npy), f"Y_npy not found: {self.Y_npy}"
        X = np.load(self.X_npy)  # (N, 6400) or (N, C,F,L)
        Y = np.load(self.Y_npy)  # (N, 3)
        assert X.shape[0] == Y.shape[0], f"样本数不一致: X={X.shape}, Y={Y.shape}"

        # Optional mask (True = artifact) -> filter out
        meta: Dict[str, Any] = {}
        if self.mask_npy and os.path.exists(self.mask_npy):
            mask = np.load(self.mask_npy).astype(bool)
            if mask.shape[0] == X.shape[0]:
                keep = ~mask
                self.logger.info("Chewing mask loaded. Removing %d/%d samples.", int(mask.sum()), mask.shape[0])
                X, Y = X[keep], Y[keep]
                meta["mask_kept"] = keep
            else:
                self.logger.warning("Mask length %d != N %d. Ignoring mask.", mask.shape[0], X.shape[0])

        # Reshape if X is flat and shape provided
        if self.ecog_reshape and X.ndim == 2:
            C, F, L = self.ecog_reshape
            expected = C * F * L
            if X.shape[1] != expected:
                raise ValueError(f"ECoG X second dim expected {expected} for reshape {self.ecog_reshape}, got {X.shape[1]}")
            X = X.reshape(X.shape[0], C, F, L).astype(np.float32)
            self.logger.info("Reshaped ECoG X -> %s", X.shape)
        else:
            X = X.astype(np.float32)

        Y = Y.astype(np.float32)
        self._attr_names = ["RWRI_X", "RWRI_Y", "RWRI_Z"]
        self.logger.info("Loaded ECoG -> X=%s, Y=%s", X.shape, Y.shape)
        return X, Y, meta

    def split(self, N: int):
        # Prefer fixed indices if provided
        if self.train_idx_npy and self.val_idx_npy and os.path.exists(self.train_idx_npy) and os.path.exists(self.val_idx_npy):
            tr_mask = np.load(self.train_idx_npy).astype(bool)
            va_mask = np.load(self.val_idx_npy).astype(bool)
            if tr_mask.shape[0] != N or va_mask.shape[0] != N:
                self.logger.warning("train/val mask length mismatch (N=%d). Falling back to random split.", N)
            else:
                tr = np.where(tr_mask)[0]
                te = np.where(va_mask)[0]
                self.logger.info("Using provided train/val indices: train=%d, val(test)=%d", tr.size, te.size)
                return tr, te
        # Else random split
        return train_test_split_indices(N, self.test_ratio, self.seed)

    def attr_names(self) -> List[str]:
        return list(self._attr_names)

                            
# Adapter factory
                            

def build_adapter(args, logger: logging.Logger) -> DatasetAdapter:
    if args.dataset == "lfw":
        assert args.images_root and args.attr_csv, "LFW 需要 --images_root 与 --attr_csv"
        return LFWAdapter(
            images_root=args.images_root,
            attr_csv=args.attr_csv,
            target_hw=tuple(args.target_hw),
            test_ratio=args.test_ratio,
            seed=args.seed,
            logger=logger,
            train_idx_npy=args.train_idx_npy,
            val_idx_npy=args.val_idx_npy
        )
    elif args.dataset == "ecog_monkey":
        assert args.X_npy and args.Y_npy, "ECOG (Monkey) 需要 --X_npy 与 --Y_npy"
        reshape = tuple(args.ecog_reshape) if args.ecog_reshape else None
        return EcogMonkeyAdapter(X_npy=args.X_npy, Y_npy=args.Y_npy, mask_npy=args.mask_npy, train_idx_npy=args.train_idx_npy, val_idx_npy=args.val_idx_npy, ecog_reshape=reshape, test_ratio=args.test_ratio, seed=args.seed, logger=logger)
    elif args.dataset == "ecog_person":
        assert args.X_npy and args.Y_npy, "ecog_person 数据集需要 --X_npy 与 --Y_npy"
        return EcogPersonAdapter(
            X_npy=args.X_npy, 
            Y_npy=args.Y_npy, 
            train_idx_npy=args.train_idx_npy, 
            val_idx_npy=args.val_idx_npy, 
            test_ratio=args.test_ratio, 
            seed=args.seed, 
            logger=logger
        )
    elif args.dataset == "synthetic_data":
        return SyntheticDataAdapter(
            N=args.synthetic_N,
            dims=tuple(args.synthetic_dims),
            D=tuple(args.synthetic_D),
            test_ratio=args.test_ratio,
            seed=args.seed,
            logger=logger,
        )
    else:
        raise ValueError(f"未知数据集: {args.dataset}")

                            
# Core runner (model-agnostic; delegates data to adapter)
                            

def run_tensor_regression(
    adapter: DatasetAdapter,
    logger: logging.Logger,
    method: str = "cp",
    rank: int = 10,
    n_iter_max: int = 300,
    tol: float = 1e-7,
    reg_W: float = 0.0,
    seed: int = 1234,
    verbose: bool = True,
    tucker_core: Optional[List[int]] = None,
    tt_ranks: Optional[List[int]] = None,
    pca_block: Optional[List[int]] = None,
    pca_basis_dims: Optional[List[int]] = None,
    K_pca: Any = 1,
    rate_choose: float = 1.0,
    step_size_A: float = 5e-4,
    step_size_W: float = 5e-4,
    warmup_A: int = 10,
                
    resume_from: Optional[str] = None,
    fixed_indices: Optional[Tuple[np.ndarray, np.ndarray]] = None,
    preload_norm: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None,  # (mean_X, y_mean, y_std)
    resume_estimator: Any = None,
    extra_iters: Optional[int] = None,

    device: Optional[str] = None,
    basis_order: str = "dft",
    pcahead_weight_type: str = "cp",
    pcahead_rank: Optional[int] = None,
    no_standardize: bool = False,
    ttstep_size: float = None,
    pcaheadinit_multi:float = 1e-2,
    pca_A_solver: str = 'gd',
    pca_W_full_solver: str = 'gd',
    pca_tt_init: str = 'spectral',
    pca_forward_flat:bool = False,
    pca_tt_solver: str = "rgd",
    pca_perk : bool = False,
    pca_combo_strategy: str = 'bfs',
    pca_B_diag: bool = False,
    pcagdals_iters: int = 1,
    pca_B_solver: str = 'perp_rgd',
    pca_if_noall_learn_A: bool = False,
    pca_combo_mode: str = 'diag',
    pca_not_use_nan_to_num: bool = False,
    pca_print_ab_stats: bool = False,
    pca_optimizer: str = 'gd',
    pca_init_B_method: str = 'random',
    pca_init_A_method: str = 'eye',
    pca_adam_betas: Tuple[float, float] = (0.9, 0.999),
    adam_eps: float = 1e-8,
    fixed_diag_max_k: bool = False,
    diag_skew_parallel_dist: float = 1.0,
    vali_test: bool = False,
    hop_r: int = 10,
    hop_xrank: Optional[List[int]] = None,
    hop_yrank: Optional[List[int]] = None,
    npls_r: int = 10,
    npls_iter: int = 100,
):

    logger.info("Method=%s | rank=%d | iters=%d | tol=%g | reg_W=%g | seed=%d", method, rank, n_iter_max, tol, reg_W, seed)
    basis_fn = dft_basis if basis_order == "dft" else dft_basis_high2low
    # 1) Load dataset via adapter
    X, Y, meta = adapter.load()
    N = X.shape[0]
    D = Y.shape[1]
    logger.info("Dataset=%s -> N=%d, X=%s, Y=%s", adapter.__class__.__name__, N, X.shape, Y.shape)

    # 2) Split (adapter-specific policy)

    if resume_from is not None and fixed_indices is not None:
        tr_idx, te_idx = fixed_indices
    else:
        tr_idx, te_idx = adapter.split(N)

    X_tr, X_te = X[tr_idx], X[te_idx]
    Y_tr, Y_te = Y[tr_idx], Y[te_idx]
    logger.info("Split -> train=%d, test=%d", X_tr.shape[0], X_te.shape[0])

    # 3) Preprocess (common): center X by train-mean; z-score Y by train stats
    if resume_from is None and preload_norm is None:
        if getattr(args, "no_standardize", False):
                   
            X_tr_c, X_te_c = X_tr.astype(np.float32), X_te.astype(np.float32)
            mean_X = np.zeros_like(X_tr[0:1], dtype=np.float32)
            Y_tr_z, Y_te_z = Y_tr.astype(np.float32), Y_te.astype(np.float32)
            y_mean = np.zeros_like(Y_tr[0:1], dtype=np.float32)
            y_std  = np.ones_like(Y_tr[0:1], dtype=np.float32)
            logger.info("⚠️ 标准化已禁用：保持原始尺度")
        else:
                   
            X_tr_c, X_te_c, mean_X = mean_center(X_tr, X_te)
            Y_tr_z, Y_te_z, y_mean, y_std = standardize_targets(Y_tr, Y_te)

    else:
        mean_X, y_mean, y_std = preload_norm
        X_tr_c = X_tr - mean_X
        X_te_c = X_te - mean_X
                             
        eps = 1e-8
        y_std_safe = np.where(y_std < eps, eps, y_std)
        Y_tr_z = (Y_tr - y_mean) / y_std_safe
        Y_te_z = (Y_te - y_mean) / y_std_safe
    logger.info("Preprocessing done. mean_X=%s, y_mean=%s, y_std(min/mean/max)=(%.4f, %.4f, %.4f)",
                mean_X.shape, y_mean.shape, float(y_std.min()), float(y_std.mean()), float(y_std.max()))

    # 4) Tensor backend conversion
    X_tr_T = T.tensor(X_tr_c)
    Y_tr_T = T.tensor(Y_tr_z)
    X_te_T = T.tensor(X_te_c)
    Y_te_T = T.tensor(Y_te_z)

    # 5) Build estimator
    
    mode_dims = list(X_tr_c.shape[1:])  # exclude sample dim
    mode_dims_y = list(Y_tr_z.shape[1:])  # exclude sample dim
    if device:
        if 'cuda' in device:
            tl.set_backend('pytorch')
    if resume_from is not None and resume_estimator is not None:
        estimator = resume_estimator
        logger.info("Resuming training on existing estimator...")
    else:
        if method == "cp":
            estimator = CPRegressor(
                weight_rank=rank,
                tol=tol,
                reg_W=reg_W,
                n_iter_max=n_iter_max,
                verbose=verbose,
                n_jobs=64,
                random_state=seed,
                logger=logger,
                device = device,
            )
        elif method == "full":
            estimator = FullRegressor(
                tol=tol,
                reg_W=reg_W,
                n_iter_max=n_iter_max,
                verbose=verbose,
                random_state=seed,
                logger=logger,
                device = device,
            )
        elif method == "tucker":
            if TuckerRegressor is None:
                raise ImportError("未找到 TuckerRegressor（请确保存在 tucker_regression.py 并可被导入）")
            # Determine core ranks per mode
            if tucker_core is None or len(tucker_core) == 0:
                raise ImportError("未找到 tucker_core")
            else:
                if len(tucker_core) != len(mode_dims)+len(mode_dims_y):
                    raise ValueError(f"--tucker_core 长度应为 {len(mode_dims)+len(mode_dims_y)}，收到 {len(tucker_core)}")
                core = tuple(int(x) for x in tucker_core)
                logger.info("Tucker core provided -> %s", core)
            estimator = TuckerRegressor(
                weight_ranks=core,
                tol=tol,
                reg_W=reg_W,
                n_iter_max=n_iter_max,
                verbose=verbose,
                random_state=seed,
                logger=logger,
                n_jobs=64,
                device = device,
            )
        elif method == "ttiht":
            if TTRegressorIHT is None:
                raise ImportError("未找到 TTRegressor（请确保存在 tt_regression.py 并可被导入）")
            # TT ranks length must be (#modes+1)
            m = len(mode_dims)+len(mode_dims_y )
            if tt_ranks is None or len(tt_ranks) == 0:
                # default [1, rank, rank, ..., rank, 1]
                tt_ranks_use = [1] + [rank]*m + [1]
                logger.info("TT ranks auto -> %s (rank=%d)", tt_ranks_use, rank)
            else:
                if len(tt_ranks) != m + 1:
                    raise ValueError(f"--tt_ranks 长度应为 {m+1}（#modes+1），收到 {len(tt_ranks)}")
                tt_ranks_use = list(int(x) for x in tt_ranks)
                logger.info("TT ranks provided -> %s", tt_ranks_use)
            estimator = TTRegressorIHT(
                tt_rank=tt_ranks_use,
                tol=tol,
                step_size = ttstep_size,
                n_iter_max=n_iter_max,
                verbose=verbose,
                random_state=seed,
                logger=logger,
                device = device,
            )
        elif method == "ttrgd":
            if TTRegressorRGD is None:
                raise ImportError("未找到 TTRegressor（请确保存在 tt_regression.py 并可被导入）")
            # TT ranks length must be (#modes+1)
            m = len(mode_dims)+len(mode_dims_y )
            if tt_ranks is None or len(tt_ranks) == 0:
                # default [1, rank, rank, ..., rank, 1]
                tt_ranks_use = [1] + [rank]*m + [1]
                logger.info("TT ranks auto -> %s (rank=%d)", tt_ranks_use, rank)
            else:
                if len(tt_ranks) != m + 1:
                    raise ValueError(f"--tt_ranks 长度应为 {m+1}（#modes+1），收到 {len(tt_ranks)}")
                tt_ranks_use = list(int(x) for x in tt_ranks)
                logger.info("TT ranks provided -> %s", tt_ranks_use)
            estimator = TTRegressorRGD(
                tt_rank=tt_ranks_use,
                tol=tol,
                n_iter_max=n_iter_max,
                step_size = ttstep_size,
                verbose=verbose,
                random_state=seed,
                logger=logger,
                device = device,
            )
        elif method == "pcareg":
            logger.info(f"PCAReg: auto-generated pca_block={pca_block}, K_pca={K_pca}")
            estimator = SpectralBlockRGDRegressor(
                block_sizes = pca_block,
                K = K_pca,
                reg_W=reg_W,
                n_iter_max=n_iter_max,
                tol = tol,
                basis_fn=basis_fn,
                verbose=verbose,
                random_state=seed,
                weight_type=pcahead_weight_type,
                step_size_A = step_size_A, 
                step_size_W = step_size_W,
                warmup_A=warmup_A,
                device = device,
                logger=logger,)
            
        elif method == "cpbasis":
            logger.info("Using CPBasisRegressor (DFT basis), rate_choose=%.3f", rate_choose)
            estimator = CPBasisRegressor(
                weight_rank=rank,
                tol=tol,
                reg_W=reg_W,
                n_iter_max=n_iter_max,
                verbose=verbose,
                n_jobs=64,
                random_state=seed,
                logger=logger,
                use_basis=True,
                rate_choose=rate_choose,
                basis_fn=basis_fn,
            )
        elif method == "cptvn":
            estimator = CPRegressorTVN(
                rank=rank,
                n_iter_max=n_iter_max,
                tol=tol,
                random_state=seed,
                verbose=verbose,
                ridge=getattr(args, "tvn_ridge", 1e-8),
                logger=logger,
            )
        elif method == "pcaregfast":
            logger.info(f"PCAReg: auto-generated pca_block={pca_block}, K_pca={K_pca},warmup_A = {warmup_A}")
            estimator = SpectralBlockRGDRegressorFast(
                block_sizes = pca_block,
                basis_dims = pca_basis_dims,
                K = K_pca,
                reg_W=reg_W,
                n_iter_max=n_iter_max,
                tol = tol,
                basis_fn=basis_fn,
                verbose=verbose,
                random_state=seed,
                weight_type =pcahead_weight_type,
                ranks=_make_head_ranks(pcahead_weight_type, pcahead_rank), 
                step_size_A = step_size_A, 
                step_size_W = step_size_W,
                warmup_A=warmup_A,
                device = device,
                logger=logger,
                headinit_multi = pcaheadinit_multi,
                A_solver = pca_A_solver,
                W_full_solver = pca_W_full_solver,
                tt_init=pca_tt_init,
                forward_flat = pca_forward_flat,
                tt_solver = pca_tt_solver,
                perk = pca_perk,
                combo_strategy = pca_combo_strategy,
                use_nan_to_num = (not pca_not_use_nan_to_num)
                
                )
        elif method == "pcaregperp":
            logger.info(f"PCAReg: auto-generated pca_block={pca_block}, K_pca={K_pca}")
            estimator = SpectralBlockPerpRGDRegressorFast(
                block_sizes=pca_block,
                basis_dims=pca_basis_dims,
                K=K_pca,
                reg_W=reg_W,
                n_iter_max=n_iter_max,
                tol=tol,
                basis_fn=basis_fn,
                verbose=verbose,
                random_state=seed,
                weight_type=pcahead_weight_type,
                ranks=_make_head_ranks(pcahead_weight_type, pcahead_rank),
                step_size_A=step_size_A,
                step_size_W=step_size_W,
                warmup_A=warmup_A,
                device=device,
                logger=logger,
                headinit_multi=pcaheadinit_multi,
                A_solver=pca_A_solver,
                W_full_solver=pca_W_full_solver,
                tt_init=pca_tt_init,
                forward_flat=pca_forward_flat,
                tt_solver=pca_tt_solver,
                perB=pca_perk,
                                              
                learn_B=(not args.no_pca_learn_B),
                sequential=(not args.pca_not_sequential) if hasattr(args, "pca_not_sequential") else True,
                step_size_B=(args.pca_step_size_B if args.pca_step_size_B is not None else step_size_A),
                reg_B_l2=args.pca_reg_B_l2,
                init_W_zero=args.pca_init_W_zero,
                B_diag = pca_B_diag,
                gdals_iters = pcagdals_iters,
                B_solver = pca_B_solver,
                if_all_learn_A = not pca_if_noall_learn_A,
                W_cp_solver = pca_W_full_solver if pcahead_weight_type=="cp" else 'gd',
                W_tucker_solver = pca_W_full_solver if pcahead_weight_type=="tucker" else 'rgd',
                joint_B_ortho=args.joint_B_ortho,
                combo_mode = pca_combo_mode,
                print_ab_stats = pca_print_ab_stats,
                optimizer = pca_optimizer,
                init_B_method = pca_init_B_method,
                init_A_method = pca_init_A_method,
                adam_betas=pca_adam_betas,
                adam_eps=adam_eps,
                fixed_diag_max_k=fixed_diag_max_k,
                diag_skew_parallel_dist =diag_skew_parallel_dist
                

            )

        elif method == "pcaregfast_autostep":
            logger.info(f"PCARegAutoStep: auto-generated pca_block={pca_block}, K_pca={K_pca},warmup_A = {warmup_A}")
            estimator = SpectralBlockRGDRegressorFastAutoStep(
                block_sizes = pca_block,
                K = K_pca,
                reg_W=reg_W,
                n_iter_max=n_iter_max,
                tol = tol,
                basis_fn=basis_fn,
                verbose=verbose,
                random_state=seed,
                weight_type=pcahead_weight_type,
                ranks=_make_head_ranks(pcahead_weight_type, pcahead_rank), 
                warmup_A=warmup_A,
                device = device,
                logger=logger,)
        elif method == "pcaregfast_perk":
            logger.info(f"PCAReg (Per-K A): pca_block={pca_block}, K_pca={K_pca}")
            estimator = SpectralBlockRGDRegressorFastPerK(
                block_sizes=pca_block,
                K=K_pca,
                reg_W=reg_W,
                n_iter_max=n_iter_max,
                tol=tol,
                basis_fn=basis_fn,
                verbose=verbose,
                random_state=seed,
                weight_type=pcahead_weight_type,
                ranks=_make_head_ranks(pcahead_weight_type, pcahead_rank), 
                step_size_A=step_size_A,
                step_size_W=step_size_W,
                warmup_A=warmup_A,
                device=device,
                logger=logger,
                A_solver = pca_A_solver,
                W_full_solver = pca_W_full_solver,
                tt_init=pca_tt_init,
                
            )
        elif method == "pcaregfast_cyc":
            logger.info(f"PCAReg (Cyclic combos): pca_block={pca_block}, K_pca={K_pca}, warmup_A={warmup_A}")
            estimator = SpectralBlockRGDRegressorCyclic(
                block_sizes=pca_block,
                K=K_pca,
                reg_W=reg_W,
                n_iter_max=n_iter_max,
                tol=tol,
                basis_fn=(dft_basis if basis_order == "dft" else dft_basis_high2low),
                verbose=verbose,
                random_state=seed,
                weight_type=pcahead_weight_type,              # 'cp' | 'full' | 'tucker' | 'tt'
                ranks=_make_head_ranks(pcahead_weight_type, pcahead_rank),
                step_size_A=step_size_A,
                step_size_W=step_size_W,
                warmup_A=warmup_A,
                device=device,
                logger=logger,
                A_solver = pca_A_solver,
                W_full_solver = pca_W_full_solver,
                tt_init=pca_tt_init
            )

        elif method == "pcaregfast_noortho":
            logger.info(f"PCAReg: auto-generated pca_block={pca_block}, K_pca={K_pca},warmup_A = {warmup_A}")
            estimator = SpectralBlockRGDRegressorFast_Noortho(
                block_sizes = pca_block,
                K = K_pca,
                reg_W=reg_W,
                n_iter_max=n_iter_max,
                tol = tol,
                basis_fn=basis_fn,
                verbose=verbose,
                random_state=seed,
                weight_type =pcahead_weight_type,
                ranks=_make_head_ranks(pcahead_weight_type, pcahead_rank), 
                step_size_A = step_size_A, 
                step_size_W = step_size_W,
                warmup_A=warmup_A,
                device = device,
                logger=logger,
                headinit_multi = pcaheadinit_multi)

        elif method == "hopls":
            logger.info(f"Using HOPLSRegressor: R={hop_r}, X-ranks(L_dims)={hop_xrank}, Y-ranks(K_dims)={hop_yrank}")
            estimator = HOPLSRegressor(
                R=hop_r, 
                L_dims=hop_xrank, 
                K_dims=hop_yrank,
                n_iter_max=n_iter_max, 
                verbose=verbose,
                tol=tol,
                random_state=seed,
                logger=logger
            )
        elif method == "npls":
            logger.info(f"Using NPLSRegressor: R={hop_r}, X-ranks(L_dims)={hop_xrank}, Y-ranks(K_dims)={hop_yrank}")
            estimator = NPLSRegressor(
                n_components = npls_r,
                als_max_iter = npls_iter,
                n_iter_max=n_iter_max, 
                verbose=verbose,
                tol=tol,
                random_state=seed,
                logger=logger
            )
        else:
            raise ValueError(f"未知方法: {method}")

    if not hasattr(estimator, "logger"):
        setattr(estimator, "logger", logger)

    # 6) Fit
    start_time = time.time()
    logger.info("Fitting %s regressor...", method.upper())
    if resume_from is not None and hasattr(estimator, 'warmup_A'):
        estimator.warmup_A = 0
        logger.info('warmup A = 0')

                                                             
    #     prev_iters = int(getattr(estimator, "history_", {}).get("n_iter", 0))
    #     added = int(extra_iters)
    #     for attr in ("n_iter_max", "max_iter", "n_epochs"):
    #         if hasattr(estimator, attr):
    #             old = int(getattr(estimator, attr))
    #             setattr(estimator, attr, old + added)
    #             logger.info(f"Resuming: {attr} {old} -> {old + added} (+{added})")
    #             break
    if resume_from is not None and extra_iters is not None:
        added = int(extra_iters)

                  
        prev_iters = int(getattr(estimator, "history_", {}).get("n_iter", 0))

        for attr in ("n_iter_max", "max_iter", "n_epochs"):
            if hasattr(estimator, attr):
                                           
                old = prev_iters if prev_iters > 0 else int(getattr(estimator, attr))
                
                setattr(estimator, attr, old + added)
                logger.info(
                    f"Resuming: {attr} {old} -> {old + added} (+{added})"
                )
                break

    if vali_test:
        if no_standardize:
            estimator.fit(X_tr_T, Y_tr_T,X_te_T, Y_te_T)
        else:
            estimator.fit(X_tr_T, Y_tr_T,X_te_T, Y_te_T, y_mean,y_std)
    else:
        estimator.fit(X_tr_T, Y_tr_T)
    elapsed = time.time() - start_time
    logger.info("Fit done in %.3f s", elapsed)

    # 7) Predict + evaluate in z-space, then de-standardize for reporting
    Y_pred_T = estimator.predict(X_te_T)
    tl.set_backend('numpy')
    # test_rmse = RMSE(Y_te_T, Y_pred_T)
    test_rmse = 0

    to_np = getattr(T, "to_numpy", lambda z: z)
    try:
        Y_pred = np.array(to_np(Y_pred_T))
    except Exception:
        Y_pred = np.array(Y_pred_T)

    if no_standardize:
        Y_pred_orig = Y_pred
        Y_true_orig = Y_te
    else:
        Y_pred_orig = Y_pred * y_std + y_mean
        Y_true_orig = Y_te

    # 8) Final metrics on ORIGINAL scale
    r_flat = pearson_r_flat(Y_true_orig, Y_pred_orig)
    rpe = relative_prediction_error(Y_true_orig, Y_pred_orig)
    logger.info("[RESULT] Test RMSE (z-space): %.6f | r(flat)=%.6f | RPE=%.6f", float(test_rmse), r_flat, rpe)

               
    results_dict = {
        "estimator": estimator,
        "method": method,
        "mean_X": mean_X,
        "y_mean": y_mean,
        "y_std": y_std,
        "attr_names": adapter.attr_names(),
        "meta_test": meta.get("meta") if isinstance(meta, dict) else None,
        "y_test_true": Y_true_orig,
        "y_test_pred": Y_pred_orig,
        "rmse_z": float(test_rmse),
        "r_flat": float(r_flat),
        "rpe": float(rpe),
        "train_idx": tr_idx,
        "test_idx": te_idx,
    }

                                
                               
                                
    if getattr(adapter, 'name', None) == 'ecog_person':
        logger.info("--- [ECoG Person Per-Finger Correlation] ---")
        per_finger_corr = {}
        finger_names = adapter.attr_names() # ["Thumb", "Index", "Middle", "Ring", "Pinky"]
        
        correlations = []
        for i, name in enumerate(finger_names):
                            
            finger_true = Y_te_z[:, i]
            finger_pred = Y_pred[:, i]
                                              
            valid_indices = ~np.isnan(finger_true)
            if np.any(valid_indices):
                corr = pearson_r_flat(finger_true[valid_indices], finger_pred[valid_indices])
                per_finger_corr[f"{name}_corr"] = corr
                logger.info(f"  - {name}: {corr:.6f}")
                                                
                if i != 3: # 0-indexed, so Ring is index 3
                    correlations.append(corr)
            else:
                 logger.warning(f"  - {name}: No valid ground truth labels found for evaluation.")

        if correlations:
            avg_corr = np.mean(correlations)
            logger.info(f"  - Avg Correlation (excluding Ring): {avg_corr:.6f}")
            results_dict["avg_corr_no_ring"] = avg_corr
        
        logger.info("---------------------------------------------")

                      
        results_dict["per_finger_corr"] = per_finger_corr
        
    return results_dict
    
                                                              
def load_previous_run(run_dir: str):
    cfg_path = os.path.join(run_dir, "config.json")
    art_path = os.path.join(run_dir, "artifacts.npz")
    model_path = os.path.join(run_dir, "model.pkl")
    assert os.path.exists(cfg_path) and os.path.exists(art_path) and os.path.exists(model_path),\
        f"resume_from={run_dir} 缺少必要文件 (config.json/artifacts.npz/model.pkl)"

    with open(cfg_path, "r", encoding="utf-8") as f:
        prev_cfg = json.load(f)
    arts = np.load(art_path, allow_pickle=True)
    with open(model_path, "rb") as f:
        estimator = pickle.load(f)

    fixed_train_idx = arts["train_idx"]
    fixed_test_idx  = arts["test_idx"]
    mean_X = arts["mean_X"]
    y_mean = arts["y_mean"]
    y_std  = arts["y_std"]

    return prev_cfg, (fixed_train_idx, fixed_test_idx), (mean_X, y_mean, y_std), estimator

def merged_cfg_with_overrides(prev_cfg: dict, args, defaults, override_keys: List[str]) -> dict:
\
\
\
       
    merged = {**prev_cfg}
    for k in override_keys:
        if hasattr(args, k) and hasattr(defaults, k):
            cur = getattr(args, k)
            dft = getattr(defaults, k)
                                          
            if isinstance(cur, list): cur_cmp = tuple(cur)
            else: cur_cmp = cur
            if isinstance(dft, list): dft_cmp = tuple(dft)
            else: dft_cmp = dft
            if cur_cmp != dft_cmp:         
                merged[k] = cur
    return merged

def _explicit_override(args, defaults, key):
                                           
    if not hasattr(args, key) or not hasattr(defaults, key):
        return False
    cur = getattr(args, key)
    dft = getattr(defaults, key)
                
    if isinstance(cur, list): cur = tuple(cur)
    if isinstance(dft, list): dft = tuple(dft)
    return cur != dft

def apply_overrides_to_estimator(estimator, args, defaults, logger):
\
\
\
       
    # key -> (attribute_name_on_estimator, human_readable)
    mapping = {
        "n_iter_max": ("n_iter_max", "max iters"),
        "tol":        ("tol",        "tolerance"),
        "reg_W":      ("reg_W",      "weight regularization"),
        "step_size_A":("step_size_A","step size A"),                  
        "step_size_W":("step_size_W","step size W"),                  
        "tvn_ridge":  ("ridge",      "TVN ridge (linear solve)"),  # CPRegressorTVN
        "verbose":    ("verbose",    "verbose"),
        "device":     ("device",     "compute device"),  # NEW: allow overriding device on resume
        "pca_step_size_B": ("step_size_B", "step size B"),
        "pca_reg_B_l2":    ("reg_B_l2",   "B L2 regularization"),
        "pca_init_W_zero": ("init_W_zero", "init W to zero"),
        "pcagdals_iters":  ("gdals_iters", "GD-ALS inner iterations"),
        "joint_B_ortho":   ("joint_B_ortho", "Joint B ortho strategy"),
        "pca_optimizer":   ("optimizer", "Optimizer for A/B/W"),
        "pca_print_ab_stats": ("print_ab_stats", "Print A/B stats"),
    }

    for k, (attr, desc) in mapping.items():
        if _explicit_override(args, defaults, k) and hasattr(estimator, attr):
            val = getattr(args, k)
            setattr(estimator, attr, val)
            logger.info(f"[resume override] Set {desc} ({attr}) -> {val}")
    
                    
    if _explicit_override(args, defaults, "no_pca_learn_B") and hasattr(estimator, "learn_B"):
        estimator.learn_B = not args.no_pca_learn_B
        logger.info(f"[resume override] Set learn_B -> {estimator.learn_B}")

    if _explicit_override(args, defaults, "pca_not_sequential") and hasattr(estimator, "sequential"):
        estimator.sequential = not args.pca_not_sequential
        logger.info(f"[resume override] Set sequential -> {estimator.sequential}")

                            
# Saving artifacts
                            
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)
def save_run_artifacts(run_dir: str, cfg: dict, results: dict, logger: logging.Logger, save_mode: str = "all"):
    cfg_path = os.path.join(run_dir, "config.json")
    metrics_path = os.path.join(run_dir, "metrics.json")
    preds_path = os.path.join(run_dir, "predictions_test.npz")
    npz_path = os.path.join(run_dir, "artifacts.npz")
    model_path = os.path.join(run_dir, "model.pkl")
    try:
        with open(cfg_path, "w", encoding="utf-8") as f:
            json.dump(cfg, f, indent=2, ensure_ascii=False)
        with open(metrics_path, "w", encoding="utf-8") as f:
            metrics_data = {
                "rmse_z": results.get("rmse_z"),
                "r_flat": results.get("r_flat"),
                "rpe": results.get("rpe"),
                "method": results.get("method"),
            }
                             
            if "per_finger_corr" in results:
                metrics_data["per_finger_corr"] = results["per_finger_corr"]
                metrics_data["avg_corr_no_ring"] = results.get("avg_corr_no_ring")
            
            json.dump(metrics_data, f, indent=2, ensure_ascii=False, cls=NpEncoder)
                                                      
        if save_mode == "minimal":
            logger.info("Save mode = minimal -> 仅保存 config.json 与 metrics.json（log.txt 由 logger 已写）。跳过 preds/artifacts/model。")
            return
        np.savez_compressed(
            preds_path,
            y_true=results["y_test_true"],
            y_pred=results["y_test_pred"],
            test_idx=results["test_idx"],
            attr_names=np.array(results.get("attr_names", []), dtype=object),
        )
        np.savez_compressed(
            npz_path,
            mean_X=results["mean_X"],
            y_mean=results["y_mean"],
            y_std=results["y_std"],
            train_idx=results["train_idx"],
            test_idx=results["test_idx"],
            attr_names=np.array(results.get("attr_names", []), dtype=object),
        )
        try:
            with open(model_path, "wb") as f:
                pickle.dump(results["estimator"], f)
            logger.info("Model saved: %s", model_path)
        except Exception as e:
            logger.error("Model pickling failed: %s", e)
        logger.info("Saved artifacts: %s, %s, %s, %s", cfg_path, metrics_path, preds_path, npz_path)
    except Exception as e:
        logger.error("Failed to save artifacts: %s", e)

                            
# CLI
                            

def build_parser():
    p = argparse.ArgumentParser(description="Tensor Regression (CP/Tucker/TT) — modular dataset adapters")
    # p.add_argument("--dataset", type=str, required=True, choices=["lfw", "ecog_monkey", "ecog_person"], help="Dataset type")
    p.add_argument("--dataset", type=str, choices=["lfw", "ecog_monkey", "ecog_person","synthetic_data","synthetic_data_lowrank"], help="Dataset type")
    p.add_argument("--method", type=str, default="cp", choices=["full","cp", "tucker", "ttiht","ttrgd","pcareg","cpbasis","cptvn", "tuckertvn","pcaregfast","pcaregfast_perk","pcaregfast_autostep","pcaregfast_cyc","pcaregfast_noortho","pcaregperp","hopls","npls"], help="Regression method")

    p.add_argument("--no_standardize", action="store_true")
    # LFW
    p.add_argument("--images_root", type=str, default=None)
    p.add_argument("--attr_csv", type=str, default=None)
    p.add_argument("--target_hw", type=int, nargs=2, default=[90, 90])

    # ECoG
    p.add_argument("--X_npy", type=str, default=None, help="Path to X .npy for ECoG datasets")
    p.add_argument("--Y_npy", type=str, default=None, help="Path to Y .npy for ECoG datasets")
    p.add_argument("--mask_npy", type=str, default=None, help="Optional mask .npy (True=artifact) to be removed")
    p.add_argument("--ecog_reshape", type=int, nargs=3, default=[64, 10, 10], help="Optional reshape for flat ECoG features, e.g., 64x10x10")

    p.add_argument("--train_idx_npy", type=str, default=None, help="Optional train index .npy (bool or int indices)")
    p.add_argument("--val_idx_npy", type=str, default=None, help="Optional val/test index .npy (bool or int indices)")
    # Synthetic Data
    p.add_argument("--synthetic_N", type=int, default=5000, help="Number of samples for synthetic data")
    p.add_argument("--synthetic_dims", type=int, nargs='*', default=[60, 60, 15], help="Tensor dimensions for synthetic data")
    p.add_argument("--synthetic_D", type=int, nargs='*', default=[73], help="Output dimension for synthetic data")

    # Training
    p.add_argument("--test_ratio", type=float, default=0.2)
    p.add_argument("--rank", type=int, default=10, help="Generic rank hyperparam (CP rank / default Tucker per-mode cap)")
    p.add_argument("--tucker_core", type=int, nargs='*', default=None, help="Tucker core ranks per mode (exclude sample dim); default auto")
    p.add_argument("--tt_ranks", type=int, nargs='*', default=None, help="TT ranks of length (#modes+1); default auto")
    p.add_argument("--pca_blocks", type=int, nargs='*', default=None)
    p.add_argument("--pca_basis_dims", type=int, nargs='*', default=None)
    p.add_argument("--K_pca", type=int, nargs='*', default=None, help="K for PCA regressor. Single value is broadcasted, multiple values are used per mode.")
    p.add_argument("--rate_choose", type=float, default=1.0, help="CP basis keep ratio in dft_basis (0,1]")
    p.add_argument("--tvn_ridge", type=float, default=1e-8, help="TVN linear solves ridge")
    p.add_argument("--step_size_A", type=float, default=5e-4)
    p.add_argument("--step_size_W", type=float, default=None)
    p.add_argument("--n_iter_max", type=int, default=300)
    p.add_argument("--tol", type=float, default=1e-7)
    p.add_argument("--reg_W", type=float, default=0.0)
    p.add_argument("--warmup_A", type=int, default=10, help="Warm-up steps: only update W for the first warmup_A iters")
    p.add_argument("--seed", type=int, default=1234)
    p.add_argument("--verbose", action="store_true")
    p.add_argument("--device", type=str, default=None, help="Compute device: e.g., 'cpu', 'cuda', or 'mps'")  # NEW: device CLI
    p.add_argument("--pcaheadinit_multi", type=float, default=1e-2)
    # basis choice
    p.add_argument("--basis_order", type=str, choices=["dft", "dfth2l","eye"], default="dft",
               help="Choose DFT basis: 'dft' (low→high) or 'dfth2l' (high→low)")
    p.add_argument("--pcahead_weight_type", type=str, default="cp",
               choices=["cp", "full", "tucker", "tt"],
               help="Head weight parameterization for PCAReg* models.")
    p.add_argument("--pcahead_rank", type=int, default=8,
                help="Rank for PCAReg* head when pcahead_weight_type in {cp,tucker,tt}: CP->R, Tucker->r (same per mode), TT->single rank")
    p.add_argument("--pca_A_solver",  type=str, default="gd")
    p.add_argument("--pca_W_full_solver",  type=str, default="gd")
    
    p.add_argument("--pca_tt_init", type=str, default='spectral')
    p.add_argument("--pca_forward_flat", action="store_true")
    p.add_argument("--pca_perk", action="store_true")
    p.add_argument("--pca_combo_strategy", type=str, default='bfs')

    p.add_argument("--pca_tt_solver", type=str, default="rgd") 
    p.add_argument("--pca_step_size_B", type=float, default=None,
                help="RGD step size for B; default to step_size_A if None")
    p.add_argument("--pca_reg_B_l2", type=float, default=0.0,
                help="L2 on B in Euclidean grad")
                                           
    p.add_argument("--no_pca_learn_B", action="store_true",
                help="Disable learning B (default: learn B)")
    p.add_argument("--pca_not_sequential", action="store_true",
                help="Force sequential component learning (default: class default)")
    p.add_argument("--pca_init_W_zero", action="store_true",
                help="Init W to zeros (only for FULL/compatible heads)")

    # Logging
    p.add_argument("--log_dir", type=str, default="logs_run_Finger")
    p.add_argument("--run_name", type=str, default=None)
    p.add_argument("--ttstep_size",type=float)
    # resume
    p.add_argument("--resume_from", type=str, default=None,
               help="Path to a previous run_dir to resume from (must contain artifacts.npz/config.json)")
    p.add_argument("--extra_iters", type=int, default=100,
                help="Additional iterations to run when resuming")
    p.add_argument("--resume_same_dir", action="store_true",
                help="If set, keep writing artifacts into the same run_dir; otherwise create a new sub-run")
    p.add_argument("--pca_B_diag", action="store_true")
    p.add_argument("--pcagdals_iters", type=int, default=1)
    p.add_argument("--pca_not_use_nan_to_num", action="store_true")
    
    p.add_argument(
        "--save_mode",
        type=str,
        choices=["all", "minimal"],
        default="all",
        help="保存模式：all=全部文件；minimal=仅 log.txt、config.json、metrics.json"
    )
    p.add_argument("--pca_B_solver",type=str, default='perp_rgd')
    p.add_argument("--pca_if_noall_learn_A", action="store_true")
    p.add_argument("--joint_B_ortho",type=str, default='perp_seq')
    p.add_argument("--pca_combo_mode",type=str, default='diag')
    p.add_argument("--pca_print_ab_stats",action="store_true")
    p.add_argument("--pca_optimizer",type=str, default='gd')
    p.add_argument("--pca_init_B_method",type=str, default='random')
    p.add_argument("--pca_init_A_method",type=str, default='eye')

    p.add_argument("--pca_adam_betas", type=float, nargs=2, default=[0.9, 0.999], help="Adam optimizer beta1 and beta2 for PCAReg*.")
    p.add_argument("--pca_adam_eps", type=float, default=1e-8, help="Adam optimizer epsilon for PCAReg*.")
    p.add_argument("--pca_fixed_diag_max_k", action="store_true", help="For pcaregperp with diag_skew, fix diagonal based on max possible K.")
    p.add_argument("--pca_diag_skew_parallel_dist", type=float, default=1.0)
    p.add_argument("--vali_test", action="store_true")

    # HOPLS
    p.add_argument("--hop_r", type=int, default=10, help="R for HOPLS regressor.")
    p.add_argument("--hop_xrank", type=int, nargs='*', default=None, help="L_dims (X ranks) for HOPLS regressor.")
    p.add_argument("--hop_yrank", type=int, nargs='*', default=None, help="K_dims (Y ranks) for HOPLS regressor.")
     # NPLS
    p.add_argument("--npls_r", type=int, default=10, help="R for HOPLS regressor.")
    p.add_argument("--npls_iter", type=int, default=100, help="R for HOPLS regressor.")
    p.add_argument("--analyze_components", type=str, default=None,
                   help="Path to a finished pcaregperp run directory. If provided, this script will load the "
                        "model and perform per-component regression analysis on the test set instead of training.")
    return p

def parse_args():
    parser = build_parser()
    return parser.parse_args(), parser

if __name__ == "__main__":
    args, parser = parse_args()
    defaults = parser.parse_args([])           
    if args.analyze_components:
        if args.method in ["hopls", "npls"]:
            analyze_pls_components(args.analyze_components)
        else:
            analyze_model_components(args.analyze_components)
                                 
        sys.exit(0)
    if args.K_pca is None:
        k_pca_val = 1
    elif len(args.K_pca) == 1:
        k_pca_val = args.K_pca[0]
    else:
        k_pca_val = args.K_pca
                                    
    if args.resume_from:
        base_run_dir = args.resume_from

                                   
                                          
        prev_cfg, fixed_indices, preload_norm, loaded_estimator = load_previous_run(base_run_dir)

                 
        logger, run_dir = setup_logger(
            log_dir=(os.path.dirname(base_run_dir) if args.resume_same_dir else args.log_dir),
            run_name=(os.path.basename(base_run_dir) if args.resume_same_dir else None),
            dataset=prev_cfg.get("dataset"),
            method=prev_cfg.get("method"),
            level=logging.INFO
        )
        logger.info(f"Resuming from run directory: {base_run_dir}")

                                        
                                                   
        adapter_cfg_dict = {
            "dataset": prev_cfg.get("dataset"),
            "images_root": prev_cfg.get("images_root"),
            "attr_csv": prev_cfg.get("attr_csv"),
            "target_hw": prev_cfg.get("target_hw", (90, 90)),
            "X_npy": prev_cfg.get("X_npy"),
            "Y_npy": prev_cfg.get("Y_npy"),
            "mask_npy": prev_cfg.get("mask_npy"),
            "train_idx_npy": prev_cfg.get("train_idx_npy"),
            "val_idx_npy": prev_cfg.get("val_idx_npy"),
            "ecog_reshape": prev_cfg.get("ecog_reshape", (64, 10, 10)),
            "synthetic_N": prev_cfg.get("synthetic_N", 5000),
            "synthetic_dims": prev_cfg.get("synthetic_dims", [60, 60, 15]),
            "synthetic_D": prev_cfg.get("synthetic_D", [73]),
            "test_ratio": prev_cfg.get("test_ratio", 0.2),
            "seed": prev_cfg.get("seed", 1234),
        }
        adapter = build_adapter(argparse.Namespace(**adapter_cfg_dict), logger)

                                        
        apply_overrides_to_estimator(loaded_estimator, args, defaults, logger)

                   
                                               
        k_pca_val_resume = prev_cfg.get("K_pca", 1)
        
        results = run_tensor_regression(
            adapter=adapter,
            logger=logger,
                                 
            resume_from=base_run_dir,
            fixed_indices=fixed_indices,
            preload_norm=preload_norm,
            resume_estimator=loaded_estimator,
            extra_iters=args.extra_iters,
                                          
            method=prev_cfg.get("method", "pcaregperp"),
            rank=prev_cfg.get("rank", 10),
            n_iter_max=prev_cfg.get("n_iter_max", 300),
            tol=prev_cfg.get("tol", 1e-7),
            reg_W=prev_cfg.get("reg_W", 0.0),
            seed=prev_cfg.get("seed", 1234),
            verbose=prev_cfg.get("verbose", True),
            tucker_core=prev_cfg.get("tucker_core"),
            tt_ranks=prev_cfg.get("tt_ranks"),
            pca_block=prev_cfg.get("pca_blocks"),
            pca_basis_dims=prev_cfg.get("pca_basis_dims"),
            K_pca=k_pca_val_resume,
            step_size_A=prev_cfg.get("step_size_A", 5e-4),
            step_size_W=prev_cfg.get("step_size_W"),
            warmup_A=prev_cfg.get("warmup_A", 10),
            device=prev_cfg.get("device"),
            basis_order=prev_cfg.get("basis_order", "dft"),
            pcahead_weight_type=prev_cfg.get("pcahead_weight_type", "cp"),
            pcahead_rank=prev_cfg.get("pcahead_rank"),
            no_standardize=prev_cfg.get("no_standardize", False),
            pcaheadinit_multi=prev_cfg.get("pcaheadinit_multi", 1e-2),
            pca_A_solver=prev_cfg.get("pca_A_solver", "rgd"),
            pca_W_full_solver=prev_cfg.get("pca_W_full_solver", "gd"),
            pca_tt_init=prev_cfg.get("pca_tt_init", "random"),
            pca_forward_flat=prev_cfg.get("pca_forward_flat", False),
            pca_tt_solver=prev_cfg.get("pca_tt_solver", "rgd"),
            pca_perk=prev_cfg.get("pca_perk", False),
                                              
            pca_combo_mode=prev_cfg.get("pca_combo_mode", "diag"),
            pca_step_size_B=prev_cfg.get("pca_step_size_B"),
            pca_reg_B_l2=prev_cfg.get("pca_reg_B_l2", 0.0),
            no_pca_learn_B=prev_cfg.get("no_pca_learn_B", False),              
            pca_not_sequential=prev_cfg.get("pca_not_sequential", False),              
            pca_init_W_zero=prev_cfg.get("pca_init_W_zero", False),
            pca_B_diag=prev_cfg.get("pca_B_diag", False),
            pcagdals_iters=prev_cfg.get("pcagdals_iters", 1),
            pca_B_solver=prev_cfg.get("pca_B_solver", "perp_rgd"),
            pca_if_noall_learn_A=prev_cfg.get("pca_if_noall_learn_A", False),
            joint_B_ortho=prev_cfg.get("joint_B_ortho", "perp_seq"),
            pca_not_use_nan_to_num=prev_cfg.get("pca_not_use_nan_to_num", False),
            pca_print_ab_stats=prev_cfg.get("pca_print_ab_stats", False),
            pca_optimizer=prev_cfg.get("pca_optimizer", "gd"),
            pca_init_B_method=prev_cfg.get("pca_init_B_method", "random"),
            pca_init_A_method=prev_cfg.get("pca_init_A_method", "eye"),
            pca_adam_betas=prev_cfg.get("pca_adam_betas", [0.9, 0.999]),
            adam_eps=prev_cfg.get("pca_adam_eps", 1e-8),                         
            fixed_diag_max_k=prev_cfg.get("pca_fixed_diag_max_k", False),
            diag_skew_parallel_dist=prev_cfg.get("pca_diag_skew_parallel_dist", 1.0),
            vali_test=prev_cfg.get("vali_test", False),
            hop_r=prev_cfg.get("hop_r", 10),
            hop_xrank=prev_cfg.get("hop_xrank"),
            hop_yrank=prev_cfg.get("hop_yrank"),
            npls_r=prev_cfg.get("npls_r", 10),
            npls_iter=prev_cfg.get("npls_iter", 100),
        )

                 
        new_cfg = {**prev_cfg, "resumed_from": base_run_dir, "added_iters": args.extra_iters}
                        
        for k in ["n_iter_max", "tol", "reg_W", "step_size_A", "step_size_W", "pca_step_size_B", "pca_reg_B_l2"]:
             if _explicit_override(args, defaults, k):
                 new_cfg[f"resume_override::{k}"] = getattr(args, k)

        save_run_artifacts(
            run_dir,              
            new_cfg,
            results,
            logger,
            save_mode=args.save_mode
        )
        logger.info("Resume OK. Metrics -> r(flat)=%.6f | RPE=%.6f | RMSE(z)=%.6f",
                    results["r_flat"], results["rpe"], results["rmse_z"])

    else:
                         
        logger, run_dir = setup_logger(
            log_dir=args.log_dir, run_name=args.run_name, dataset=args.dataset, method=args.method, level=logging.INFO
        )
        cfg = {
            "dataset": args.dataset,
            "method": args.method,
            "images_root": args.images_root,
            "attr_csv": args.attr_csv,
            "target_hw": args.target_hw,
            "X_npy": args.X_npy,
            "Y_npy": args.Y_npy,
            "mask_npy": args.mask_npy,
            "train_idx_npy": args.train_idx_npy,
            "val_idx_npy": args.val_idx_npy,
            "ecog_reshape": args.ecog_reshape,
            "test_ratio": args.test_ratio,
            "rank": args.rank,
            "tucker_core": args.tucker_core,
            "tt_ranks": args.tt_ranks,
            "pca_blocks": args.pca_blocks,
            "pca_basis_dims":args.pca_basis_dims,
            "K_pca": k_pca_val,
            "n_iter_max": args.n_iter_max,
            "rate_choose": args.rate_choose,
            "tol": args.tol,
            "reg_W": args.reg_W,
            "seed": args.seed,
            "verbose": args.verbose,
            "device": args.device, 
            "warmup_A":args.warmup_A,
            "basis_order": args.basis_order,
            "pcahead_weight_type": args.pcahead_weight_type,
            "no_standardize": args.no_standardize,
            "ttstep_size":args.ttstep_size,
            "step_size_A":args.step_size_A,
            "step_size_W":args.step_size_W,
            "pcahead_rank": args.pcahead_rank,
            "pcaheadinit_multi":args.pcaheadinit_multi,
            "pca_A_solver":args.pca_A_solver,
            "pca_W_full_solver":args.pca_W_full_solver,
            "pca_tt_init":args.pca_tt_init,
            "pca_forward_flat": args.pca_forward_flat,
            "pca_tt_solver" : args.pca_tt_solver,
            "pca_perk":args.pca_perk,
            "pca_combo_strategy":args.pca_combo_strategy,
            "pca_step_size_B": args.pca_step_size_B,
            "pca_reg_B_l2": args.pca_reg_B_l2,
            "no_pca_learn_B": args.no_pca_learn_B,
            "pca_not_sequential": args.pca_not_sequential,
            "pca_init_W_zero": args.pca_init_W_zero,
            "pca_B_diag":args.pca_B_diag,
            "pcagdals_iters":args.pcagdals_iters,
            "pca_B_solver" : args.pca_B_solver,
            "pca_if_noall_learn_A":args.pca_if_noall_learn_A,
            "joint_B_ortho":args.joint_B_ortho,
            "pca_combo_mode":args.pca_combo_mode,
            "pca_not_use_nan_to_num":args.pca_not_use_nan_to_num,
            "pca_print_ab_stats":args.pca_print_ab_stats,
            "pca_optimizer":args.pca_optimizer,
            "pca_init_B_method":args.pca_init_B_method,
            "pca_init_A_method":args.pca_init_A_method,
            "pca_adam_betas": args.pca_adam_betas,
            "pca_adam_eps": args.pca_adam_eps,
            "pca_fixed_diag_max_k": args.pca_fixed_diag_max_k,
            "pca_diag_skew_parallel_dist":args.pca_diag_skew_parallel_dist,
            "vali_test":args.vali_test,
            "hop_r": args.hop_r,
            "hop_xrank": args.hop_xrank,
            "hop_yrank": args.hop_yrank,
            "npls_r":args.npls_r,
            "npls_iter":args.npls_iter,
        }

        adapter = build_adapter(args, logger)
        results = run_tensor_regression(
            adapter=adapter,
            logger=logger,
            method=args.method,
            rank=args.rank,
            n_iter_max=args.n_iter_max,
            tol=args.tol,
            reg_W=args.reg_W,
            seed=args.seed,
            verbose=args.verbose,
            tucker_core=args.tucker_core,
            tt_ranks=args.tt_ranks,
            pca_block=args.pca_blocks,
            pca_basis_dims = args.pca_basis_dims,
            K_pca=k_pca_val,
            rate_choose=args.rate_choose,
            step_size_A=args.step_size_A,
            step_size_W=args.step_size_W,
            warmup_A = args.warmup_A,
            device=args.device,
            basis_order=args.basis_order,
            pcahead_weight_type=args.pcahead_weight_type,
            pcahead_rank=args.pcahead_rank,
            no_standardize=args.no_standardize,
            ttstep_size = args.ttstep_size,
            pcaheadinit_multi = args.pcaheadinit_multi,
            pca_A_solver = args.pca_A_solver,
            pca_W_full_solver = args.pca_W_full_solver,
            pca_tt_init = args.pca_tt_init,
            pca_forward_flat = args.pca_forward_flat,
            pca_tt_solver = args.pca_tt_solver,
            pca_perk = args.pca_perk,
            pca_combo_strategy = args.pca_combo_strategy,
            pca_B_diag = args.pca_B_diag,
            pcagdals_iters = args.pcagdals_iters,
            pca_B_solver = args.pca_B_solver,
            pca_if_noall_learn_A = args.pca_if_noall_learn_A,
            pca_combo_mode = args.pca_combo_mode,
            pca_not_use_nan_to_num = args.pca_not_use_nan_to_num,
            pca_print_ab_stats = args.pca_print_ab_stats,
            pca_optimizer = args.pca_optimizer,
            pca_init_B_method = args.pca_init_B_method,
            pca_init_A_method = args.pca_init_A_method,
            pca_adam_betas=args.pca_adam_betas,
            adam_eps=args.pca_adam_eps,
            fixed_diag_max_k=args.pca_fixed_diag_max_k,
            diag_skew_parallel_dist = args.pca_diag_skew_parallel_dist,
            vali_test = args.vali_test,
            hop_r = args.hop_r,
            hop_xrank= args.hop_xrank,
            hop_yrank=  args.hop_yrank,
            npls_r = args.npls_r,
            npls_iter = args.npls_iter,
        )
                
        save_run_artifacts(run_dir, cfg, results, logger, save_mode=args.save_mode)
        logger.info("Done. Metrics -> r(flat)=%.6f | RPE=%.6f | RMSE(z)=%.6f",
                    results["r_flat"], results["rpe"], results["rmse_z"])
