# datasets_with_permutation.py
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from utils.timefeatures import time_features
import warnings
warnings.filterwarnings('ignore')

import numpy as np
from scipy.cluster.hierarchy import linkage, optimal_leaf_ordering, leaves_list
from scipy.spatial.distance import squareform
from sklearn.decomposition import PCA

_eps = 1e-12

def _prepare_train_stack(train_data):
    """ Collapse and z-score columns robustly.
        train_data: [T, C] or [N, T, C] -> returns [T_total, C]
    """
    if isinstance(train_data, list):
        train_data = np.concatenate(train_data, axis=0)
    if train_data.ndim == 3:
        train_stack = train_data.reshape(-1, train_data.shape[-1])
    else:
        train_stack = train_data.copy()
    # center + scale per-variable (avoid zero-std)
    mu = train_stack.mean(axis=0, keepdims=True)
    sigma = train_stack.std(axis=0, ddof=0, keepdims=True) + _eps
    X = (train_stack - mu) / sigma
    return X

def compute_corr_matrix_from_train(train_data):
    """Return correlation matrix (C x C) computed from training portion."""
    X = _prepare_train_stack(train_data)
    # corrcoef with rowvar=False -> features in columns
    corr = np.corrcoef(X, rowvar=False)
    # replace NaNs (constant columns) with 0
    corr = np.nan_to_num(corr, nan=0.0)
    return corr

# --- existing spectral / degree kept (slightly robustified) ---
def spectral_seriation(sim_mat):
    A = np.abs(sim_mat)
    np.fill_diagonal(A, 0.0)
    deg = A.sum(axis=1)
    L = np.diag(deg) - A
    try:
        w, v = np.linalg.eigh(L)
        if v.shape[1] < 2:
            return np.arange(A.shape[0], dtype=int)
        fiedler = v[:, 1]
    except Exception:
        return np.argsort(-deg)
    order = np.argsort(fiedler)
    return np.asarray(order, dtype=int)

def degree_order(sim_mat):
    A = np.abs(sim_mat)
    score = A.sum(axis=1)
    return np.argsort(-score).astype(int)

# --- hierarchical + optimal leaf ordering (robust) ---
def hierarchical_seriation(sim_mat):
    A = np.abs(sim_mat)
    np.fill_diagonal(A, 0.0)
    # convert similarity -> distance in [0,1]
    dist = 1.0 - A
    # make symmetric
    dist = (dist + dist.T) / 2.0
    # condense for linkage
    try:
        condensed = squareform(dist)
    except Exception:
        # degenerate small C
        return np.arange(A.shape[0], dtype=int)
    try:
        Z = linkage(condensed, method='average')
        Z_opt = optimal_leaf_ordering(Z, condensed)
        order = leaves_list(Z_opt)
        return np.asarray(order, dtype=int)
    except Exception:
        # fallback to degree
        return degree_order(A)

# --- PCA ordering: sort by first principal component coordinate (or by abs) ---
def pca_order(train_data, by_abs=True):
    X = _prepare_train_stack(train_data)
    pca = PCA(n_components=1)
    pca.fit(X)  # shape: components_[0] is length C
    comp = pca.components_[0]  # may be positive/negative
    if by_abs:
        order = np.argsort(-np.abs(comp))
    else:
        order = np.argsort(comp)
    return order.astype(int)

# --- greedy nearest neighbor on abs-correlation -->
def greedy_nearest(sim_mat, start=None):
    A = np.abs(sim_mat).copy()
    np.fill_diagonal(A, 0.0)
    C = A.shape[0]
    if C <= 1:
        return np.arange(C, dtype=int)
    degrees = A.sum(axis=1)
    if start is None:
        start = int(np.argmax(degrees))
    order = [int(start)]
    used = np.zeros(C, dtype=bool)
    used[start] = True
    for _ in range(C - 1):
        last = order[-1]
        # choose not-used variable with highest correlation to last
        candidates = np.where(~used)[0]
        if candidates.size == 0:
            break
        # tie-breaker by degree
        scores = A[last, candidates]
        best_idx = int(candidates[np.argmax(scores)])
        order.append(best_idx)
        used[best_idx] = True
    return np.asarray(order, dtype=int)

