# -*- coding: utf-8 -*-
from typing import List, Tuple, Dict, Optional
import numpy as np
import torch
import tensorly as tl
from tensorly.decomposition import tensor_train
from tensorly.tt_tensor import TTTensor, tt_to_tensor
from tensorly.tenalg import khatri_rao
from tensorly.cp_tensor import cp_to_tensor
from tensorly.tenalg import kronecker, multi_mode_dot
from tensorly.tucker_tensor import tucker_to_tensor

# （）
from ..basis_vector import dft_basis
from ..tensor_check import check_tensor_stats 

                                
def mode_n_product_batch(X: torch.Tensor, M: torch.Tensor, mode: int) -> torch.Tensor:
    """X:(B,d1,...,dn), M:(d_mode, d_mode') 左乘第 mode 模"""
    Y = torch.tensordot(X, M, dims=([1 + mode], [0]))
    new_axis, target_axis = Y.ndim - 1, 1 + mode
    if new_axis != target_axis:
        perm = list(range(Y.ndim))
        perm.pop(new_axis)
        perm.insert(target_axis, new_axis)
        Y = Y.permute(perm)
    return Y


def riemannian_grad_on_orthogonal(A: torch.Tensor, G: torch.Tensor) -> torch.Tensor:
    AtG = A.T @ G
    sym = 0.5 * (AtG + AtG.T)
    return G - A @ sym


def qr_retraction(A: torch.Tensor, step: torch.Tensor) -> torch.Tensor:
                                              
    Y = A - step
    Q, R = torch.linalg.qr(Y, mode='reduced')
    d = torch.diag(R)
    s = torch.sign(d)
    s[s == 0] = 1.0
    return Q @ torch.diag(s)


def _safe_check(name: str, T: torch.Tensor):
    try:
        check_tensor_stats(T, name)
    except Exception:
        pass


def _clip_inplace(t: torch.Tensor, max_norm: float = 1e3):
    nrm = torch.linalg.norm(t)
    if torch.isfinite(nrm) and nrm > max_norm:
        t.mul_(max_norm / (nrm + 1e-12))


                                              