# --- random permutation (reproducible by seed) ---
def random_permutation(C, seed=0):
    rng = np.random.RandomState(seed)
    return rng.permutation(C).astype(int)

# --- wrapper that exposes unified API ---
def compute_permutation_from_train(train_data, method='spectral', seed=0):
    """
    train_data: [T,C] or [N,T,C]
    method: 'spectral', 'degree', 'hierarchical', 'pca', 'greedy', 'random'
    seed: used for random method (and could be used in others if needed)
    returns: permutation array length C
    """
    # prepare
    X = _prepare_train_stack(train_data)
    C = X.shape[1]
    if C <= 1:
        return np.arange(C, dtype=int)

    corr = np.corrcoef(X, rowvar=False)
    corr = np.nan_to_num(corr, nan=0.0)

    method = method.lower()
    if method == 'spectral':
        return spectral_seriation(corr)
    elif method == 'degree':
        return degree_order(corr)
    elif method == 'hierarchical':
        return hierarchical_seriation(corr)
    elif method == 'pca':
        return pca_order(X, by_abs=True)
    elif method == 'greedy':
        return greedy_nearest(corr)
    elif method == 'random':
        return random_permutation(C, seed=seed)
    else:
        raise ValueError(f"unknown permute method: {method}")

def apply_permutation_array(arr, perm):
    """
    Apply permutation perm to the last axis (variables) of numpy or torch arrays.
    arr: numpy ndarray or torch tensor with last axis = variable axis
    perm: 1D index array/list
    """
    if perm is None:
        return arr
    if isinstance(arr, np.ndarray):
        if arr.ndim == 1:
            return arr  # a single vector; nothing to permute meaningfully
        elif arr.ndim == 2:
            return arr[:, perm]
        else:
            # apply to last axis
            idx = [slice(None)] * arr.ndim
            idx[-1] = perm
            return arr[tuple(idx)]
    elif isinstance(arr, torch.Tensor):
        perm_t = torch.as_tensor(perm, dtype=torch.long, device=arr.device)
        if arr.dim() == 1:
            return arr
        # torch indexing
        return arr.index_select(dim=-1, index=perm_t)
    else:
        return arr

# ----------------- Dataset classes (modified) -----------------

class Dataset_ETT_hour(Dataset):
    def __init__(self, root_path, flag='train', size=None,
                 features='S', data_path='ETTh1.csv',
                 target='OT', scale=True, timeenc=0, freq='h',
                 permute_method=None, permute_order=None, permute_seed=0):
        # size [seq_len, label_len, pred_len]
        if size is None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]

        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scale = scale
        self.timeenc = timeenc
        self.freq = freq

        self.root_path = root_path
        self.data_path = data_path

        # permutation options
        self.permute_method = permute_method
        self.permute_order = permute_order
        self.permute_seed = permute_seed

        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))

        border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len]
        border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        if self.features == 'M' or self.features == 'MS':
            cols_data = df_raw.columns[1:]
            df_data = df_raw[cols_data]
        elif self.features == 'S':
            df_data = df_raw[[self.target]]

        # --- compute / apply permutation BEFORE scaling ---
        if self.permute_order is not None:
            perm = np.asarray(self.permute_order, dtype=int)
            # validate length
            if perm.shape[0] != df_data.shape[1]:
                raise ValueError("permute_order length must equal number of variables (columns)")
            df_data = df_data.iloc[:, perm]
            self._perm = perm
        elif self.permute_method is not None and df_data.shape[1] > 1:
            # derive training slice to compute permutation
            train_slice = df_data.values[border1s[0]:border2s[0]]
            perm = compute_permutation_from_train(train_slice, method=self.permute_method, seed=self.permute_seed)
            df_data = df_data.iloc[:, perm]
            self._perm = perm
        else:
            self._perm = None

        if self.scale:
            train_data = df_data[border1s[0]:border2s[0]]
            self.scaler.fit(train_data.values)
            data = self.scaler.transform(df_data.values)
        else:
            data = df_data.values

        df_stamp = df_raw[['date']][border1:border2].copy()
        df_stamp['date'] = pd.to_datetime(df_stamp.date)
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
            data_stamp = df_stamp.drop(['date'], axis=1).values
        elif self.timeenc == 1:
            data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)

        self.data_x = data[border1:border2]
        self.data_y = data[border1:border2]
        self.data_stamp = data_stamp

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
        seq_x_mark = self.data_stamp[s_begin:s_end]
        seq_y_mark = self.data_stamp[r_begin:r_end]

        return seq_x, seq_y, seq_x_mark, seq_y_mark

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        # expects data shape [..., C]
        if self._perm is not None:
            # invert permutation
            inv = np.argsort(self._perm)
            if isinstance(data, np.ndarray):
                return self.scaler.inverse_transform(data[:, inv])
            else:
                # torch tensor
                arr = data.detach().cpu().numpy()
                return self.scaler.inverse_transform(arr[:, inv])
        else:
            return self.scaler.inverse_transform(data)


class Dataset_ETT_minute(Dataset):
    def __init__(self, root_path, flag='train', size=None,
                 features='S', data_path='ETTm1.csv',
                 target='OT', scale=True, timeenc=0, freq='t',
                 permute_method=None, permute_order=None, permute_seed=0):
        if size is None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]

        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scale = scale
        self.timeenc = timeenc
        self.freq = freq

        self.root_path = root_path
        self.data_path = data_path

        self.permute_method = permute_method
        self.permute_order = permute_order
        self.permute_seed = permute_seed

        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))

        border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len]
        border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]
        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        if self.features == 'M' or self.features == 'MS':
            cols_data = df_raw.columns[1:]
            df_data = df_raw[cols_data]
        elif self.features == 'S':
            df_data = df_raw[[self.target]]

        # permutation
        if self.permute_order is not None:
            perm = np.asarray(self.permute_order, dtype=int)
            if perm.shape[0] != df_data.shape[1]:
                raise ValueError("permute_order length must equal number of variables")
            df_data = df_data.iloc[:, perm]
            self._perm = perm
        elif self.permute_method is not None and df_data.shape[1] > 1:
            train_slice = df_data.values[border1s[0]:border2s[0]]
            perm = compute_permutation_from_train(train_slice, method=self.permute_method, seed=self.permute_seed)
            df_data = df_data.iloc[:, perm]
            self._perm = perm
        else:
            self._perm = None

        if self.scale:
            train_data = df_data[border1s[0]:border2s[0]]
            self.scaler.fit(train_data.values)
            data = self.scaler.transform(df_data.values)
        else:
            data = df_data.values

        df_stamp = df_raw[['date']][border1:border2].copy()
        df_stamp['date'] = pd.to_datetime(df_stamp.date)
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
            df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1)
            df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15)
            data_stamp = df_stamp.drop(['date'], axis=1).values
        elif self.timeenc == 1:
            data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)

        self.data_x = data[border1:border2]
        self.data_y = data[border1:border2]
        self.data_stamp = data_stamp

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
        seq_x_mark = self.data_stamp[s_begin:s_end]
        seq_y_mark = self.data_stamp[r_begin:r_end]

        return seq_x, seq_y, seq_x_mark, seq_y_mark

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        if self._perm is not None:
            inv = np.argsort(self._perm)
            arr = data.copy()
            return self.scaler.inverse_transform(arr[:, inv])
        else:
            return self.scaler.inverse_transform(data)