class SpectralBlockPerpRGDRegressorFast:
    """
    顺序主成分（components）版本：在原有 B/A/W 框架上，使 B 可训练，并按 p=1..K 的顺序逐组学习，
    保证同模不同组 B 的两两正交；每个阶段与 A（所有已有组）和 W 联合优化。

    兼容原版大部分超参与 W 的低秩结构（'full'|'cp'|'tucker'|'tt'）。

    新增关键参数：
    - learn_B: 是否训练 B（若 False 则沿用原版 DFT block，提供兼容路径）
    - sequential: 是否启用“按组件顺序”的外层循环（强烈建议 True）
    - step_size_B: B 的步长（Riemannian on Stiefel, 带对先前子空间的正交投影）
    - reg_B_l2:   B 的 L2 正则（作用在欧氏梯度上，再做黎曼投影）
    - init_W_zero: 是否将 W 初始化为全 0（full/部分参数化可用；TT/CP/Tucker 建议保留随机）

    训练入口：
    - fit_sequential(X, y): 实现题述三步到 K 的完整流程；
      * 第 p 组时，固定 {B_{i,<p}}，仅更新 {B_{i,p}}、所有 {A_{i,≤p}} 与 W。
      * 提供早停/容忍度，与原版一致。

    预测/评分接口保持不变。
    """

    def __init__(
        self,
        block_sizes: List[int],
        K: int,
        weight_type: str = 'full',           # 'full' | 'cp' | 'tucker' | 'tt'
        ranks: Optional[Dict[str, int]] = None,
        reg_W: float = 1e-3,
        step_size_A: float = 5e-4,
        step_size_W: Optional[float] = 5e-4,
        step_size_G: Optional[float] = None,
        warmup_A: int = 10,
        n_iter_max: int = 300,
        tol: float = 1e-6,
        basis_fn=dft_basis,
        device: Optional[str] = None,
        dtype: torch.dtype = torch.float64,
        verbose: int = 1,
        logger: Optional[object] = None,
        random_state: Optional[int] = None,
        early_stopping: bool = True,
        patience: int = 50,
        min_delta: float = 0.0,
        headinit_multi: float = 1e-2,
              
        if_rgd: bool = True,
        tt_init: str = 'random',
        forward_flat: bool = False,
        B_solver: str = 'perp_rgd',  # 'perp_rgd' | 'rgd' | 'gd'
        tt_solver: str = 'rgd',              # 'rgd' | 'projected' | 'iht'
        A_solver: str = 'rgd',               # 'rgd' | 'gd' | 'als'
        W_full_solver: str = 'gd',           # 'gd' | 'ls'
        W_cp_solver: str = 'gd',     # 'gd' | 'als'
        n_jobs: int = 0,             # ：ALS CPU row（ ALS ）
        ridge_eps: float = 1e-8,     # ：ALS numerically stablediagonal
        W_tucker_solver: str = 'rgd',   # 'als' | 'rgd'
        tucker_ranks: Optional[List[int]] = None,
        als_reg_A: float = 1e-6,
        perB: bool = True,                                                     
        init_A_method: str = 'eye',       # 'eye' | 'hosvd_hooi' | 'random'
        init_A_hooi_iters: int = 5, 
        init_B_method: str = 'random' ,       # 'dft' | 'random'
        basis_dims: Optional[object] = None, # None | int | List[Optional[int]]
        reg_A_l1: float = 0.0,
        reg_A_l2: float = 0.0,
                        
        learn_B: bool = True,
        sequential: bool = True,
        step_size_B: Optional[float] = None, # step_size_A
        reg_B_l2: float = 0.0,
        init_W_zero: bool = False,
        B_diag: bool = False,
        gdals_iters: int = 1,  
        if_all_learn_A: bool = True,
        joint_B_ortho: str = 'joint_stiefel',  # 'joint_stiefel' | 'block_only' | 'perp_seq'
        joint_iters: Optional[int] = None,                                                        
        combo_mode: str = 'diag',             # 'diag_skew'| 'diag'| 'bfs'｜'diag_skew_parallel'
        diag_skew_parallel_dist: float = 1.0,
        use_nan_to_num: bool = True,
        print_ab_stats: bool = False,
        print_grad_stats: bool = False,
        optimizer: str = 'gd',
        adam_betas: Tuple[float, float] = (0.9, 0.999),
        adam_eps: float = 1e-8,
        fixed_diag_max_k: bool = False,
    ):
        # device/random
        self.device = 'cpu' if device is None else device
        self.dtype = dtype
        self.random_state = random_state
        if random_state is not None:
            self.rng = torch.Generator(device=self.device)
            self.rng.manual_seed(random_state)
        else:
            self.rng = None

        # shape//
        self.block_sizes = list(block_sizes)
        # self.K = int(K)
        n_modes = len(self.block_sizes)
        if isinstance(K, int):
            self.K_list_ = [int(K)] * n_modes
        else:
            assert isinstance(K, (list, tuple)) and len(K) == n_modes,\
                "K must be an int or a list of ints with length equal to the number of modes."
            self.K_list_ = [int(k_i) for k_i in K]
        
        self.K = int(np.max(self.K_list_)) # self.K K_max
        self.K_max_ = self.K
        self.K_base_mode_ = int(np.argmax(self.K_list_))

        self.weight_type = weight_type.lower()
        self.ranks = ranks or {}
        self.reg_W = float(reg_W)
        self.step_size_A = float(step_size_A)
        self.step_size_W = None if step_size_W is None else float(step_size_W)
        self.step_size_G = None if step_size_G is None and step_size_W is None else float(step_size_W if step_size_G is None else step_size_G)
        self.warmup_A = int(warmup_A)
        self.n_iter_max = int(n_iter_max)
        self.tol = float(tol)
        self.basis_fn = basis_fn
        self.verbose = verbose
        self.logger = logger
        self.headinit_multi = headinit_multi

              
        self.learn_B = bool(learn_B)
        self.sequential = bool(sequential)
        self.step_size_B = float(step_size_A if step_size_B is None else step_size_B)
        self.reg_B_l2 = float(reg_B_l2)
        self.init_W_zero = bool(init_W_zero)

                   
        self.blocks_: Dict[int, List[torch.Tensor]] = {}
        self.combos_: List[Tuple[int, ...]] = []
        self.A_: Dict[Tuple[int, int], torch.Tensor] = {}      # if perB=False: (i, p), if perB=True: (k, i) where k is combo_index
        self.B_: Dict[Tuple[int, int], torch.Tensor] = {}      # (i, p)
        self.W_full_: Optional[torch.Tensor] = None
        self.W_cp_: Optional[Dict[str, torch.Tensor]] = None
        self.W_tucker_: Optional[Dict[str, torch.Tensor]] = None
        self.W_tt_: Optional[List[torch.Tensor]] = None

        self.in_shape_: Optional[Tuple[int, ...]] = None
        self.out_shape_: Optional[Tuple[int, ...]] = None
        self.losses_: List[float] = []
        self.metrics_: List[dict] = []
        self.history_: dict = {}

            
        self.early_stopping = early_stopping
        self.patience = patience
        self.best_loss = float('inf')
        self.best_epoch = 0
        self.no_improve_count = 0
        self.min_delta = float(min_delta)

              
        self.if_rgd = if_rgd
        self.tt_init = tt_init.lower()
        assert self.tt_init in ('spectral', 'random')
        self.forward_flat = bool(forward_flat)
        self.tt_solver = tt_solver.lower()
        assert self.tt_solver in ('rgd', 'projected', 'iht')
        

        self.A_solver = A_solver.lower(); assert self.A_solver in ('rgd', 'gd', 'als')
        self.W_full_solver = W_full_solver.lower(); assert self.W_full_solver in ('gd', 'ls',"als","rgd")
        self.W_cp_solver = W_cp_solver.lower(); assert self.W_cp_solver in ('gd','als')
        self.n_jobs = int(n_jobs)
        self.ridge_eps = float(ridge_eps)
        self.W_tucker_solver = W_tucker_solver.lower(); assert self.W_tucker_solver in ('als','rgd')
        self.tucker_ranks = tucker_ranks
        
        self.als_reg_A = float(als_reg_A)
        self.reg_A_l2 = float(reg_A_l2)
        self.reg_A_l1 = float(reg_A_l1)
        self.perB = bool(perB)                         

        self.init_A_method = init_A_method.lower(); assert self.init_A_method in ('eye', 'hosvd_hooi', 'random')
        self.init_A_hooi_iters = int(init_A_hooi_iters)
        self.init_B_method = init_B_method.lower(); assert self.init_B_method in ('dft', 'random')
        if self.A_solver == 'rgd':
            print('A update: Riemannian GD (orthogonal)')
        elif self.A_solver == 'gd':
            print('A update: Euclidean GD')
        else:
            print('A update: ALS (least squares, unconstrained)')

        if self.weight_type == 'full':
            print(f"W(full) update: {self.W_full_solver.upper()}")
        self.B_solver = B_solver.lower(); assert self.B_solver in ('perp_rgd', 'rgd', 'gd', 'als')
        if self.verbose:
            print(f"B update: {self.B_solver.upper()}")
            if self.B_solver == 'als':
                print("B update: ALS-like Orthogonal Projection (SVD-based)")
 
        if self.verbose:
            print(f"B update: {self.B_solver.upper()}")
        self.B_diag = B_diag
        self.gdals_iters=gdals_iters
        self.if_all_learn_A = if_all_learn_A

        self.joint_B_ortho = joint_B_ortho.lower()
        assert self.joint_B_ortho in ('joint_stiefel', 'block_only', 'perp_seq')
        self.joint_iters = int(joint_iters) if joint_iters is not None else None
        self.combo_mode = combo_mode.lower()
        assert self.combo_mode in ('diag', 'bfs', 'diag_skew', 'diag_skew_parallel')
        self.diag_skew_parallel_dist = float(diag_skew_parallel_dist)

        self.use_nan_to_num = bool(use_nan_to_num)
        self.print_ab_stats = bool(print_ab_stats)
        self.basis_dims = basis_dims
        self.print_grad_stats = print_grad_stats

        self.optimizer = optimizer.lower()
        assert self.optimizer in ('gd', 'adam'), "Optimizer must be 'gd' or 'adam'"
        self.adam_betas = adam_betas
        self.adam_eps = adam_eps
        self.fixed_diag_max_k = bool(fixed_diag_max_k)
        
                   
        self.adam_m_ = {}
        self.adam_v_ = {}
        self.adam_t_ = {}
        

        if self.verbose:
            self._log(f"Optimizer selected: {self.optimizer.upper()}")
            

                                               
    def _log(self, msg: str):
        if self.logger is not None:
            try:
                self.logger.info(msg); return
            except Exception:
                pass
        if self.verbose:
            print(msg)

    def _log_grad_stats(self, grad_tensor: torch.Tensor, name: str, context: str):
        """如果启用，则打印梯度张量的统计信息。"""
        if not self.print_grad_stats or grad_tensor is None:
            return
        
        # tensorempty
        if grad_tensor.numel() == 0 or not torch.all(torch.isfinite(grad_tensor)):
            self._log(f"[{context}] Grad Stats for {name}: Tensor is empty or contains non-finite values.")
            return

        stats = {
            "max": torch.max(grad_tensor).item(),
            "min": torch.min(grad_tensor).item(),
            "mean": torch.mean(grad_tensor).item(),
            "median": torch.median(grad_tensor).item(),
        }
        self._log(f"[{context}] Grad Stats for {name}: "
                  f"Max={stats['max']:.6g}, Min={stats['min']:.6g}, "
                  f"Mean={stats['mean']:.6g}, Median={stats['median']:.6g}")
    def _print_AB_stats(self):
                           
        def _fmt_norms(t: torch.Tensor):
            vals = torch.linalg.norm(t, dim=0).detach().cpu().numpy()
            return "[" + ", ".join(f"{v:.6g}" for v in vals) + "]"

        print("\n" + "="*20 + " Final Parameter Statistics " + "="*20)
        
                         
        try:
            W_full = self._predictor_weight_full()
            w_max = float(W_full.max().item())
            w_min = float(W_full.min().item())
            print(f"[W full] max={w_max:.6g}  min={w_min:.6g}  shape={tuple(W_full.shape)}")
        except Exception as e:
            print(f"[W full] Could not compute stats. Error: {e}")

                                 
        n_modes = len(self.block_sizes)
        for i in range(n_modes):
            for p in range(1, self.K_list_[i] + 1):
                A = self.A_.get((i, p), None)
                B = self.B_.get((i, p), None)

                if A is not None:
                    n_A = A.shape[1]
                    I_A = torch.eye(n_A, device=A.device, dtype=A.dtype)
                    a_max = float(A.max().item()); a_min = float(A.min().item())
                    A_I = A.T @ A - I_A
                    a_max_I = float(A_I.max().item()); a_min_I = float(A_I.min().item())
                    print(f"[A i={i} p={p}] max={a_max:.6g}  min={a_min:.6g}  ||cols||={_fmt_norms(A)}, shape = {A.shape}")
                    print(f"[A i={i} p={p} | A.T@A - I] max={a_max_I:.6g}  min={a_min_I:.6g}")
                else:
                    print(f"[A i={i} p={p}] (missing)")

                if B is not None:
                    n_B = B.shape[1]
                    I_B = torch.eye(n_B, device=B.device, dtype=B.dtype)
                    b_max = float(B.max().item()); b_min = float(B.min().item())
                    B_I = B.T @ B - I_B
                    b_max_I = float(B_I.max().item()); b_min_I = float(B_I.min().item())
                    print(f"[B i={i} p={p}] max={b_max:.6g}  min={b_min:.6g}  ||cols||={_fmt_norms(B)}, shape = {B.shape}")
                    print(f"[B i={i} p={p} | B.T@B - I] max={b_max_I:.6g}  min={b_min_I:.6g}")
                else:
                    print(f"[B i={i} p={p}] (missing)")

                               
        if self.K > 1:
            print("\n--- Inter-Block Orthogonality Check for B ---")
            for i in range(n_modes):
                try:
                                    
                    B_cat = self._concat_B_mode(i)
                    
                                  
                    total_cols = B_cat.shape[1]
                    I_cat = torch.eye(total_cols, device=B_cat.device, dtype=B_cat.dtype)
                    B_cat_I = B_cat.T @ B_cat - I_cat
                    
                    b_cat_max_I = float(B_cat_I.max().item())
                    b_cat_min_I = float(B_cat_I.min().item())
                    
                    print(f"[B_cat i={i} | B_cat.T@B_cat - I] max={b_cat_max_I:.6g}  min={b_cat_min_I:.6g}  shape={tuple(B_cat.shape)}")
                    # zero B_i,p orthogonal。
                except Exception as e:
                    print(f"[B_cat i={i}] Could not perform inter-block check. Error: {e}")
        
        print("="*64 + "\n")

    def _get_adam_step(self, param_key: str, grad: torch.Tensor, step_size: float) -> torch.Tensor:
        """
        为指定参数计算 Adam 更新步长。
        
        Args:
            param_key (str): 参数的唯一标识符 (e.g., 'W_full', 'A_0_1').
            grad (torch.Tensor): 该参数的当前梯度。
            step_size (float): 基础学习率。

        Returns:
            torch.Tensor: Adam 计算出的最终更新量 (相当于 step_size * corrected_grad)。
        """
        if self.optimizer != 'adam':
            return step_size * grad

        # Adam
        beta1, beta2 = self.adam_betas
        eps = self.adam_eps

        # initparameter
        m = self.adam_m_.get(param_key, torch.zeros_like(grad))
        v = self.adam_v_.get(param_key, torch.zeros_like(grad))
        t = self.adam_t_.get(param_key, 0) + 1

        # Adam update
        m = beta1 * m + (1 - beta1) * grad
        v = beta2 * v + (1 - beta2) * grad.pow(2)

              
        m_hat = m / (1 - beta1**t)
        v_hat = v / (1 - beta2**t)

        # update
        update = step_size * m_hat / (v_hat.sqrt() + eps)

        # update
        self.adam_m_[param_key] = m
        self.adam_v_[param_key] = v
        self.adam_t_[param_key] = t
        
        return update
                                                  
    def _rand_orth(self, rows: int, cols: int) -> torch.Tensor:
        # gensrc = self.rng if (self.rng is not None and self.device == 'cpu') else None
        # fix
        gensrc = self.rng if self.rng is not None else None
        M = torch.randn(rows, cols, device=self.device, dtype=self.dtype, generator=gensrc)
        Q, R = torch.linalg.qr(M, mode='reduced')
        s = torch.sign(torch.diag(R)); s[s==0] = 1.0
        return Q @ torch.diag(s)

    def _rand_orth_in_complement(self, rows: int, cols: int, Q_prev: Optional[torch.Tensor]) -> torch.Tensor:
        if Q_prev is None or Q_prev.numel()==0:
            return self._rand_orth(rows, cols)
        gensrc = self.rng if (self.rng is not None and self.device == 'cpu') else None
        M = torch.randn(rows, cols, device=self.device, dtype=self.dtype, generator=gensrc)
        # projectionorthogonal
        M = M - Q_prev @ (Q_prev.T @ M)
        Q, R = torch.linalg.qr(M, mode='reduced')
        s = torch.sign(torch.diag(R)); s[s==0] = 1.0
        return Q @ torch.diag(s)

    def _project_to_complement(self, B: torch.Tensor, Q_prev: Optional[torch.Tensor]) -> torch.Tensor:
        if Q_prev is None or Q_prev.numel()==0:
            return B
        return B - Q_prev @ (Q_prev.T @ B)

    def _concat_B_mode(self, i: int, order_pp: Optional[List[int]] = None) -> torch.Tensor:
        """把第 i 模的 B_{i,1..K} 在列上拼接成 (d_i, K*b_i)。"""
        if order_pp is None:
            order_pp = list(range(1, self.K_list_[i] + 1))
        Bs = [self.B_[(i, p)] for p in order_pp]
        return torch.cat(Bs, dim=1)

    def _split_B_mode(self, i: int, B_cat: torch.Tensor, order_pp: Optional[List[int]] = None):
        """把 (d_i, K*b_i) 列切回各个 B_{i,p}。"""
        if order_pp is None:
            order_pp = list(range(1, self.K_list_[i] + 1))
        b_i = int(self.block_sizes[i])
        offs = 0
        for p in order_pp:
            self.B_[(i, p)] = B_cat[:, offs:offs + b_i].contiguous()
            offs += b_i

                                              
    def _prepare_shapes_and_postdims(self, X: torch.Tensor, basis_dims) -> List[int]:
        n_modes = X.ndim - 1
        assert n_modes == len(self.block_sizes)
        d_list = list(X.shape[1:])
        # r_i
        if basis_dims is None:
            post_dims = [int(self.block_sizes[i]) for i in range(n_modes)]
        elif isinstance(basis_dims, int):
            post_dims = [min(int(self.block_sizes[i]), max(1, int(basis_dims))) for i in range(n_modes)]
        else:
            assert isinstance(basis_dims, (list, tuple)) and len(basis_dims) == n_modes
            post_dims = []
            for i in range(n_modes):
                bi = int(self.block_sizes[i])
                ri = bi if basis_dims[i] is None else int(basis_dims[i])
                post_dims.append(min(bi, max(1, ri)))
        assert all(post_dims[i] <= int(self.block_sizes[i]) for i in range(n_modes))
        self.post_dims_ = post_dims
        return d_list

    def _make_diag_combos(self, n_modes: int):
        # combos_ = [(1,1,...,1), (2,2,...,2), ..., (K,K,...,K)]
        self.combos_ = [tuple([p]*n_modes) for p in range(1, self.K+1)]
        # k index（ p）
        self._k_groups = {}
        for i in range(n_modes):
            groups_i = {}
            for p in range(1, self.K+1):
                groups_i[p] = torch.as_tensor([p-1], device=self.device, dtype=torch.long)
            self._k_groups[i] = groups_i

    def _make_bfs_layer(self, p: int, n_modes: int):
        """返回该层的所有组合：ki ∈ {1..p} 且 max(ki)=p。"""
        from itertools import product
        all_up_to_p = list(product(range(1, p+1), repeat=n_modes))
        layer = [t for t in all_up_to_p if max(t) == p]
        return layer

    def _make_bfs_combos_up_to(self, p: int, n_modes: int):
        """返回到第 p 层为止的全部组合：ki ∈ {1..p}，大小 = p^n。"""
        from itertools import product
        return list(product(range(1, p+1), repeat=n_modes))

    def _make_diag_skew_combos(self, n_modes: int):
        """
        生成 'diag_skew' 组合。
        使用比例映射逻辑，将基准模式的组件索引映射到其他模式。
        采用“逢五向下舍入”规则，即 ceil(x - 0.5)。
        """
        if self.fixed_diag_max_k:
                                            
            k_list_for_mapping = self.K_max_possible_list_
            self._log(f"Using fixed diagonal based on max possible K: {k_list_for_mapping}")
        else:
            # ，useKcol
            k_list_for_mapping = self.K_list_
        
        assert len(k_list_for_mapping) == n_modes, "K list for mapping length mismatch."

        # Kcol，modeK
        base_mode = int(np.argmax(k_list_for_mapping))
        K_base_val = k_list_for_mapping[base_mode]
        
        combos = []
        for p_base in range(1, K_base_val + 1):
            current_combo = [0] * n_modes
            
            for i in range(n_modes):
                if i == base_mode: # use base_mode
                    current_combo[i] = p_base
                    continue
                
                K_target_val = k_list_for_mapping[i] # usecol
                
                # K=1 ，zero
                if K_base_val == 1:
                    val_mapped = 1.0
                else:
                    val_mapped = (p_base - 1) * (K_target_val - 1) / (K_base_val - 1) + 1
                
                # : ceil(x - 0.5)
                p_target = torch.ceil(torch.tensor(val_mapped - 0.5, dtype=self.dtype)).int().item()
                current_combo[i] = max(1, p_target) # 1
            
            combos.append(tuple(current_combo))
            
        self.combos_ = combos

    def _make_diag_skew_parallel_combos(self, n_modes: int):
        """
        生成 'diag_skew_parallel' 组合。
        首先定义一条穿过 (1,1,...) 和 (K_list_[0], ..., K_list_[n-1]) 的“对角线”。
        然后遍历所有可能的组合点 (p_0, ..., p_{n-1})，其中 1 <= p_i <= K_list_[i]。
        计算每个点到这条线的垂直距离，如果距离小于或等于 self.diag_skew_parallel_dist，
        则将该组合选中。
        """
        from itertools import product
        
        # 1. N empty
        # P1 P2
        p1 = torch.ones(n_modes, device=self.device, dtype=self.dtype)
        p2 = torch.tensor(self.K_list_, device=self.device, dtype=self.dtype)

        # vector v
        # K_list_ 1 ()
        v = p2 - p1
        v_norm_sq = torch.dot(v, v)
        if v_norm_sq < 1e-12:
            self.combos_ = [tuple(map(int, p1.tolist()))]
            self._log(f"diag_skew_parallel: Line degenerated to a point. Selected combo: {self.combos_}")
            return

                                      
                                                          
        possible_points_ranges = [range(1, k + 1) for k in self.K_list_]
        candidate_combos = list(product(*possible_points_ranges))

        # 3.
        selected_combos = []
        for combo in candidate_combos:
            p0 = torch.tensor(combo, device=self.device, dtype=self.dtype)

            # p0
            # vector a = p0 - p1
            a = p0 - p1
            
            # p0 p1 v projection
            # t = (a . v) / ||v||^2
            t = torch.dot(a, v) / v_norm_sq
            projection_point = p1 + t * v
            
            # p0 projection
            distance = torch.linalg.norm(p0 - projection_point)

            if distance <= self.diag_skew_parallel_dist:
                selected_combos.append(combo)

        self.combos_ = selected_combos
        self._log(f"diag_skew_parallel: Selected {len(self.combos_)} combos with distance <= {self.diag_skew_parallel_dist}")

                                                    
    def _build_Xcore_tensor(self, X: torch.Tensor, p_max: Optional[int] = None) -> torch.Tensor:
        """
        根据 p_max 构造核心张量 Xc。
        p_max: 当前激活到的最大组件编号。凡是 combo 中包含大于 p_max 的组件，其核心都为零。
        """
        if p_max is None:
            p_max = self.K  # ，component

        N = X.shape[0]
        n = X.ndim - 1
        cores = []

                               
        core_base_shape = (N,) + tuple(self.post_dims_)

        for k, combo in enumerate(self.combos_):  # combo n , modeuse p (1..K)
            
                                     
                                         
            # if max(combo) > p_max:
            # cores.append(torch.zeros(core_base_shape, device=self.device, dtype=self.dtype))
            # continue
            is_active = True
            if self.combo_mode == 'diag_skew':
                                             
                stage_p = combo[self.K_base_mode_]
                if stage_p > p_max:
                    is_active = False
            else:
                                     
                if max(combo) > p_max:
                    is_active = False
            
            if not is_active:
                cores.append(torch.zeros(core_base_shape, device=self.device, dtype=self.dtype))
                continue

            missing = False
            # useexistencerowcheck
            for i in range(n):
                pp = combo[i]
                if not self.B_exists_.get((i, pp), False) or not self.A_exists_.get((i, pp), False):
                    missing = True
                    break
            
            if missing:
                cores.append(torch.zeros(core_base_shape, device=self.device, dtype=self.dtype))
                continue

            # ：modeuse (B_{i,pp} @ A_{i,pp})
            T = X
            for i in range(n):
                pp = combo[i]
                
                if self.perB:
                    # perB=True: A combo (k) mode (i)
                    A = self.A_[(k, i)]
                else:
                    # perB=False: A mode (i) component (pp)
                    A = self.A_[(i, pp)]

                BA = self.B_[(i, pp)] @ A # (d_i, r_i)
                T = mode_n_product_batch(T, BA, i)

            cores.append(T)

        return torch.stack(cores, dim=n + 1)  # (N, r1,...,rn, K_eff)

    def _ensure_W_capacity_full(self, Xc_shape: Tuple[int, ...]):
        """确保 W_full_ 的形状与当前 Xc 的输入形状一致（尤其是组合维 K_eff）。"""
        n = len(self.in_shape_)  # inputmode
        desired_shape = list(Xc_shape[1:n+2]) + list(self.out_shape_)  # (r1..rn, K_eff, o1..om)

        if self.W_full_ is None:
            in_prod = int(np.prod(desired_shape[:n+1]))
            out_prod = int(np.prod(self.out_shape_))
            if self.init_W_zero:
                W = torch.zeros(in_prod, out_prod, device=self.device, dtype=self.dtype)
            else:
                gensrc = (self.rng if (self.rng is not None and self.device=='cpu') else None)
                W = torch.randn(in_prod, out_prod, device=self.device, dtype=self.dtype, generator=gensrc) * self.headinit_multi
            self.W_full_ = W.reshape(desired_shape)
            return

        cur = list(self.W_full_.shape)
        if cur == desired_shape:
            return

                                            
        assert cur[:n] == desired_shape[:n], f"post_dims 不能在训练中改变: {cur[:n]} vs {desired_shape[:n]}"
        assert cur[n+1:] == desired_shape[n+1:], f"输出形状不能改变: {cur[n+1:]} vs {desired_shape[n+1:]}"

        cur_K = cur[n]
        new_K = desired_shape[n]
        assert new_K >= cur_K, "不支持缩小组合维"

        if new_K == cur_K:
            return

        pad_shape = cur.copy()
        pad_shape[n] = new_K - cur_K

        if self.init_W_zero:
            pad = torch.zeros(pad_shape, device=self.device, dtype=self.dtype)
        else:
            gensrc = (self.rng if (self.rng is not None and self.device=='cpu') else None)
            pad = torch.randn(pad_shape, device=self.device, dtype=self.dtype, generator=gensrc) * self.headinit_multi

        self.W_full_ = torch.cat([self.W_full_, pad], dim=n)


                                                                   
    def _unfold_batch_mode(self, T: torch.Tensor, mode: int) -> torch.Tensor:
        T = T.movedim(1 + mode, -2).contiguous()
        return T.reshape(T.shape[0], T.shape[-2], -1).contiguous()

    def _delta_all(self, Xc: torch.Tensor, Y: torch.Tensor):
        with torch.no_grad():
            Yhat = self._predict_given_W(Xc)
        R = Yhat - Y
        with torch.no_grad():
            W_full = self._predictor_weight_full()
        n = Xc.ndim - 2
        m = len(self.out_shape_)
        D_all = torch.tensordot(
            R, W_full,
            dims=(list(range(1, 1 + m)), list(range(n + 1, n + 1 + m)))
        )                        
        D_all = D_all / max(1, int(np.prod(self.out_shape_)))
        return D_all, Yhat

    def _grad_M_for_combo(self, X: torch.Tensor, D_all: torch.Tensor, k: int, i: int) -> torch.Tensor:
        """为单个组合 k 和模式 i 计算 M 的梯度 ∂L/∂M_{k,i}。仅在 perB=True 时使用。"""
        n = X.ndim - 1
        combo = self.combos_[k]

        # 1. combo k error D_k
        D_k = D_all.select(-1, k).unsqueeze(-1)  # (N, r1..rn, 1)

                                              
        T_c = X
        is_combo_valid = True
        for m in range(n):
            if m == i:
                continue
            
            # checkcomponent
            pp_m = combo[m]
            if not self.B_exists_.get((m, pp_m), False) or not self.A_exists_.get((k, m), False):
                is_combo_valid = False
                break
            
            A_m = self.A_[(k, m)]
            M_m = self.B_[(m, pp_m)] @ A_m
            T_c = mode_n_product_batch(T_c, M_m, m)

        if not is_combo_valid:
            d_i, r_i = self.in_shape_[i], self.post_dims_[i]
            return torch.zeros(d_i, r_i, device=self.device, dtype=self.dtype)

        # 3. unfoldgradient
        V_c = self._unfold_batch_mode(T_c, i)  # (N, d_i, rest_r)
        U_c = self._unfold_batch_mode(D_k, i)  # (N, r_i, rest_r)
        GiM_k = torch.einsum('sdr, skr -> dk', V_c, U_c)
        
        return GiM_k / max(1, X.shape[0])

    def _grad_M_for(self, X: torch.Tensor, D_all: torch.Tensor, p: int, i: int) -> torch.Tensor:
        n = X.ndim - 1
        d_i = int(self.in_shape_[i])
        r_i = int(self.post_dims_[i])
        
                                     
        combo_indices = self._k_groups[i].get(p)
        if combo_indices is None or combo_indices.numel() == 0:
            return torch.zeros(d_i, r_i, device=self.device, dtype=self.dtype)

        # initgradient
        GiM_total = torch.zeros(d_i, r_i, device=self.device, dtype=self.dtype)

                    
        for combo_idx in combo_indices:
            combo = self.combos_[combo_idx.item()]

                                                         
            T_c = X
            is_combo_valid = True
            for m in range(n):
                if m == i:
                    continue
                                      
                p_m = combo[m]

                                           
                                      
                b_exists = self.B_exists_.get((m, p_m), False)
                a_exists = False
                if self.perB:
                    # perB=True, A key (combo_index, mode)
                    a_exists = self.A_exists_.get((combo_idx.item(), m), False)
                else:
                    # perB=False, A key (mode, p)
                    a_exists = self.A_exists_.get((m, p_m), False)

                if not b_exists or not a_exists:
                    is_combo_valid = False
                    break
                
                              
                # perB A matrix
                if self.perB:
                    A_m = self.A_[(combo_idx.item(), m)]
                else:
                    A_m = self.A_[(m, p_m)]

                M_m = self.B_[(m, p_m)] @ A_m # use A_m
                T_c = mode_n_product_batch(T_c, M_m, m)

                                         
            if not is_combo_valid:
                continue
            
            # unfold V_c
            V_c = self._unfold_batch_mode(T_c, i) # (N, d_i, rest_r)

                                                         
            D_c = D_all.index_select(-1, combo_idx.view(1)) # (N, r1..rn, 1)
            U_c = self._unfold_batch_mode(D_c, i)          # (N, r_i, rest_r)

                                       
            grad_contribution = torch.einsum('sdr, skr -> dk', V_c, U_c)
            GiM_total += grad_contribution

        # regressiongradient
        return GiM_total / max(1, X.shape[0])

    def _grad_B_mode_all(self, X: torch.Tensor, D_all: torch.Tensor, i: int) -> torch.Tensor:
        """
        计算第 i 模下，对所有 p 的欧氏梯度 G_E(i,p) 并在列上拼接：
        返回形状 (d_i, K*b_i)
        """
        grads = []
        for p in range(1, self.K+1):
            if p > self.K_list_[i]:
                continue
            
            G_E = None
            if self.perB:
                # perB=True: combo gradient
                d_i, b_i = self.B_[(i, p)].shape
                G_E_total = torch.zeros(d_i, b_i, device=self.device, dtype=self.dtype)
                
                combo_indices = self._k_groups[i].get(p)
                if combo_indices is not None and combo_indices.numel() > 0:
                    for k_tensor in combo_indices:
                        k = k_tensor.item()
                        # combo M gradient
                        GiM_k = self._grad_M_for_combo(X, D_all, k, i)
                        # A matrix
                        A_ki = self.A_[(k, i)]
                        # B gradient
                        G_E_total += GiM_k @ A_ki.T
                G_E = G_E_total
            else:
                # perB=False: use
                GiM = self._grad_M_for(X, D_all, p, i)
                Aij = self.A_[(i, p)]
                G_E = GiM @ Aij.T

            if self.reg_B_l2 > 0:
                G_E = G_E + self.reg_B_l2 * self.B_[(i, p)]
            _clip_inplace(G_E, 1e3)
            grads.append(G_E)
        return torch.cat(grads, dim=1)                      # (d_i, K*b_i)


                                                       
    def _predictor_weight_full(self) -> torch.Tensor:
        if self.weight_type == 'cp':
            return self._cp_reconstruct_full()
        if self.weight_type == 'tucker':
            return self._tucker_reconstruct_full()
        if self.weight_type == 'tt':
            return self._tt_reconstruct_full_from_cores()
        return self.W_full_

    def _predict_given_W(self, Xc: torch.Tensor) -> torch.Tensor:
        if self.weight_type == 'full':
            W_full = self.W_full_
        else:
            W_full = self._predictor_weight_full()
        n = Xc.ndim - 2
        if self.forward_flat:
            in_prod  = int(np.prod(Xc.shape[1:n+2]))
            out_prod = int(np.prod(self.out_shape_))
            Xc_flat  = Xc.reshape(Xc.shape[0], in_prod)
            W_mat    = W_full.reshape(in_prod, out_prod)
            Y_flat   = Xc_flat @ W_mat
            return Y_flat.reshape(Xc.shape[0], *self.out_shape_)
        else:
            Y = torch.tensordot(
                Xc, W_full,
                dims=(list(range(1, n + 2)), list(range(0, n + 1)))
            )
            return Y

                                                          
    def _residual_and_grad_Wfull(self, Xc: torch.Tensor, Y: torch.Tensor, context: str = "W-update"):
        Yhat = self._predict_given_W(Xc)
        R = Yhat - Y
        if self.weight_type == 'full':
            W_full = self.W_full_
        else:
            W_full = self._predictor_weight_full()
        N = Xc.shape[0]
        O = int(np.prod(self.out_shape_))
        scale = 1.0 / float(max(1, N * O))
        if self.forward_flat:
            n = Xc.ndim - 2
            in_prod  = int(np.prod(Xc.shape[1:n+2]))
            out_prod = O
            Xc_flat  = Xc.reshape(N, in_prod)
            R_flat   = R.reshape(N, out_prod)
            G_mat    = Xc_flat.T @ R_flat
            G_mat    = scale * G_mat
            G_W_full = G_mat.reshape(W_full.shape)
        else:
            G_W_full = torch.tensordot(Xc, R, dims=([0],[0]))
            G_W_full = scale * G_W_full
        
        self._log_grad_stats(G_W_full, "Grad(W_full)", context)
        
        if self.reg_W > 0 and self.weight_type in ('full','cp','tucker'):
            G_W_full = G_W_full + self.reg_W * W_full
        return R, G_W_full

    # def _step_W_full_rgd(self, G_W_full: torch.Tensor, step: float):
    # # self.W_full_ = torch.nan_to_num(self.W_full_ - step * G_W_full.to(self.W_full_.dtype))
    # updated_W = self.W_full_ - step * G_W_full.to(self.W_full_.dtype)
    # if self.use_nan_to_num:
    # updated_W = torch.nan_to_num(updated_W)
    # self.W_full_ = updated_W
    def _step_W_full_rgd(self, G_W_full: torch.Tensor, step: float):
        # 1. parameter key
        param_key = 'W_full'
        # 2. updatestep size (Adam GD)
        update_step = self._get_adam_step(param_key, G_W_full, step)
        # 3. update
        updated_W = self.W_full_ - update_step.to(self.W_full_.dtype)
        if self.use_nan_to_num:
            updated_W = torch.nan_to_num(updated_W)
        self.W_full_ = updated_W

    def _step_W_full_ls(self, Xc: torch.Tensor, Y: torch.Tensor):
        n = Xc.ndim - 2
        in_prod  = int(np.prod(Xc.shape[1:n+2]))
        out_prod = int(np.prod(self.out_shape_))
        X_flat   = Xc.reshape(Xc.shape[0], in_prod)
        Y_flat   = Y.reshape(Y.shape[0], out_prod)
        XT = X_flat.T
        gram = XT @ X_flat
        if self.reg_W > 0:
            gram = gram + self.reg_W * torch.eye(gram.shape[0], device=self.device, dtype=self.dtype)
        RHS = XT @ Y_flat
        try:
            L = torch.linalg.cholesky(gram)
            W_mat = torch.cholesky_solve(RHS, L)
        except Exception:
            W_mat = torch.linalg.solve(gram, RHS)
        self.W_full_ = W_mat.reshape(list(Xc.shape[1:n+2]) + list(self.out_shape_))

                                            
    def _cp_reconstruct_full(self) -> torch.Tensor:
        facs = self.W_cp_['factors']; lam = self.W_cp_['lambda']
        tl.set_backend('pytorch')
        W = cp_to_tensor((lam, facs))  # return tl tensor（pytorch backend）， torch.Tensor
        return torch.as_tensor(W, device=self.device, dtype=self.dtype)



    def _tucker_reconstruct_full(self) -> torch.Tensor:
        Us = self.W_tucker_['Us']; G = self.W_tucker_['G']
        tl.set_backend('pytorch')
        W_full = tucker_to_tensor((G, Us))      # torch.Tensor（pytorch backend）
        return torch.as_tensor(W_full, device=self.device, dtype=self.dtype)


    def _tt_reconstruct_full_from_cores(self) -> torch.Tensor:
        assert isinstance(self.W_tt_, list) and len(self.W_tt_)>0
        tl.set_backend('pytorch')
        return tt_to_tensor(TTTensor(self.W_tt_))

    def _cp_step_als(self, Xc: torch.Tensor, Y: torch.Tensor):
        """
        一次完整的 CP-ALS sweep（逐因子带岭回归闭式解）。
        因子顺序：前 p 个对应 Xc 的各模（含 K），后面对应输出各模；p = Xc.ndim - 1。
        W 的 CP 形状 = (r1,...,rn,K, o1,...,om)。
        """
        assert self.weight_type == 'cp', "CP-ALS 仅在 weight_type='cp' 时可用"

        device = self.device
        dtype  = self.dtype
        eps    = getattr(self, 'ridge_eps', 1e-8)  # numerically stable
        reg    = float(getattr(self, 'reg_W', 0.0))
        R      = int(self.W_cp_['R'])

        # factorshape
        factors: List[torch.Tensor] = self.W_cp_['factors']
        lam: torch.Tensor           = self.W_cp_.get('lambda', torch.ones(R, device=device, dtype=dtype))
        p = Xc.ndim - 1                              # inputmode（ K）
        in_dims  = list(Xc.shape[1:])                # [r1,...,rn,K]
        out_dims = list(self.out_shape_)             # [o1,...,om]
        N        = int(Xc.shape[0])
        in_prod  = int(np.prod(in_dims)) if in_dims else 1
        out_prod = int(np.prod(out_dims)) if out_dims else 1

                                                  
        for i in range(len(factors)):
            if i < p:
                # dimension
                I_i = int(in_dims[i])
                other_in_prod = int(in_prod // I_i) if I_i > 0 else 1

                # X (1+i) unfoldrow： (N*I_i, other_in_prod)
                # batch ，input
                perm = [0, 1 + i] + [ax for ax in range(1, 1 + p) if ax != 1 + i]
                X_unf = Xc.permute(perm).contiguous().reshape(N * I_i, other_in_prod)

                # factor（input + output） Khatri-Rao， i
                # KR shape = (other_in_prod * out_prod, R) ——permute (other_in_prod, out_prod * R)
                KR_all = khatri_rao([factors[j] for j in range(len(factors)) if j != i])
                KR_all = KR_all.reshape(other_in_prod, out_prod * R)

                # matrix phi： (N*I_i, out_prod*R)
                phi = X_unf @ KR_all

                # (N*out_prod, I_i*R)， y vector
                phi = phi.reshape(N, I_i, out_prod, R).permute(0, 2, 1, 3).reshape(N * out_prod, I_i * R)

                # vector： (N*out_prod,)
                y_vec = Y.reshape(-1)

                # ridgeregression： (Phi^T Phi + αI) vec(U_i) = Phi^T y
                AtA = phi.T @ phi
                if reg > 0 or eps > 0:
                    AtA = AtA + (reg + eps) * torch.eye(I_i * R, device=device, dtype=dtype)
                Atb = phi.T @ y_vec  # (I_i*R,)

                sol = torch.linalg.solve(AtA, Atb.unsqueeze(1)).reshape(I_i, R)
                # factors[i] = torch.nan_to_num(sol)
                if self.use_nan_to_num:
                    sol = torch.nan_to_num(sol)
                factors[i] = sol

            else:
                                                           
                axis = i - p                  # output axis mode
                O_i  = int(out_dims[axis]) if len(out_dims) > 0 else 1
                other_out_prod = int(out_prod // O_i) if O_i > 0 else 1

                # Xc (N, in_prod)
                Xv = Xc.reshape(N, in_prod)

                # factor（outputfactor）Khatri-Rao：
                # shape = (in_prod * other_out_prod, R) ——permute (in_prod, other_out_prod * R)
                KR_all = khatri_rao([factors[j] for j in range(len(factors)) if j != i])
                KR_all = KR_all.reshape(in_prod, other_out_prod * R)

                # matrix： (N, other_out_prod*R) -> (N*other_out_prod, R)
                phi = (Xv @ KR_all).reshape(N * other_out_prod, R)

                # matrix：outputmode => (N*other_out_prod, O_i)
                y_unf = Y.movedim(1 + axis, -1).reshape(-1, O_i)

                # ridgeregression： (Phi^T Phi + αI) U_i^T = Phi^T Y_unf
                AtA = phi.T @ phi
                if reg > 0 or eps > 0:
                    AtA = AtA + (reg + eps) * torch.eye(R, device=device, dtype=dtype)
                AtB = phi.T @ y_unf  # (R, O_i)

                sol_T = torch.linalg.solve(AtA, AtB)   # (R, O_i)
                # factors[i] = torch.nan_to_num(sol_T.T) # (O_i, R)
                sol = sol_T.T
                if self.use_nan_to_num:
                    sol = torch.nan_to_num(sol)
                factors[i] = sol

        # （lambda ；col/ lambda）
        self.W_cp_['factors'] = factors
        self.W_cp_['lambda']  = lam

    def _cp_step_gd(self, Xc: torch.Tensor, Y: torch.Tensor):
        """对 CP 权重做一轮 GD：用 full 张量梯度 G_W_full，经 MTTKRP 映到各因子。
        """
        # full tensorgradient
        _, G_W_full = self._residual_and_grad_Wfull(Xc, Y)  # shape = (r1,...,rn,K,o1,...,om)
        p = Xc.ndim - 1
        factors: List[torch.Tensor] = self.W_cp_['factors']
        lam: torch.Tensor = self.W_cp_['lambda']
        R = self.W_cp_['R']

        # unfold helper
        def _unfold(T: torch.Tensor, k: int) -> torch.Tensor:
            dims = list(range(T.ndim))
            perm = [k] + [d for d in dims if d != k]
            return T.permute(perm).reshape(T.shape[k], -1)  # (dim_k, prod_others)

        # factor MTTKRP： grad_Uk = unfold(G, k) @ (KR(other factors) ⊙ diag(lambda))
        for k in range(len(factors)):
            Gk = _unfold(G_W_full, k).T                     # (prod_others, dim_k)
            mats = [factors[j] for j in range(len(factors)) if j != k]
            KR = khatri_rao(mats)                           # (prod_others, R)
            # col lambda
            KR = KR * lam.unsqueeze(0)
            grad_Uk = Gk.T @ KR                              # (dim_k, R)
                
            if self.reg_W > 0:
                grad_Uk = grad_Uk + self.reg_W * factors[k]
            # gradientstep
            # factors[k] = torch.nan_to_num(factors[k] - (self.step_size_W or 1e-3) * grad_Uk)
            # updated_factor = factors[k] - (self.step_size_W or 1e-3) * grad_Uk
            # if self.use_nan_to_num:
            # updated_factor = torch.nan_to_num(updated_factor)
            param_key = f'W_cp_factor_{k}'
            step_size = self.step_size_W or 1e-3
            update_step = self._get_adam_step(param_key, grad_Uk, step_size)
            updated_factor = factors[k] - update_step
            if self.use_nan_to_num:
                updated_factor = torch.nan_to_num(updated_factor)
            factors[k] = updated_factor


        # update lambda： full gradient rank projection
        # G_W_full rank-1
        # KR_all = ⊙_{factor} U_k colvector（） gvec
        KR_all = khatri_rao(factors[::-1])                  # (prod_all_modes, R)
        gvec = G_W_full.reshape(-1)                         # (prod_all_modes,)
        grad_lam = KR_all.T @ gvec                          # (R,)
        if self.reg_W > 0:
            grad_lam = grad_lam + self.reg_W * lam
        # # self.W_cp_['lambda'] = torch.nan_to_num(lam - (self.step_size_W or 1e-3) * grad_lam)
        # updated_lambda = lam - (self.step_size_W or 1e-3) * grad_lam
        # if self.use_nan_to_num:
        # updated_lambda = torch.nan_to_num(updated_lambda)
        # self.W_cp_['lambda'] = updated_lambda
        param_key = 'W_cp_lambda'
        step_size = self.step_size_W or 1e-3
        update_step = self._get_adam_step(param_key, grad_lam, step_size)
        
        updated_lambda = lam - update_step
        if self.use_nan_to_num:
            updated_lambda = torch.nan_to_num(updated_lambda)
        self.W_cp_['lambda'] = updated_lambda
        self.W_cp_['factors'] = factors

    def _tucker_step_als(self, Xc: torch.Tensor, Y: torch.Tensor):
        """一次完整的 Tucker-ALS：依次更新输入侧因子、输出侧因子、核心 G（均为岭回归闭式）。"""
        device, dtype = self.device, self.dtype
        reg = float(getattr(self, 'reg_W', 0.0))

               
        p = Xc.ndim - 1                                  # inputmode（ K）
        Us: List[torch.Tensor] = self.W_tucker_['Us']    # [U_0,...,U_{p+q-1}]
        G: torch.Tensor          = self.W_tucker_['G']
        ranks: List[int]         = self.W_tucker_['ranks']
        ranks_real = [U.shape[1] for U in Us]
        if tuple(G.shape) != tuple(ranks_real):
            G_new = torch.zeros(ranks_real, device=device, dtype=dtype)
            common = [min(a, b) for a, b in zip(G.shape, ranks_real)]
            sl = tuple(slice(0, c) for c in common)
            G_new[sl] = G[sl]      # （：random/zeroinit）
            G = G_new
            self.W_tucker_['G'] = G
        self.W_tucker_['ranks'] = ranks_real   # step
        in_dims  = list(Xc.shape[1:])                    # [r1,...,rn,K]
        out_dims = list(self.out_shape_)                 # [o1,...,om]
        N        = int(Xc.shape[0])
        in_prod  = int(np.prod(in_dims)) if in_dims else 1
        out_prod = int(np.prod(out_dims)) if out_dims else 1
        K = len(out_dims)

                                                                          
        def _kron_last_dim(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
            # A:(M, R), B:(N, S) -> (M, N, R*S)
            return (A[:, None, :, None] * B[None, :, None, :]).reshape(A.shape[0], B.shape[0], A.shape[1]*B.shape[1])

                                                
        for i in range(p):
            I_i  = in_dims[i]
            R_i  = Us[i].shape[1]
            other_in_prod = in_prod // I_i

            # Xc i inputmodeunfoldrow：(N*I_i, other_in_prod)
            perm = [0, 1 + i] + [ax for ax in range(1, 1 + p) if ax != 1 + i]
            X_unf = Xc.permute(perm).contiguous().reshape(N * I_i, other_in_prod)
            def _kron_list(mats: List[torch.Tensor]) -> torch.Tensor:
                if len(mats) == 0:
                    return torch.ones((1, 1), device=device, dtype=dtype)
                K = mats[0].contiguous()
                for M in mats[1:]:
                    K = torch.kron(K, M.contiguous())
                return K.contiguous()
            # factor（input + output） Kronecker
            U_in_except = [Us[j] for j in range(p) if j != i]
            U_out_all   = [Us[p + j] for j in range(len(out_dims))]
            K_in_except = _kron_list([U.contiguous() for U in U_in_except]) if len(U_in_except) > 0 else torch.ones((1, 1), device=device, dtype=dtype)
            K_out_all   = _kron_list([U.contiguous() for U in U_out_all])   if len(U_out_all)   > 0 else torch.ones((1, 1), device=device, dtype=dtype)



                                       
            I_other = other_in_prod
            J_all   = out_prod
            R_except = int(np.prod([Us[j].shape[1] for j in range(p) if j != i])) if p > 1 else 1
            S_all = int(np.prod([Us[p + j].shape[1] for j in range(len(out_dims))])) if len(out_dims) > 0 else 1

                                                
            # input（ i）
            if K_in_except.shape[0] == I_other and K_in_except.shape[1] == R_except:
                pass
            elif K_in_except.shape[1] == I_other and K_in_except.shape[0] == R_except:
                K_in_except = K_in_except.T
            elif K_in_except.shape[0] == R_except and K_in_except.shape[1] == I_other:
                K_in_except = K_in_except.T
            # output（）
            if K_out_all.shape[0] == J_all and K_out_all.shape[1] == S_all:
                pass
            elif K_out_all.shape[1] == J_all and K_out_all.shape[0] == S_all:
                K_out_all = K_out_all.T
            elif K_out_all.shape[0] == S_all and K_out_all.shape[1] == J_all:
                K_out_all = K_out_all.T

                                                        
            A = _kron_last_dim(K_in_except, K_out_all)  # (I_other, J_all, R_except*S_all)



            # unfold(G, i) (R_i, R_other * S_all)； (R_other * S_all, R_i)
            Gi = self._unfold_k(G, i)   # unfold
            # unfold(G, i) (R_i, R_except*S_all)； (R_except*S_all, R_i)
            if Gi.shape[0] == R_i and Gi.shape[1] == R_except * S_all:
                G_i = Gi.T
            elif Gi.shape[1] == R_i and Gi.shape[0] == R_except * S_all:
                G_i = Gi
            else:
                raise RuntimeError(f"unfold(G,{i}) unexpected shape {Gi.shape}, expect either (R_i,{R_except*S_all}) or ({R_except*S_all},R_i)")

            Phi_3d = torch.tensordot(A, G_i, dims=([2], [0]))  # (I_other, J_all, R_i)



            # X_unf，：(N, out_prod, I_i, R_i)
            Phi = (X_unf @ Phi_3d.reshape(other_in_prod, -1)).reshape(N, out_prod, I_i, R_i)

            # (N*out_prod, I_i*R_i)，vector vec(Y) = (N*out_prod,)
            Phi_flat = Phi.permute(0, 1, 2, 3).reshape(N * out_prod, I_i * R_i)
            y_vec = Y.reshape(-1)

            # ridgeregression：(Phi^T Phi + αI) vec(U_i) = Phi^T y
            AtA = Phi_flat.T @ Phi_flat
            if reg > 0:
                AtA = AtA + reg * torch.eye(I_i * R_i, device=device, dtype=dtype)
            Atb = Phi_flat.T @ y_vec

            sol = torch.linalg.solve(AtA, Atb.unsqueeze(1)).reshape(I_i, R_i)
            # Us[i] = torch.nan_to_num(sol)
            if self.use_nan_to_num:
                sol = torch.nan_to_num(sol)
            Us[i] = sol

                                                  
        Xv = Xc.reshape(N, in_prod)
        K_in_all = kronecker(Us[:p]) if p > 0 else torch.eye(in_prod, device=device, dtype=dtype)
        Z = Xv @ K_in_all                            # (N, ΠR_in)

        # G permute (R_in, R_out)
        R_in  = int(np.prod([Us[j].shape[1] for j in range(p)])) if p > 0 else 1
        R_out = int(np.prod([Us[p + j].shape[1] for j in range(K)])) if K > 0 else 1

        G_inout = G.reshape(R_in, R_out)

        S_code = Z @ G_inout                         # (N, R_out)
        # (N, S1,...,Sk)
        S_dims = [Us[p + j].shape[1] for j in range(K)]
        S_full = S_code.reshape((N,) + tuple(S_dims))

        K = len(out_dims)
        for j in range(K):
            Jj = out_dims[j]; Sj = S_dims[j]
            # outputmode Kronecker
            H_j = kronecker([Us[p + q] for q in range(K) if q != j]) if K > 1 else torch.ones((1, 1), device=device, dtype=dtype)

                                               
            if K == 1:
                Y_j = Y.transpose(0, 1).reshape(Jj, -1)  # (Jj, N)
            else:
                axes_y = (0, 1 + j) + tuple(1 + q for q in range(K) if q != j)
                Y_perm = Y.permute(axes_y)
                Y_j = Y_perm.permute(1, 0, 2).reshape(Jj, -1)

                                             
            axes_s = (0, 1 + j) + tuple(1 + q for q in range(K) if q != j)
            S_perm = S_full.permute(axes_s)                        # (N, Sj, S_other)
            S_flat = S_perm.reshape(S_perm.shape[0] * Sj, -1)      # (N*Sj, S_other)
            M_big  = S_flat @ H_j.T                                # (N*Sj, J_other)
            M_j    = M_big.reshape(-1, Sj).T                       # (Sj, N*J_other)

            # ridgeregression：V_j ∈ ℝ^{Jj×Sj}
            Gram = M_j @ M_j.T
            if reg > 0:
                Gram = Gram + reg * torch.eye(Sj, device=device, dtype=dtype)
            RHS  = Y_j @ M_j.T
            V_j  = torch.linalg.solve(Gram, RHS.T).T              # (Jj, Sj)
            Us[p + j] = V_j

                                         
        Kx = _kron_list([U.contiguous() for U in Us[:p]]) if p > 0 else torch.eye(in_prod, device=device, dtype=dtype)
        Ky = _kron_list([U.contiguous() for U in Us[p:]]) if K  > 0 else torch.eye(out_prod, device=device, dtype=dtype)
        # Kx/Ky “row”dimension（I / J）
        if Kx.shape[0] != in_prod and Kx.shape[1] == in_prod:
            Kx = Kx.T
        if Ky.shape[0] != out_prod and Ky.shape[1] == out_prod:
            Ky = Ky.T
                      
        if Kx.shape[1] != R_in and Kx.shape[0] == R_in:
            Kx = Kx.T
        if Ky.shape[1] != R_out and Ky.shape[0] == R_out:
            Ky = Ky.T

        X_vec = Xc.reshape(N, in_prod)                           # (N, I)

        # φ_3d: (N, J, R_in*R_out)
        kron_combined = _kron_last_dim(Kx, Ky)                   # (I, J, R_in*R_out)
        phi_3d = torch.tensordot(X_vec, kron_combined, dims=([1],[0]))
        phi_flat = phi_3d.reshape(-1, R_in * R_out)              # (N*J, R_in*R_out)
        y_vec = Y.reshape(-1)                                    # (N*J,)

        AtA = phi_flat.T @ phi_flat
        if reg > 0:
            AtA = AtA + reg * torch.eye(R_in * R_out, device=device, dtype=dtype)
        Atb = phi_flat.T @ y_vec
        G_vec = torch.linalg.solve(AtA, Atb)
        G = G_vec.reshape(G.shape)

            
        self.W_tucker_['Us'] = Us
        self.W_tucker_['G']  = G

    def _tucker_step_rgd(self, Xc: torch.Tensor, Y: torch.Tensor):
        """Tucker 的 RGD：U_k 在 Stiefel 流形上投影-回缩；G 欧氏一步。
        注意：这里不再重复对 G/Us 做 L2 正则 —— 若需要请将 reg 拆成独立超参。"""
        # 1) full weightgradient（ W ridge）
        _, G_W_full = self._residual_and_grad_Wfull(Xc, Y)  # shape == W_full.shape

        Us: List[torch.Tensor] = self.W_tucker_['Us']
        G:  torch.Tensor       = self.W_tucker_['G']
        stepU = (self.step_size_W or 1e-3)
        stepG = (self.step_size_G or self.step_size_W or 1e-3)

        n_modes = len(Us)

                                                  
        UTs = [U.T for U in Us]
        grad_G = multi_mode_dot(G_W_full, UTs, modes=list(range(n_modes)))
        # reg * G（ _residual_and_grad_Wfull ）
        # G = torch.nan_to_num(G - stepG * grad_G)
        # updated_G = G - stepG * grad_G
        param_key_G = 'W_tucker_G'
        update_step_G = self._get_adam_step(param_key_G, grad_G, stepG)
        updated_G = G - update_step_G
        if self.use_nan_to_num:
            updated_G = torch.nan_to_num(updated_G)
        G = updated_G

                                                      
        # ：∇U_k = unfold_k(∇W) @ H_k， H_k = (⊗_{j≠k}U_j) @ unfold_k(G)^T
                                                              
        for k in range(n_modes):
            # k U_j^T ∇W ， Kronecker
            # ： G “mode” U_j， G_bar ∇W_k
            modes_others = [j for j in range(n_modes) if j != k]

            # G_bar = G ×_{j≠k} U_j => unfold_k(G_bar).T shape (Π_{j≠k} D_j , r_k)
            G_bar = multi_mode_dot(G, [Us[j] for j in modes_others], modes=modes_others)

            # unfold
            GWk   = self._unfold_k(G_W_full, k)            # (D_k, Π_{j≠k} D_j)
            Gbark = self._unfold_k(G_bar,      k).T        # (Π_{j≠k} D_j, r_k)

            grad_Uk = GWk @ Gbark                           # (D_k, r_k)

            # reg * U_k（）
            # Riemannian projection + QR
            gradR = riemannian_grad_on_orthogonal(Us[k], grad_Uk)
            param_key_Uk = f'W_tucker_U_{k}'
            update_step_Uk = self._get_adam_step(param_key_Uk, gradR, stepU)
            Uk_new = qr_retraction(Us[k], update_step_Uk)
            # Uk_new = qr_retraction(Us[k], stepU * gradR)
            # self.W_tucker_['Us'][k] = torch.nan_to_num(Uk_new)
            if self.use_nan_to_num:
                Uk_new = torch.nan_to_num(Uk_new)
            self.W_tucker_['Us'][k] = Uk_new

        self.W_tucker_['G'] = G

    def _tt_step_iht(self, Xc: torch.Tensor, Y: torch.Tensor):
        """TT-IHT：W_{t+1} = TT-SVD_rank(W_t - mu * (1/N) X^T(XW_t - Y))"""
        N = Xc.shape[0]
        # in_prod = self._tt_in_prod
        in_prod = int(np.prod(Xc.shape[1:]))
        out_prod = self._tt_out_prod
        weight_shape = self._tt_weight_shape
        ranks_full = self._tt_ranks_full

        # /step size mu
        if self._tt_mu_ is None:
            # sigma_max(X)^2 / N
            X_mat = Xc.reshape(N, in_prod)
            # （torch ）
            # inside _tt_step_iht(...) just above: smax = _power_top_sv(X_mat)
            def _power_top_sv(A: torch.Tensor, n_iter: int = 50, tol: float = 1e-6):
                """Power iteration to estimate the top singular value of A.
                Keeps dtype/device consistent with A to avoid matmul dtype errors."""
                m, n = A.shape
                g = torch.Generator(device=A.device); g.manual_seed(0)
                                                
                v = torch.randn(n, device=A.device, dtype=A.dtype, generator=g)
                eps = torch.tensor(1e-12, device=A.device, dtype=A.dtype)

                # normalize v
                v = v / (v.norm() + eps)

                sv_prev = A.new_tensor(0.0)
                for _ in range(n_iter):
                    Av = A @ v
                    u  = Av / (Av.norm() + eps)

                    Atu = A.T @ u
                    v   = Atu / (Atu.norm() + eps)

                    sv = Av.norm()
                    if torch.abs(sv - sv_prev) / (sv_prev + eps) < tol:
                        break
                    sv_prev = sv
                return float(sv)

            smax = _power_top_sv(X_mat)
            self._tt_mu_ = 1.0 / ((smax**2) / float(N) + 1e-12)

        mu = float(self.step_size_W) if (self.step_size_W is not None) else float(self._tt_mu_)

        # W_mat
        W_full = tt_to_tensor(TTTensor(self.W_tt_))
        W_mat = W_full.reshape(in_prod, out_prod)

        # gradient：G = (1/N) X^T (XW - Y)
        X_mat = Xc.reshape(N, in_prod)
        Y_mat = Y.reshape(N, out_prod)
        R = X_mat @ W_mat - Y_mat
        G = (X_mat.T @ R) / float(N)

        # step + TT-SVD projection
        W_mat_tmp = W_mat - mu * G
        W_tmp = W_mat_tmp.reshape(weight_shape)
        tl.set_backend('pytorch')
        tt_new = tensor_train(tl.tensor(W_tmp), rank=ranks_full)

        # cores
        self.W_tt_ = [torch.as_tensor(c, device=self.device, dtype=self.dtype).detach().clone().requires_grad_(True)
                    for c in tt_new.factors]
        # orthogonal
        self.W_tt_ = self._tt_left_orthonormalize_cores(self.W_tt_)

    def _tt_left_unfold(self, core: torch.Tensor):
        r_prev, d, r_next = core.shape
        return core.reshape(r_prev * d, r_next)

    def _tt_left_fold(self, mat: torch.Tensor, r_prev: int, d: int, r_next: int):
        return mat.reshape(r_prev, d, r_next)

    def _tt_left_orthonormalize_cores(self, cores: List[torch.Tensor]):
        new_cores = []
        for i, G in enumerate(cores):
            if i == len(cores) - 1:
                new_cores.append(G.detach().clone().requires_grad_(True)); continue
            r_prev, d, r_next = G.shape
            U = self._tt_left_unfold(G)
            Q, R = torch.linalg.qr(U, mode='reduced')
            Q = Q[:, :r_next]
            G_new = self._tt_left_fold(Q, r_prev, d, r_next)
            new_cores.append(G_new.detach().clone().requires_grad_(True))
        return new_cores

    def _tt_step_rgd(self, Xc: torch.Tensor, Y: torch.Tensor):
        """TT 的 Riemannian GD：前 (d-1) 个核左正交（Stiefel），末核欧氏自由。"""
        N = Xc.shape[0]
        # in_prod = self._tt_in_prod
        in_prod = int(np.prod(Xc.shape[1:]))
        out_prod = self._tt_out_prod

        mu = float(self.step_size_W) if (self.step_size_W is not None) else 1e-3

        X_mat = Xc.reshape(N, in_prod)
        Y_mat = Y.reshape(N, out_prod)

        # autograd gradient
        for c in self.W_tt_:
            if c.grad is not None:
                c.grad.zero_()

        def _loss_and_wmat():
            W = tt_to_tensor(TTTensor(self.W_tt_))
            Wmat = W.reshape(in_prod, out_prod)
            R = X_mat @ Wmat - Y_mat
            loss = 0.5 * (R * R).sum() / float(N)
            return loss, Wmat

        loss, _ = _loss_and_wmat()
        loss.backward()

        with torch.no_grad():
            for i, G in enumerate(self.W_tt_):
                Ge = G.grad
                if Ge is None:
                    # gradient None ，“”（）
                    continue
                if i < len(self.W_tt_) - 1:
                    r_prev, d, r_next = G.shape
                    U = self._tt_left_unfold(G)     # (r_{i-1}*d_i, r_i)
                    Gm = self._tt_left_unfold(Ge)   # same shape
                    UtG = U.T @ Gm
                    Grad_tan = Gm - U @ (0.5 * (UtG + UtG.T))  # emptyprojection
                    Grad_tan = self._tt_left_fold(Grad_tan, r_prev, d, r_next)
                    self.W_tt_[i] = (G - mu * Grad_tan).detach().clone().requires_grad_(True)
                else:
                    # core，step
                    self.W_tt_[i] = (G - mu * Ge).detach().clone().requires_grad_(True)

            # orthogonal（QR）
            self.W_tt_ = self._tt_left_orthonormalize_cores(self.W_tt_)

    def _rebuild_k_groups_from_combos(self):
        """基于当前 self.combos_ 建立 self._k_groups[i][p] = LongTensor(索引集合)"""
        n_modes = len(self.block_sizes)
        self._k_groups = {}
        for i in range(n_modes):
            groups_i = {}
            for p in range(1, self.K + 1):
                idxs = [idx for idx, combo in enumerate(self.combos_) if combo[i] == p]
                groups_i[p] = torch.as_tensor(idxs, device=self.device, dtype=torch.long)
            self._k_groups[i] = groups_i

                                                 
    def fit(self, X, y, X_test=None, y_test=None, y_mean=None, y_std=None):
        tl.set_backend('pytorch')
        self.adam_m_.clear()
        self.adam_v_.clear()
        self.adam_t_.clear()
        X = torch.as_tensor(np.array(X), device=self.device, dtype=self.dtype)
        Y = torch.as_tensor(np.array(y), device=self.device, dtype=self.dtype)

                                  
        do_unnormalize = (y_mean is not None) and (y_std is not None)
        if (y_mean is not None) != (y_std is not None):
            raise ValueError("y_mean 和 y_std 必须同时提供或同时不提供。")
        if do_unnormalize:
            # self._log(f"y_mean={y_mean} and y_std={y_std} provided. Scale-dependent metrics (e.g., RPE) will be reported in the original data scale.")
            # numpy
            y_mean = np.array(y_mean)
            y_std = np.array(y_std)
            if np.any(y_std == 0):
                self._log("Warning: y_std contains zero values. De-normalization might be unstable.")
                y_std[y_std == 0] = 1e-12 # zero
        self._X_orig_ = X; self._Y_orig_ = Y
        assert X.shape[0] == Y.shape[0]
        self.in_shape_ = tuple(X.shape[1:]); self.out_shape_ = tuple(Y.shape[1:])

                              
        has_val_data = (X_test is not None) and (y_test is not None)
        if (X_test is not None) != (y_test is not None):
            raise ValueError("X_test 和 y_test 必须同时提供或同时不提供。")
        
        X_val, Y_val = None, None
        if has_val_data:
            self._log("Validation data provided. Early stopping and metrics will be based on the validation set.")
            X_val = torch.as_tensor(np.array(X_test), device=self.device, dtype=self.dtype)
            Y_val = torch.as_tensor(np.array(y_test), device=self.device, dtype=self.dtype)
            assert X_val.shape[0] == Y_val.shape[0]
            # initcol
            self.val_losses_ = []
            self.val_metrics_ = []

        # dimension r_i
        d_list = self._prepare_shapes_and_postdims(X, getattr(self, 'basis_dims', None))
        n_modes = len(d_list)
        self.K_max_possible_list_ = [int(d) // int(b) for d, b in zip(d_list, self.block_sizes)]

        # , input K_list_
        original_K_list = self.K_list_.copy()
        new_K_list = []
        was_capped = False
        for i in range(n_modes):
            k_user = original_K_list[i]
            k_max_possible = self.K_max_possible_list_[i]
            if k_user > k_max_possible:
                self._log(f"Warning: For mode {i}, input K={k_user} is larger than the maximum possible {k_max_possible} (dim={d_list[i]}, block_size={self.block_sizes[i]}). Capping K to {k_max_possible}.")
                new_K_list.append(k_max_possible)
                was_capped = True
            else:
                new_K_list.append(k_user)

        # K_list_ ，update
        if was_capped:
            self.K_list_ = new_K_list
            # K_max_ mode
            self.K = int(np.max(self.K_list_))
            self.K_max_ = self.K
            self.K_base_mode_ = int(np.argmax(self.K_list_))
            self._log(f"K configuration updated to: {self.K_list_}")
            

       
        if self.init_B_method == 'dft':
            self.dft_bases_ = {}
            for i in range(n_modes):
                d_i = int(d_list[i])
                # mode
                b_i = int(self.block_sizes[i])
                total_basis_needed = self.K_list_[i] * b_i
                if total_basis_needed > d_i:
                    raise ValueError(
                        f"对于模态 {i}, DFT 基不足。需要 {total_basis_needed}个 (K={self.K_list_[i]} * b_i={b_i}), "
                        f"但维度 d_i={d_i} 最多只能提供 {d_i} 个。"
                    )
                              
                full_basis = dft_basis(d_i, rate_choose=1.0)                          
                self.dft_bases_[i] = torch.as_tensor(full_basis, device=self.device, dtype=self.dtype)

                   
        # self._make_diag_combos(n_modes)
        self._log("Pre-calculating final combinations and initializing W to full size.")
        if self.combo_mode == 'diag':
            self._make_diag_combos(n_modes)
        elif self.combo_mode == 'diag_skew':
            self._make_diag_skew_combos(n_modes)
        elif self.combo_mode == 'diag_skew_parallel':
            self._make_diag_skew_parallel_combos(n_modes)
        else:  # 'bfs'
            self.combos_ = self._make_bfs_combos_up_to(self.K, n_modes)
        self._rebuild_k_groups_from_combos()

        # init
        if getattr(self, 'losses_', None) is None: self.losses_ = []
        if getattr(self, 'metrics_', None) is None: self.metrics_ = []

        # init W（zero）
                                    
        def _ensure_A_for(p:int):
            for i in range(n_modes):
                                            
                if p > self.K_list_[i]:
                    # componentpmodeiK，
                    continue
                                          
                if (i,p) not in self.A_:
                    b = int(self.block_sizes[i]); r = int(self.post_dims_[i])
                    if self.init_A_method == 'eye':
                        self.A_[(i,p)] = torch.eye(b, device=self.device, dtype=self.dtype)[:, :r].contiguous()
                    else:
                        self.A_[(i,p)] = self._rand_orth(b, r)

        def _ensure_B_for(p:int):
            for i in range(n_modes):
                                            
                if p > self.K_list_[i]:
                    # componentpmodeiK，
                    continue
                                          
                if (i,p) in self.B_:  # ，
                    continue

                d_i = int(d_list[i])
                b_i = int(self.block_sizes[i])

                if self.init_B_method == 'dft':
                                        
                    start_idx = (p - 1) * b_i
                    end_idx = p * b_i
                    
                               
                    # dft_basis return (num_basis, n)， (d_i, d_i)
                    # b_i ，row
                    dft_block = self.dft_bases_[i][start_idx:end_idx, :]                   
                    
                    # B matrixshape (d_i, b_i)，transpose
                    self.B_[(i,p)] = dft_block.T.contiguous()

                else: # 'random' row ()
                                     
                    Qprev = None
                    if p > 1:
                                     
                        Qprev = torch.cat([self.B_[(i,pp)] for pp in range(1,p)], dim=1)
                    
                                     
                    self.B_[(i,p)] = self._rand_orth_in_complement(d_i, b_i, Qprev)
        # p=1 B/A，init W
        # _ensure_B_for(1); _ensure_A_for(1)
        # for pp in range(1, self.K + 1):
        # # B/A ，randominit
        # _ensure_B_for(pp)
        # _ensure_A_for(pp)
        for pp in range(1, self.K + 1):
            _ensure_B_for(pp)

        # perB init A
        if self.perB:
            self._log("Initializing A per combo (perB=True)")
            # A key (combo_index, mode_index)
            for k, combo in enumerate(self.combos_):
                for i in range(n_modes):
                    b = int(self.block_sizes[i]); r = int(self.post_dims_[i])
                    if self.init_A_method == 'eye':
                        self.A_[(k, i)] = torch.eye(b, device=self.device, dtype=self.dtype)[:, :r].contiguous()
                    else:
                        self.A_[(k, i)] = self._rand_orth(b, r)
        else:
            self._log("Initializing A per component (perB=False)")
            # A key (mode_index, p)
            for pp in range(1, self.K + 1):
                 _ensure_A_for(pp)
        
                               
        self.A_exists_ = {key: True for key in self.A_.keys()}
        self.B_exists_ = {key: True for key in self.B_.keys()}

        Xc_tmp = self._build_Xcore_tensor(X, p_max=self.K)
        self._log(f"[fit] Final Xc shape will be {tuple(Xc_tmp.shape)}. Initializing W accordingly.")



        if self.weight_type == 'full':
            in_prod  = int(np.prod(Xc_tmp.shape[1:n_modes+2]))
            out_prod = int(np.prod(self.out_shape_))
            if self.init_W_zero:
                self.W_full_ = torch.zeros(in_prod, out_prod, device=self.device, dtype=self.dtype).reshape(list(Xc_tmp.shape[1:n_modes+2])+list(self.out_shape_))
            else:
                W = torch.randn(in_prod, out_prod, device=self.device, dtype=self.dtype, generator=(self.rng if (self.rng is not None and self.device=='cpu') else None)) * self.headinit_multi
                self.W_full_ = W.reshape(list(Xc_tmp.shape[1:n_modes+2]) + list(self.out_shape_))
        elif self.weight_type == 'cp':
            modes = list(Xc_tmp.shape[1:n_modes+2]) + list(self.out_shape_)  # inputmode + outputmode
            R = int(self.ranks.get('cp', 10))
            facs = []
            gensrc = (self.rng if (self.rng is not None and self.device=='cpu') else None)
            for d in modes:
                M = torch.randn(d, R, device=self.device, dtype=self.dtype, generator=gensrc)
                facs.append(M)  # orthogonal，ALS/GD
            lam = torch.ones(R, device=self.device, dtype=self.dtype)
            self.W_cp_ = {'R': R, 'factors': facs, 'lambda': lam}
        elif self.weight_type == 'tucker':
            modes = list(Xc_tmp.shape[1:n_modes+2]) + list(self.out_shape_)   # inputmode(K) + outputmode
            # mode； r
            if self.tucker_ranks is not None:
                r_list = list(self.tucker_ranks)
                assert len(r_list) == len(modes), "tucker_ranks 长度需等于权重张量维数"
            else:
                r = int(self.ranks.get('tucker', 8))
                r_list = [r] * len(modes)

            Us = []
            # for d, rk in zip(modes, r_list):
            # Q, _ = torch.linalg.qr(torch.randn(d, rk, device=self.device, dtype=self.dtype), mode='reduced')
            # Us.append(Q) # colorthogonal， RGD

            # G = torch.randn(r_list, device=self.device, dtype=self.dtype) * self.headinit_multi
            for d, rk in zip(modes, r_list):
                       
                M = torch.randn(d, rk, device=self.device, dtype=self.dtype, generator=gensrc)
                Q, _ = torch.linalg.qr(M, mode='reduced')
                Us.append(Q)
            G = torch.randn(r_list, device=self.device, dtype=self.dtype, generator=gensrc) * self.headinit_multi
            self.W_tucker_ = {'Us': Us, 'G': G, 'ranks': r_list}
        else:  # 'tt'
            # W “shape”：Xc_tmp inputmode(K) + outputmode
            weight_shape = list(Xc_tmp.shape[1:n_modes+2]) + list(self.out_shape_)
            self._tt_weight_shape = tuple(weight_shape)
            self._tt_in_prod  = int(np.prod(weight_shape[:len(Xc_tmp.shape[1:n_modes+2])]))
            self._tt_out_prod = int(np.prod(self.out_shape_))

            # TT full ranks: [1, r1, ..., r_{d-1}, 1]
            def _tt_full_ranks(tt_rank, order):
                if isinstance(tt_rank, int):
                    return [1] + [tt_rank] * (order - 1) + [1]
                ranks = list(tt_rank)
                if len(ranks) == order - 1:  # internal ranks
                    return [1] + ranks + [1]
                assert len(ranks) == order + 1 and ranks[0] == 1 and ranks[-1] == 1
                return ranks

            tt_rank = int(self.ranks.get('tt', 8)) if not isinstance(self.ranks.get('tt', 8), (list,tuple)) else self.ranks['tt']
            ranks_full = _tt_full_ranks(tt_rank, order=len(weight_shape))
            self._tt_ranks_full = ranks_full

            N = Xc_tmp.shape[0]
            X_mat = Xc_tmp.reshape(N, -1)
            Y_mat = torch.as_tensor(self._Y_orig_, device=self.device, dtype=self.dtype).reshape(N, -1)
            with torch.no_grad():
                W0_mat = (X_mat.T @ Y_mat) / float(N)
                W0 = W0_mat.reshape(weight_shape)
            tl.set_backend('pytorch')
            tt = tensor_train(tl.tensor(W0), rank=ranks_full)


            self.W_tt_ = [torch.as_tensor(c, device=self.device, dtype=self.dtype).detach().clone().requires_grad_(True)
                        for c in tt.factors]
            self.W_tt_ = self._tt_left_orthonormalize_cores(self.W_tt_)  # 5
            self._tt_mu_ = None  

                                            
        if not self.sequential:
            for pp in range(1, self.K + 1):
                _ensure_B_for(pp)

            if self.perB:
                self._log("Initializing A per combo (perB=True)")
                for k, combo in enumerate(self.combos_):
                    for i in range(n_modes):
                        b = int(self.block_sizes[i]); r = int(self.post_dims_[i])
                        if self.init_A_method == 'eye':
                            self.A_[(k, i)] = torch.eye(b, device=self.device, dtype=self.dtype)[:, :r].contiguous()
                        else:
                            self.A_[(k, i)] = self._rand_orth(b, r)
            else:
                self._log("Initializing A per component (perB=False)")

                for pp in range(1, self.K + 1):
                    _ensure_A_for(pp)
            self._log("=== Joint (end-to-end) optimization ===")
            best_loss_joint = float('inf'); best_state_joint = None
            noimp_joint = 0; prev_joint = float('inf')

            inner_AB = int(self.joint_iters) if (self.joint_iters is not None) else int(max(1, self.gdals_iters))

            for it in range(1, self.n_iter_max + 1):


                Xc = self._build_Xcore_tensor(X)

                if self.weight_type == 'tt':
                    if self.tt_solver == 'iht':
                        self._tt_step_iht(Xc, Y)
                    elif self.tt_solver == 'rgd':
                        self._tt_step_rgd(Xc, Y)
                    else:
                        _, G_W_full = self._residual_and_grad_Wfull(Xc, Y)
                        W_full = self._tt_reconstruct_full_from_cores()
                        # W_full = torch.nan_to_num(W_full - (self.step_size_W or 1e-3) * G_W_full.to(W_full.dtype))
                        updated_W_full = W_full - (self.step_size_W or 1e-3) * G_W_full.to(W_full.dtype)
                        if self.use_nan_to_num:
                            updated_W_full = torch.nan_to_num(updated_W_full)
                        W_full = updated_W_full
                        tt = tensor_train(tl.tensor(W_full), rank=self._tt_ranks_full)
                        self.W_tt_ = [torch.as_tensor(core, device=self.device, dtype=self.dtype).detach().clone().requires_grad_(True)
                                    for core in tt.factors]
                elif self.weight_type == 'full':
                    if self.W_full_solver == 'ls':
                        self._step_W_full_ls(Xc, Y)
                    else:
                        _, G_W_full = self._residual_and_grad_Wfull(Xc, Y, context=f"joint it={it}")
                        self._step_W_full_rgd(G_W_full, (self.step_size_W or 1e-3))
                elif self.weight_type == 'cp':
                    if self.W_cp_solver == 'als':
                        self._cp_step_als(Xc, Y)
                    else:
                        self._cp_step_gd(Xc, Y)
                else:  # tucker
                    if self.W_tucker_solver == 'als':
                        self._tucker_step_als(Xc, Y)
                    else:
                        self._tucker_step_rgd(Xc, Y)

                                       
                best_inner = float('inf'); noimp_inner = 0; prev_inner = float('inf')
                patience_inner = min(self.patience, max(2, inner_AB))

                for inner in range(1, inner_AB + 1):
                    Xc = self._build_Xcore_tensor(X)
                    D_all, Yhat = self._delta_all(Xc, Y)
                                     
                    if self.perB:
                        for k, combo in enumerate(self.combos_):
                            for i in range(n_modes):
                                GiM_k = self._grad_M_for_combo(X, D_all, k, i)

                                pp_i = combo[i]
                                Bij = self.B_[(i, pp_i)]
                                grad_A = Bij.T @ GiM_k
                                
            
                                if self.reg_A_l2 > 0: grad_A = grad_A + self.reg_A_l2 * self.A_[(k, i)]
                                if self.reg_A_l1 > 0: grad_A = grad_A + self.reg_A_l1 * torch.sign(self.A_[(k, i)])
                                _clip_inplace(grad_A, 1e3)

                                param_key_A = f'A_{k}_{i}'
                                if self.A_solver == 'rgd':
                                    Aold = self.A_[(k, i)]
                                    gradR = riemannian_grad_on_orthogonal(Aold, grad_A)
                                    update_step = self._get_adam_step(param_key_A, gradR, self.step_size_A)
                                    self.A_[(k, i)] = qr_retraction(Aold, update_step)
                                else: # 'gd' or 'als'
                                    update_step = self._get_adam_step(param_key_A, grad_A, self.step_size_A)
                                    updated_A = self.A_[(k, i)] - update_step
                                    if self.use_nan_to_num: updated_A = torch.nan_to_num(updated_A)
                                    self.A_[(k, i)] = updated_A
                    else: # perB=False
                        for i in range(n_modes):
                            for pp in range(1, self.K_list_[i] + 1):
                                GiM = self._grad_M_for(X, D_all, pp, i)
                                Bij = self.B_[(i, pp)]
                                grad_A = Bij.T @ GiM
                                
                                if self.reg_A_l2 > 0: grad_A = grad_A + self.reg_A_l2 * self.A_[(i, pp)]
                                if self.reg_A_l1 > 0: grad_A = grad_A + self.reg_A_l1 * torch.sign(self.A_[(i, pp)])
                                self._log_grad_stats(grad_A, f"Grad_E(A i={i},p={pp})", f"joint it={it}, inner={inner}")
                                _clip_inplace(grad_A, 1e3)
                                
                                param_key_A = f'A_{i}_{pp}'
                                if self.A_solver == 'rgd':
                                    Aold = self.A_[(i, pp)]
                                    gradR = riemannian_grad_on_orthogonal(Aold, grad_A)
                                    self._log_grad_stats(gradR, f"Grad_R(A i={i},p={pp})", f"joint it={it}, inner={inner}")
                                    update_step = self._get_adam_step(param_key_A, gradR, self.step_size_A)
                                    self.A_[(i, pp)] = qr_retraction(Aold, update_step)
                                else: # 'gd' or 'als'
                                    update_step = self._get_adam_step(param_key_A, grad_A, self.step_size_A)
                                    updated_A = self.A_[(i, pp)] - update_step
                                    if self.use_nan_to_num: updated_A = torch.nan_to_num(updated_A)
                                    self.A_[(i, pp)] = updated_A

                                      
                    if self.learn_B:
                        if self.joint_B_ortho == 'joint_stiefel':
                            for i in range(n_modes):
                                G_cat = self._grad_B_mode_all(X, D_all, i)
                                self._log_grad_stats(G_cat, f"Grad_E(B_cat i={i})", f"joint it={it}, inner={inner}")
                                B_cat = self._concat_B_mode(i)
                                param_key_B_cat = f'B_cat_{i}'

                                if self.B_solver == 'als':
                                    # ALS-like Orthogonal Projection via SVD
                                    update_step = self._get_adam_step(param_key_B_cat, G_cat, self.step_size_B)
                                    B_target = B_cat - update_step
                                    U, _, Vh = torch.linalg.svd(B_target, full_matrices=False)
                                    B_new = U @ Vh
                                else: # RGD
                                    # Riemannian projection on Stiefel & QR retraction
                                    gradR = riemannian_grad_on_orthogonal(B_cat, G_cat)
                                    self._log_grad_stats(gradR, f"Grad_R(B_cat i={i})", f"joint it={it}, inner={inner}")
                                    update_step = self._get_adam_step(param_key_B_cat, gradR, self.step_size_B)
                                    B_new = qr_retraction(B_cat, update_step)

                                if self.use_nan_to_num:
                                    B_new = torch.nan_to_num(B_new)
                                self._split_B_mode(i, B_new)
                        elif self.joint_B_ortho == 'block_only':
                            for i in range(n_modes):
                                for pp in range(1, self.K_list_[i] + 1):
                                    GiM = self._grad_M_for(X, D_all, pp, i)
                                    Aij = self.A_[(i, pp)]
                                    G_E = GiM @ Aij.T
                                    if self.reg_B_l2 > 0:
                                        G_E = G_E + self.reg_B_l2 * self.B_[(i, pp)]
                                    _clip_inplace(G_E, 1e3)
                                    B = self.B_[(i, pp)]
                                    param_key_B = f'B_{i}_{pp}' 
                                    if self.B_solver == 'gd':
                                        update_step = self._get_adam_step(param_key_B, G_E, self.step_size_B)
                                        B_new = B - update_step
                                    else: # 'rgd'
                                        gradR = riemannian_grad_on_orthogonal(B, G_E)
                                        update_step = self._get_adam_step(param_key_B, gradR, self.step_size_B)
                                        B_new = qr_retraction(B, update_step)
                                    
                                    if self.use_nan_to_num:
                                        B_new = torch.nan_to_num(B_new)
                                    self.B_[(i, pp)] = B_new

                        else:
                            raise ValueError("joint_B_ortho='perp_seq' 仅在 sequential=True 时可用；端到端请使用 'joint_stiefel' 或 'block_only'。")

                    Xc_eval = self._build_Xcore_tensor(X)
                    with torch.no_grad():
                        Yhat_inner = self._predict_given_W(Xc_eval)
                        cur_inner = 0.5 * torch.mean((Yhat_inner - Y) ** 2).item()

                    rel_inner = abs(prev_inner - cur_inner) / (prev_inner + 1e-12)
                    prev_inner = cur_inner
                    if cur_inner < best_inner - self.min_delta:
                        best_inner = cur_inner; noimp_inner = 0
                    else:
                        noimp_inner += 1
                    if rel_inner < self.tol or noimp_inner >= patience_inner:
                        break

                X_eval = X_val if has_val_data else X
                Y_eval = Y_val if has_val_data else Y
                log_prefix = "val" if has_val_data else "train"

                Xc_eval = self._build_Xcore_tensor(X_eval)
                with torch.no_grad():
                    Yhat = self._predict_given_W(Xc_eval)
                    cur = 0.5 * torch.mean((Yhat - Y_eval) ** 2).item()
                

                if has_val_data:
                    self.val_losses_.append(cur)
                    with torch.no_grad():
                        Xc_train = self._build_Xcore_tensor(X)
                        Yhat_train = self._predict_given_W(Xc_train)
                        train_loss = 0.5 * torch.mean((Yhat_train - Y) ** 2).item()
                        self.losses_.append(train_loss) 
                else:
                    self.losses_.append(cur)



                yp_2d = Yhat.detach().cpu().numpy()
                yt_2d = Y_eval.detach().cpu().numpy()
                yp_flat = yp_2d.ravel()
                yt_flat = yt_2d.ravel()
                yt_m = yt_flat - yt_flat.mean(); yp_m = yp_flat - yp_flat.mean()
                denom = np.linalg.norm(yt_m) * np.linalg.norm(yp_m)
                pearson_r_flat = float(np.dot(yt_m, yp_m) / denom) if denom > 1e-12 else 0.0

                if do_unnormalize:
                    yp_unnorm_2d = yp_2d * y_std + y_mean
                    yt_unnorm_2d = yt_2d * y_std + y_mean
                    rpe = float(np.linalg.norm(yp_unnorm_2d - yt_unnorm_2d) / (np.linalg.norm(yt_unnorm_2d) + 1e-12))
                else:
                    rpe = float(np.linalg.norm(yp_2d - yt_2d) / (np.linalg.norm(yt_2d) + 1e-12))

                metrics_to_store = {'pearson_r_flat': pearson_r_flat, 'rpe': rpe}
                if has_val_data:
                    self.val_metrics_.append(metrics_to_store)
                else:
                    self.metrics_.append(metrics_to_store)

                if cur < best_loss_joint - self.min_delta:
                    best_loss_joint = cur; noimp_joint = 0
                    snap = {
                        'B': {(i, pp): self.B_[(i, pp)].clone() for i in range(n_modes) for pp in range(1, self.K_list_[i]+1)},
                        'A': {key: val.clone() for key, val in self.A_.items()},
                        'W': self._snapshot_W(),
                    }
                    best_state_joint = snap
                else:
                    noimp_joint += 1

                if self.verbose:
                    log_msg = f"[joint] it {it}/{self.n_iter_max} - {log_prefix}_loss={cur:.6f} - {log_prefix}_r_flat={pearson_r_flat:.6f} - {log_prefix}_rpe={rpe:.6f}"
                    if has_val_data:
                        log_msg += f" (train_loss={train_loss:.6f})"
                    self._log(log_msg)

                rel = abs(prev_joint - cur) / (prev_joint + 1e-12)
                if rel < self.tol:
                    self._log(f"[joint] Converged at it={it} (rel_change<{self.tol}).")
                    break
                prev_joint = cur
                if self.early_stopping and noimp_joint >= self.patience:
                    self._log(f"[joint] Early-stopped at it={it}; best_loss={best_loss_joint:.6f}")
                    break

            if best_state_joint is not None:
                for k, v in best_state_joint['B'].items(): self.B_[k] = v.clone()
                for k, v in best_state_joint['A'].items(): self.A_[k] = v.clone()
                self._restore_W(best_state_joint['W'])

            self.history_ = {
                'loss': self.losses_,
                'metrics': self.metrics_,
                'n_iter': len(self.losses_) - 1,
                'weight_type': self.weight_type,
                'K': self.K,
                'block_sizes': self.block_sizes,
                'step_size_A': self.step_size_A,
                'step_size_B': self.step_size_B,
                'step_size_W': self.step_size_W,
                'warmup_A': self.warmup_A,
                'best_epoch': None,
                'early_stopping': self.early_stopping,
                'patience': self.patience,
                'min_delta': self.min_delta,
                'A_solver': self.A_solver,
                'W_full_solver': self.W_full_solver,
                'perB': self.perB,
                'sequential': self.sequential,
                'joint_B_ortho': self.joint_B_ortho,       
                'basis_dims': list(self.post_dims_),
            }
            if has_val_data:
                self.history_['val_loss'] = self.val_losses_
                self.history_['val_metrics'] = self.val_metrics_
                
            if getattr(self, 'print_ab_stats', False):
                self._print_AB_stats()

            return self

                                        
        self.losses_.clear(); self.metrics_.clear()
        best_global = None
        for p in range(1, self.K_max_ + 1):
            self._log(f"=== Component {p}/{self.K}: init/opt ===")
            best_loss_p = float('inf'); best_state_p = None
            noimp = 0
            prev = float('inf')

                
            for it in range(1, self.n_iter_max+1):
                # Xc = self._build_Xcore_tensor(X)
                Xc = self._build_Xcore_tensor(X, p_max=p)

                                          
                if self.weight_type == 'tt':
                    if self.tt_solver == 'iht':
                        self._tt_step_iht(Xc, Y)   
                    elif self.tt_solver == 'rgd':
                        self._tt_step_rgd(Xc, Y)    
                    else:
                        _, G_W_full = self._residual_and_grad_Wfull(Xc, Y)
                        W_full = self._tt_reconstruct_full_from_cores()
                        # W_full = torch.nan_to_num(W_full - (self.step_size_W or 1e-3) * G_W_full.to(W_full.dtype))
                        updated_W_full = W_full - (self.step_size_W or 1e-3) * G_W_full.to(W_full.dtype)
                        if self.use_nan_to_num:
                            updated_W_full = torch.nan_to_num(updated_W_full)
                        W_full = updated_W_full
                        tt = tensor_train(tl.tensor(W_full), rank=self._tt_ranks_full)
                        self.W_tt_ = [torch.as_tensor(core, device=self.device, dtype=self.dtype).detach().clone().requires_grad_(True)
                                    for core in tt.factors]
                elif self.weight_type == 'full':
                    if self.W_full_solver == 'ls':
                        self._step_W_full_ls(Xc, Y)
                    else:
                        _, G_W_full = self._residual_and_grad_Wfull(Xc, Y, context=f"p={p}, it={it}")
                        self._step_W_full_rgd(G_W_full, (self.step_size_W or 1e-3))
                elif self.weight_type == 'cp':
                    if self.W_cp_solver == 'als':
                        self._cp_step_als(Xc, Y)     
                    else:
                        self._cp_step_gd(Xc, Y)  
                elif self.weight_type == 'tucker': 
                    if self.W_tucker_solver == 'als':
                        self._tucker_step_als(Xc, Y)    
                    else:
                        self._tucker_step_rgd(Xc, Y)    
                
                                          
                do_gdals = (
                    (self.weight_type == 'full'   and self.W_full_solver   == 'ls')   or
                    (self.weight_type == 'cp'     and self.W_cp_solver     == 'als')  or
                    (self.weight_type == 'tucker' and self.W_tucker_solver == 'als')
                )
                inner_max = int(self.gdals_iters) if do_gdals else 1

                best_inner = float('inf')
                noimp_inner = 0
                patience_inner = min(self.patience, max(2, inner_max))

                prev_inner = float('inf')
                for inner in range(1, inner_max + 1):
                                             
                    Xc = self._build_Xcore_tensor(X)
                    D_all, Yhat = self._delta_all(Xc, Y)

                                             
                    if self.perB:
                        for k, combo in enumerate(self.combos_):
                            is_active = True
                            if self.combo_mode == 'diag_skew':
                                if combo[self.K_base_mode_] > p: is_active = False
                            elif max(combo) > p:
                                is_active = False
                            if not is_active: continue
                            if (not self.if_all_learn_A) and (max(combo) < p):
                                continue

                            for i in range(n_modes):
                                # 1. D_k
                                D_k = D_all.select(-1, k).unsqueeze(-1)
                                
                                # 2. V_{k,i}
                                T_c = X
                                for m in range(n_modes):
                                    if m == i: continue
                                    pp_m = combo[m]
                                    M_m = self.B_[(m, pp_m)] @ self.A_[(k, m)]
                                    T_c = mode_n_product_batch(T_c, M_m, m)
                                
                                # 3. ∂L/∂M_{k,i}
                                V_c = self._unfold_batch_mode(T_c, i)
                                U_c = self._unfold_batch_mode(D_k, i)
                                GiM_k = torch.einsum('sdr, skr -> dk', V_c, U_c) / max(1, X.shape[0])
                                pp_i = combo[i]
                                Bij = self.B_[(i, pp_i)]
                                grad_A = Bij.T @ GiM_k

                                if self.reg_A_l2 > 0: grad_A = grad_A + self.reg_A_l2 * self.A_[(k, i)]
                                if self.reg_A_l1 > 0: grad_A = grad_A + self.reg_A_l1 * torch.sign(self.A_[(k, i)])
                                _clip_inplace(grad_A, 1e3)
                                param_key_A = f'A_{k}_{i}'
                                if self.A_solver == 'rgd':
                                    Aold = self.A_[(k, i)]
                                    gradR = riemannian_grad_on_orthogonal(Aold, grad_A)
                                    update_step = self._get_adam_step(param_key_A, gradR, self.step_size_A)
                                    self.A_[(k, i)] = qr_retraction(Aold, update_step)
                                else:
                                    update_step = self._get_adam_step(param_key_A, grad_A, self.step_size_A)
                                    updated_A = self.A_[(k, i)] - update_step
                                    if self.use_nan_to_num: updated_A = torch.nan_to_num(updated_A)
                                    self.A_[(k, i)] = updated_A
                    else: 
                        for pp in range(1, p + 1):
                            if (not self.if_all_learn_A) and (pp < p):
                                continue
                            for i in range(n_modes):
                                if pp > self.K_list_[i]: continue
                                GiM = self._grad_M_for(X, D_all, pp, i)
                                Bij = self.B_[(i, pp)]
                                grad_A = Bij.T @ GiM
                                if self.reg_A_l2 > 0: grad_A = grad_A + self.reg_A_l2 * self.A_[(i, pp)]
                                if self.reg_A_l1 > 0: grad_A = grad_A + self.reg_A_l1 * torch.sign(self.A_[(i, pp)])
                                _clip_inplace(grad_A, 1e3)
                                param_key_A = f'A_{i}_{pp}'
                                if self.A_solver == 'rgd':
                                    Aold = self.A_[(i, pp)]
                                    gradR = riemannian_grad_on_orthogonal(Aold, grad_A)
                                    update_step = self._get_adam_step(param_key_A, gradR, self.step_size_A)
                                    self.A_[(i, pp)] = qr_retraction(Aold, update_step)
                                else:
                                    update_step = self._get_adam_step(param_key_A, grad_A, self.step_size_A)
                                    updated_A = self.A_[(i, pp)] - update_step
                                    if self.use_nan_to_num: updated_A = torch.nan_to_num(updated_A)
                                    self.A_[(i, pp)] = updated_A

                                             
                    if self.learn_B:
                        for i in range(n_modes):
                            if p > self.K_list_[i]:
                                continue
                            # ∂L/∂B = (∂L/∂M) @ A^T
                            if self.perB:
                                d_i, b_i = self.B_[(i, p)].shape
                                G_E_total = torch.zeros(d_i, b_i, device=self.device, dtype=self.dtype)

                                combo_indices = self._k_groups[i].get(p)
                                if combo_indices is not None and combo_indices.numel() > 0:
                                    for k_tensor in combo_indices:
                                        k = k_tensor.item()
                                        GiM_k = self._grad_M_for_combo(X, D_all, k, i)
                                        A_ki = self.A_[(k, i)]

                                        G_E_total += GiM_k @ A_ki.T
                                G_E = G_E_total
                            else:

                                GiM = self._grad_M_for(X, D_all, p, i)
                                Aij = self.A_[(i, p)]
                                G_E = GiM @ Aij.T 
                            if self.reg_B_l2 > 0:
                                G_E = G_E + self.reg_B_l2 * self.B_[(i, p)]
                            self._log_grad_stats(G_E, f"Grad_E(B i={i},p={p})", f"p={p}, it={it}, inner={inner}")
                            _clip_inplace(G_E, 1e3)

                            B = self.B_[(i, p)]
                            param_key_B = f'B_{i}_{p}' # key

                            if self.B_solver == 'gd':
                                update_step = self._get_adam_step(param_key_B, G_E, self.step_size_B)
                                B_new = B - update_step
                            elif self.B_solver == 'rgd':
                                gradR = riemannian_grad_on_orthogonal(B, G_E)
                                self._log_grad_stats(gradR, f"Grad_R(B i={i},p={p})", f"p={p}, it={it}, inner={inner}")
                                update_step = self._get_adam_step(param_key_B, gradR, self.step_size_B)
                                B_new = qr_retraction(B, update_step)
                            elif self.B_solver == 'als':
                                Qprev = torch.cat([self.B_[(i,pp)] for pp in range(1,p)], dim=1) if p>1 else None
                                def P_perp(Z: torch.Tensor) -> torch.Tensor:
                                    return Z if (Qprev is None or Qprev.numel()==0) else Z - Qprev @ (Qprev.T @ Z)
                                
                                update_step = self._get_adam_step(param_key_B, G_E, self.step_size_B)
                                B_target = B - update_step

                                B_proj = P_perp(B_target)
                                

                                U, _, Vh = torch.linalg.svd(B_proj, full_matrices=False)
                                B_new = U @ Vh
                            else: # 'perp_rgd'
                                Qprev = torch.cat([self.B_[(i,pp)] for pp in range(1,p)], dim=1) if p>1 else None
                                def P_perp(Z: torch.Tensor) -> torch.Tensor:
                                    return Z if (Qprev is None or Qprev.numel()==0) else Z - Qprev @ (Qprev.T @ Z)
                                sym = 0.5 * (B.T @ G_E + G_E.T @ B)

                            if self.B_solver != 'perp_rgd': 
                                if self.use_nan_to_num:
                                    B_new = torch.nan_to_num(B_new)
                                self.B_[(i, p)] = B_new


                    Xc_eval = self._build_Xcore_tensor(X)
                    with torch.no_grad():
                        Yhat_inner = self._predict_given_W(Xc_eval)
                        cur_inner = 0.5 * torch.mean((Yhat_inner - Y) ** 2).item()

                    
                    rel_inner = abs(prev_inner - cur_inner) / (prev_inner + 1e-12)
                    prev_inner = cur_inner


                    if cur_inner < best_inner - self.min_delta:
                        best_inner = cur_inner
                        noimp_inner = 0
                    else:
                        noimp_inner += 1

                    if rel_inner < self.tol or noimp_inner >= patience_inner:
                        break

                X_eval = X_val if has_val_data else X
                Y_eval = Y_val if has_val_data else Y
                log_prefix = "val" if has_val_data else "train"

                Xc_eval = self._build_Xcore_tensor(X_eval, p_max=p)
                with torch.no_grad():
                    Yhat = self._predict_given_W(Xc_eval)
                    cur = 0.5 * torch.mean((Yhat - Y_eval) ** 2).item()

                if has_val_data:
                    self.val_losses_.append(cur)
                    with torch.no_grad():
                        Xc_train = self._build_Xcore_tensor(X, p_max=p)
                        Yhat_train = self._predict_given_W(Xc_train)
                        train_loss = 0.5 * torch.mean((Yhat_train - Y) ** 2).item()
                        self.losses_.append(train_loss)
                else:
                    self.losses_.append(cur)
                yp_2d = Yhat.detach().cpu().numpy()
                yt_2d = Y_eval.detach().cpu().numpy()

                                                                   
                yp_flat = yp_2d.ravel()
                yt_flat = yt_2d.ravel()
                yt_m = yt_flat - yt_flat.mean(); yp_m = yp_flat - yp_flat.mean()
                denom = np.linalg.norm(yt_m) * np.linalg.norm(yp_m)
                pearson_r_flat = float(np.dot(yt_m, yp_m) / denom) if denom > 1e-12 else 0.0

                                                                           
                if do_unnormalize:
                    yp_unnorm_2d = yp_2d * y_std + y_mean
                    yt_unnorm_2d = yt_2d * y_std + y_mean
                    rpe = float(np.linalg.norm(yp_unnorm_2d - yt_unnorm_2d) / (np.linalg.norm(yt_unnorm_2d) + 1e-12))
                else:
                    rpe = float(np.linalg.norm(yp_2d - yt_2d) / (np.linalg.norm(yt_2d) + 1e-12))

                metrics_to_store = {'pearson_r_flat': pearson_r_flat, 'rpe': rpe}
                if has_val_data:
                    self.val_metrics_.append(metrics_to_store)
                else:
                    self.metrics_.append(metrics_to_store)
                


                if cur < best_loss_p - self.min_delta:
                    best_loss_p = cur; noimp = 0

                    A_snap = {}
                    if self.perB:
                        for k, combo in enumerate(self.combos_):
                            is_active = True
                            if self.combo_mode == 'diag_skew':
                                if combo[self.K_base_mode_] > p: is_active = False
                            elif max(combo) > p:
                                is_active = False
                            if is_active:
                                for i in range(n_modes):
                                    A_snap[(k,i)] = self.A_[(k,i)].clone()
                    else: 
                        for i in range(n_modes):
                            for pp in range(1,p+1):
                                if pp <= self.K_list_[i]:
                                    A_snap[(i,pp)] = self.A_[(i,pp)].clone()

                    snap = {
                        'B': {(i,pp): self.B_[(i,pp)].clone() for i in range(n_modes) for pp in range(1,p+1) if pp <= self.K_list_[i]},
                        'A': A_snap,
                        'W': self._snapshot_W(),
                    }
                    best_state_p = snap
                else:
                    noimp += 1


                if self.verbose:
                    log_msg = f"[p={p}] it {it}/{self.n_iter_max} - {log_prefix}_loss={cur:.6f} - {log_prefix}_r_flat={pearson_r_flat:.6f} - {log_prefix}_rpe={rpe:.6f}"
                    if has_val_data:
                        log_msg += f" (train_loss={train_loss:.6f})"
                    self._log(log_msg)
                # early
                rel = abs(prev-cur)/(prev+1e-12)
                if rel < self.tol:
                    self._log(f"[p={p}] Converged at it={it} (rel_change<{self.tol}).")
                    break
                prev = cur
                if self.early_stopping and noimp >= self.patience:
                    self._log(f"[p={p}] Early-stopped at it={it}; best_loss={best_loss_p:.6f}")
                    break


            if best_state_p is not None:
                for k,v in best_state_p['B'].items(): self.B_[k] = v.clone()
                for k,v in best_state_p['A'].items(): self.A_[k] = v.clone()
                self._restore_W(best_state_p['W'])


            best_global = best_state_p


        self.history_ = {
            'loss': self.losses_,
            'metrics': self.metrics_,
            'n_iter': len(self.losses_)-1,
            'weight_type': self.weight_type,
            'K': self.K,
            'block_sizes': self.block_sizes,
            'step_size_A': self.step_size_A,
            'step_size_B': self.step_size_B,
            'step_size_W': self.step_size_W,
            'warmup_A': self.warmup_A,
            'best_epoch': None,
            'early_stopping': self.early_stopping,
            'patience': self.patience,
            'min_delta': self.min_delta,
            'A_solver': self.A_solver,
            'W_full_solver': self.W_full_solver,
            'perB': self.perB,
            'sequential': self.sequential,
            'basis_dims': list(self.post_dims_),
        }
        if has_val_data:
                self.history_['val_loss'] = self.val_losses_
                self.history_['val_metrics'] = self.val_metrics_
        if getattr(self, 'print_ab_stats', False):
            self._print_AB_stats()

        return self

                            
    def _snapshot_W(self):
        if self.weight_type == 'full':
            return {'type':'full', 'W_full_': self.W_full_.clone()}
        elif self.weight_type == 'cp':
            return {'type':'cp', 'W_cp_': {'R': self.W_cp_['R'],
                                           'lambda': self.W_cp_['lambda'].clone(),
                                           'factors': [u.clone() for u in self.W_cp_['factors']]}}
        elif self.weight_type == 'tucker':
            return {'type':'tucker',
                    'W_tucker_': {
                        'Us': [U.clone() for U in self.W_tucker_['Us']],
                        'G': self.W_tucker_['G'].clone(),
                        'ranks': self.W_tucker_['ranks'],   # ←
                    }}
        else:
            return {'type':'tt', 'W_tt_': [c.detach().clone() for c in self.W_tt_]}

    def _restore_W(self, snap):
        if snap['type']=='full':
            self.W_full_ = snap['W_full_'].clone()
        elif snap['type']=='cp':
            cp = snap['W_cp_']; self.W_cp_ = {'R': cp['R'], 'lambda': cp['lambda'].clone(), 'factors': [u.clone() for u in cp['factors']]}
        elif snap['type']=='tucker':
            tk = snap['W_tucker_']
            self.W_tucker_ = {
                'Us': [U.clone() for U in tk['Us']],
                'G': tk['G'].clone(),
                'ranks': tk['ranks'],                     # ←
            }
        else:
            self.W_tt_ = [c.detach().clone().requires_grad_(True) for c in snap['W_tt_']]

                                                 
    def predict(self, X):
        X = torch.as_tensor(np.array(X), device=self.device, dtype=self.dtype)
        Xc = self._build_Xcore_tensor(X)
        with torch.no_grad():
            Yhat = self._predict_given_W(Xc)
        return Yhat.cpu().numpy()

 

    def score(self, X, y, y_mean=None, y_std=None):
        y_pred = self.predict(X)
        yt_orig = np.array(y)
        yp_norm = np.array(y_pred)


        yt_flat = yt_orig.ravel()
        yp_flat = yp_norm.ravel()
        yt_m = yt_flat - yt_flat.mean(); yp_m = yp_flat - yp_flat.mean()
        denom = np.linalg.norm(yt_m) * np.linalg.norm(yp_m)
        pearson_r_flat = float(np.dot(yt_m, yp_m) / denom) if denom > 1e-12 else 0.0

        # De-normalize predictions if parameters are provided
        do_unnormalize = (y_mean is not None) and (y_std is not None)
        if do_unnormalize:
            y_mean_arr = np.array(y_mean)
            y_std_arr = np.array(y_std)
            # y_pred is normalized, so we un-normalize it to match the original scale of y
            yp_unnorm = yp_norm * y_std_arr + y_mean_arr
            # Now both yp_unnorm and yt_orig are in the original data scale
            rpe = float(np.linalg.norm(yp_unnorm - yt_orig) / (np.linalg.norm(yt_orig) + 1e-12))
            mse = float(np.mean((yp_unnorm - yt_orig)**2))
        else:
            # Assume y is already normalized to the same scale as y_pred
            rpe = float(np.linalg.norm(yp_norm - yt_orig) / (np.linalg.norm(yt_orig) + 1e-12))
            mse = float(np.mean((yp_norm - yt_orig)**2))
            
        return {"pearson_r_flat": pearson_r_flat, "rpe": rpe}, mse

                                                          
    def _unfold_k(self, T: torch.Tensor, k: int) -> torch.Tensor:
        dims = list(range(T.ndim)); perm = [k] + [d for d in dims if d!=k]
        return T.permute(perm).reshape(T.shape[k], -1)

    def _khatri_rao_all(self, factors: List[torch.Tensor], skip: Optional[int] = None) -> torch.Tensor:
        mats = [factors[i] for i in range(len(factors)) if i != skip]
        out = mats[0]
        for M in mats[1:]:
            out = torch.einsum('ir,jr->ijr', out, M).reshape(out.shape[0]*M.shape[0], out.shape[1])
        return out