class Dataset_Custom(Dataset):
    def __init__(self, root_path, flag='train', size=None,
                 features='S', data_path='ETTh1.csv',
                 target='OT', scale=True, timeenc=0, freq='h',
                 permute_method=None, permute_order=None, permute_seed=0):
        if size is None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]

        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scale = scale
        self.timeenc = timeenc
        self.freq = freq

        self.root_path = root_path
        self.data_path = data_path

        self.permute_method = permute_method
        self.permute_order = permute_order
        self.permute_seed = permute_seed

        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))

        cols = list(df_raw.columns)
        cols.remove(self.target)
        cols.remove('date')
        df_raw = df_raw[['date'] + cols + [self.target]]

        num_train = int(len(df_raw) * 0.7)
        num_test = int(len(df_raw) * 0.2)
        num_vali = len(df_raw) - num_train - num_test
        border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
        border2s = [num_train, num_train + num_vali, len(df_raw)]
        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        if self.features == 'M' or self.features == 'MS':
            cols_data = df_raw.columns[1:]
            df_data = df_raw[cols_data]
        elif self.features == 'S':
            df_data = df_raw[[self.target]]


        # permutation
        if self.permute_order is not None:
            perm = np.asarray(self.permute_order, dtype=int)
            if perm.shape[0] != df_data.shape[1]:
                raise ValueError("permute_order length must equal number of variables")
            df_data = df_data.iloc[:, perm]
            self._perm = perm
        elif self.permute_method is not None and df_data.shape[1] > 1:
            train_data = df_data.values[border1s[0]:border2s[0]]
            perm = compute_permutation_from_train(train_data, method=self.permute_method, seed=self.permute_seed)
            print("perm",perm)
            df_data = df_data.iloc[:, perm]
            self._perm = perm
        else:
            self._perm = None

        if self.scale:
            train_data = df_data[border1s[0]:border2s[0]]
            self.scaler.fit(train_data.values)
            data = self.scaler.transform(df_data.values)
        else:
            data = df_data.values

        df_stamp = df_raw[['date']][border1:border2].copy()
        df_stamp['date'] = pd.to_datetime(df_stamp.date)
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
            data_stamp = df_stamp.drop(['date'], axis=1).values
        elif self.timeenc == 1:
            data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)

        self.data_x = data[border1:border2]
        self.data_y = data[border1:border2]
        self.data_stamp = data_stamp

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
        seq_x_mark = self.data_stamp[s_begin:s_end]
        seq_y_mark = self.data_stamp[r_begin:r_end]

        return seq_x, seq_y, seq_x_mark, seq_y_mark

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        if self._perm is not None:
            inv = np.argsort(self._perm)
            if isinstance(data, np.ndarray):
                return self.scaler.inverse_transform(data[:, inv])
            else:
                arr = data.detach().cpu().numpy()
                return self.scaler.inverse_transform(arr[:, inv])
        else:
            return self.scaler.inverse_transform(data)


class Dataset_Pred(Dataset):
    def __init__(self, root_path, flag='pred', size=None,
                 features='S', data_path='ETTh1.csv',
                 target='OT', scale=True, inverse=False, timeenc=0, freq='15min', cols=None,
                 permute_method=None, permute_order=None, permute_seed=0):
        if size is None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]

        assert flag in ['pred']

        self.features = features
        self.target = target
        self.scale = scale
        self.inverse = inverse
        self.timeenc = timeenc
        self.freq = freq
        self.cols = cols
        self.root_path = root_path
        self.data_path = data_path

        self.permute_method = permute_method
        self.permute_order = permute_order
        self.permute_seed = permute_seed

        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))

        if self.cols:
            cols = self.cols.copy()
            cols.remove(self.target)
        else:
            cols = list(df_raw.columns)
            cols.remove(self.target)
            cols.remove('date')
        df_raw = df_raw[['date'] + cols + [self.target]]
        border1 = len(df_raw) - self.seq_len
        border2 = len(df_raw)

        if self.features == 'M' or self.features == 'MS':
            cols_data = df_raw.columns[1:]
            df_data = df_raw[cols_data]
        elif self.features == 'S':
            df_data = df_raw[[self.target]]

        # permutation (for pred, use provided permute_order or compute from all data)
        if self.permute_order is not None:
            perm = np.asarray(self.permute_order, dtype=int)
            if perm.shape[0] != df_data.shape[1]:
                raise ValueError("permute_order length must equal number of variables")
            df_data = df_data.iloc[:, perm]
            self._perm = perm
        elif self.permute_method is not None and df_data.shape[1] > 1:
            # use earlier portion of data to compute permutation
            train_slice = df_data.values[:self.seq_len]
            perm = compute_permutation_from_train(train_slice, method=self.permute_method, seed=self.permute_seed)
            df_data = df_data.iloc[:, perm]
            self._perm = perm
        else:
            self._perm = None

        if self.scale:
            self.scaler.fit(df_data.values)
            data = self.scaler.transform(df_data.values)
        else:
            data = df_data.values

        tmp_stamp = df_raw[['date']][border1:border2].copy()
        tmp_stamp['date'] = pd.to_datetime(tmp_stamp.date)
        pred_dates = pd.date_range(tmp_stamp.date.values[-1], periods=self.pred_len + 1, freq=self.freq)

        df_stamp = pd.DataFrame(columns=['date'])
        df_stamp.date = list(tmp_stamp.date.values) + list(pred_dates[1:])
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
            df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1)
            df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15)
            data_stamp = df_stamp.drop(['date'], axis=1).values
        elif self.timeenc == 1:
            data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)

        self.data_x = data[border1:border2]
        if self.inverse:
            self.data_y = df_data.values[border1:border2]
        else:
            self.data_y = data[border1:border2]
        self.data_stamp = data_stamp

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        if self.inverse:
            seq_y = self.data_x[r_begin:r_begin + self.label_len]
        else:
            seq_y = self.data_y[r_begin:r_begin + self.label_len]
        seq_x_mark = self.data_stamp[s_begin:s_end]
        seq_y_mark = self.data_stamp[r_begin:r_end]

        return seq_x, seq_y, seq_x_mark, seq_y_mark

    def __len__(self):
        return len(self.data_x) - self.seq_len + 1

    def inverse_transform(self, data):
        if self._perm is not None:
            inv = np.argsort(self._perm)
            arr = data.copy()
            return self.scaler.inverse_transform(arr[:, inv])
        else:
            return self.scaler.inverse_transform(data)


class Dataset_PEMS(Dataset):
    def __init__(self, root_path, flag='train', size=None,
                 features='S', data_path='ETTh1.csv',
                 target='OT', scale=True, timeenc=0, freq='h',
                 permute_method=None, permute_order=None, permute_seed=0):
        self.seq_len = size[0]
        self.label_len = size[1]
        self.pred_len = size[2]
        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scale = scale
        self.timeenc = timeenc
        self.freq = freq
        self.root_path = root_path
        self.data_path = data_path

        self.permute_method = permute_method
        self.permute_order = permute_order
        self.permute_seed = permute_seed

        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        data_file = os.path.join(self.root_path, self.data_path)
        loaded = np.load(data_file, allow_pickle=True)
        # many PEMS data files have shape [time, nodes, features], we take last axis 0 if needed
        data = loaded['data']
        # normalize to shape [T, C] or [N, T, C] depending on file
        if data.ndim == 3:
            # assume shape like [num_samples, time, channels] or [time, channels, feat]
            # try common layouts: if last dim ==1, squeeze
            if data.shape[-1] == 1:
                data = data[..., 0]
            # If shape is [num_samples, time, channels], we will concatenate along time for permutation computation
        # else keep as is

        # split into train/val/test by first axis (consistent with original code)
        train_ratio = 0.6
        valid_ratio = 0.2
        train_data = data[:int(train_ratio * len(data))]
        valid_data = data[int(train_ratio * len(data)): int((train_ratio + valid_ratio) * len(data))]
        test_data = data[int((train_ratio + valid_ratio) * len(data)):]
        total_data = [train_data, valid_data, test_data]
        data_sel = total_data[self.set_type]

        # apply permutation (on last axis)
        if self.permute_order is not None:
            perm = np.asarray(self.permute_order, dtype=int)
            self._perm = perm
            data_sel = apply_permutation_array(data_sel, perm)
        elif self.permute_method is not None:
            # compute permutation using concatenated train_data along time axis
            if train_data.ndim == 3:
                # [N, T, C] -> stack to [N*T, C]
                train_concat = train_data.reshape(-1, train_data.shape[-1])
            else:
                train_concat = train_data
            perm = compute_permutation_from_train(train_concat, method=self.permute_method, seed=self.permute_seed)
            self._perm = perm
            data_sel = apply_permutation_array(data_sel, perm)
        else:
            self._perm = None

        if self.scale:
            # scaler expects 2D array: flatten time dimension and fit
            if data_sel.ndim == 2:
                self.scaler.fit(data_sel)
                data_sel_scaled = self.scaler.transform(data_sel)
            elif data_sel.ndim == 3:
                # shape [N, T, C] -> flatten by concatenating first two dims
                flat = data_sel.reshape(-1, data_sel.shape[-1])
                self.scaler.fit(flat)
                scaled_flat = self.scaler.transform(flat)
                data_sel_scaled = scaled_flat.reshape(data_sel.shape)
            else:
                raise ValueError("Unexpected data_sel ndim")
        else:
            data_sel_scaled = data_sel

        # consistent with earlier API: self.data_x = df (2D) or reshaped
        if isinstance(data_sel_scaled, np.ndarray) and data_sel_scaled.ndim == 3:
            # choose representation used by original code (they used df = pd.DataFrame(data) previously)
            # we collapse first dimension (samples) by concatenating along time for dataset indexing simplicity
            self.data_x = data_sel_scaled.reshape(-1, data_sel_scaled.shape[-1])
            self.data_y = self.data_x
        else:
            self.data_x = data_sel_scaled
            self.data_y = data_sel_scaled

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
        seq_x_mark = torch.zeros((seq_x.shape[0], 1))
        seq_y_mark = torch.zeros((seq_x.shape[0], 1))

        return seq_x, seq_y, seq_x_mark, seq_y_mark

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        if self._perm is not None:
            inv = np.argsort(self._perm)
            return self.scaler.inverse_transform(data[:, inv])
        else:
            return self.scaler.inverse_transform(data)


class Dataset_Solar(Dataset):
    def __init__(self, root_path, flag='train', size=None,
                 features='S', data_path='ETTh1.csv',
                 target='OT', scale=True, timeenc=0, freq='h',
                 permute_method=None, permute_order=None, permute_seed=0):
        self.seq_len = size[0]
        self.label_len = size[1]
        self.pred_len = size[2]
        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]
        self.features = features
        self.target = target
        self.scale = scale
        self.timeenc = timeenc
        self.freq = freq

        self.root_path = root_path
        self.data_path = data_path

        self.permute_method = permute_method
        self.permute_order = permute_order
        self.permute_seed = permute_seed

        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        df_raw = []
        with open(os.path.join(self.root_path, self.data_path), "r", encoding='utf-8') as f:
            for line in f.readlines():
                line = line.strip('\n').split(',')
                data_line = np.stack([float(i) for i in line])
                df_raw.append(data_line)
        df_raw = np.stack(df_raw, 0)
        df_raw = pd.DataFrame(df_raw)

        num_train = int(len(df_raw) * 0.7)
        num_test = int(len(df_raw) * 0.2)
        num_valid = int(len(df_raw) * 0.1)
        border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
        border2s = [num_train, num_train + num_valid, len(df_raw)]
        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        df_data = df_raw.values

        # permutation
        if self.permute_order is not None:
            perm = np.asarray(self.permute_order, dtype=int)
            df_data = apply_permutation_array(df_data, perm)
            self._perm = perm
        elif self.permute_method is not None:
            train_slice = df_data[border1s[0]:border2s[0]]
            perm = compute_permutation_from_train(train_slice, method=self.permute_method, seed=self.permute_seed)
            df_data = apply_permutation_array(df_data, perm)
            self._perm = perm
        else:
            self._perm = None

        if self.scale:
            train_data = df_data[border1s[0]:border2s[0]]
            self.scaler.fit(train_data)
            data = self.scaler.transform(df_data)
        else:
            data = df_data

        self.data_x = data[border1:border2]
        self.data_y = data[border1:border2]

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
        seq_x_mark = torch.zeros((seq_x.shape[0], 1))
        seq_y_mark = torch.zeros((seq_x.shape[0], 1))

        return seq_x, seq_y, seq_x_mark, seq_y_mark

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        if self._perm is not None:
            inv = np.argsort(self._perm)
            return self.scaler.inverse_transform(data[:, inv])
        else:
            return self.scaler.inverse_transform(data)
