import os
import pickle
import torch
import numpy as np
from sklearn.preprocessing import LabelEncoder
from aeon.datasets import load_classification

def get_data(cfg, logger):
    """
    Loads, preprocesses, splits, and standardizes the dataset.
    Handles 'no' time scaling by using raw integer timestamps.
    
    Returns:
        train_X, val_X, test_X, train_y, val_y, test_y, timestamps
    """
    dataset_name = cfg['dataset_name']
    save_path = os.path.join(cfg['data_dir'], f"{dataset_name}.pkl")

    logger.info(f"--- Loading dataset: {dataset_name} ---")
    if os.path.exists(save_path):
        with open(save_path, 'rb') as f:
            data = pickle.load(f)
        X, y = data['X'], data['y']
    else:
        X, y, _ = load_classification(dataset_name, return_metadata=True)
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, 'wb') as f:
            pickle.dump({'X': X, 'y': y}, f)
        logger.info("Data was downloaded and saved.")

    if not np.issubdtype(y.dtype, np.number):
        logger.info("Non-numeric labels detected, encoding with LabelEncoder.")
        label_encoder = LabelEncoder()
        y = label_encoder.fit_transform(y)
    else:
        y = y.astype(int)
    
    n_classes = len(np.unique(y))
    logger.info(f"Dataset '{dataset_name}' loaded. Shape: {X.shape}, Classes: {n_classes}")

    X = torch.tensor(X, dtype=torch.float32)
    y = torch.tensor(y, dtype=torch.int64)

    time_scaling = cfg['time_scaling_factor']
    seq_len = X.shape[2]
    t_original = torch.arange(0, seq_len, dtype=X.dtype)

    if isinstance(time_scaling, str):
        if time_scaling.lower() == 'uni':
            logger.info("Time scaling set to 'uni'. Force creating timestamps linspace(0, 1).")
            timestamps = torch.linspace(0, 1, seq_len, dtype=X.dtype)
            
        elif time_scaling.lower() == 'no':
            logger.info("Time scaling set to 'no'. Using regular timestamps (0, 1, 2...).")
            timestamps = t_original
        else:
            try:
                ts_val = float(time_scaling)
                norm_const = float(seq_len - 1)
                timestamps = (ts_val / norm_const) * t_original
                logger.info(f"Time scaling factor (from str): {ts_val}. Normalized timestamps.")
            except ValueError:
                 raise ValueError(f"Invalid time_scaling_factor: {time_scaling}. Must be 'uni', 'no', or a number.")
    else:
        try:
            ts_val = float(time_scaling)
            norm_const = float(seq_len - 1)
            timestamps = (ts_val / norm_const) * t_original
            logger.info(f"Time scaling factor: {ts_val}. Normalized timestamps.")
        except ValueError:
             raise ValueError(f"Invalid time_scaling_factor: {time_scaling}. Must be a number, 'uni' or 'no'.")
    
    add_time_str = str(cfg.get('add_time', 'yes')).lower()
    
    X = X.permute(0, 2, 1) 

    if add_time_str == 'yes':
        timestamps_expanded = timestamps.unsqueeze(0).unsqueeze(0).expand(X.size(0), 1, -1)
        timestamps_expanded = timestamps_expanded.permute(0, 2, 1) # (batch, seq, 1)
        X = torch.cat([timestamps_expanded, X], dim=2)
        logger.info("Time channel appended to X.")
    else:
        logger.info("Time channel NOT appended to X (kept separate).")

    perm = torch.randperm(X.size(0))
    X, y = X[perm], y[perm]
    logger.info(f"Data prepared. Final X shape: {X.shape}")

    train_frac, val_frac = 0.6, 0.2
    num_samples = X.shape[0]
    train_end = int(num_samples * train_frac)
    val_end = int(num_samples * (train_frac + val_frac))

    train_X, val_X, test_X = X[:train_end], X[train_end:val_end], X[val_end:]
    train_y, val_y, test_y = y[:train_end], y[train_end:val_end], y[val_end:]
    logger.info(f"Train/Val/Test split: {len(train_X)}/{len(val_X)}/{len(test_X)}")

    logger.info("Standardizing features based on training set statistics.")
    
    # If add_time='yes', channel 0 is time, normalize channels 1+.
    # If add_time='no', channel 0 is feature, normalize channels 0+.
    start_feat_idx = 1 if add_time_str == 'yes' else 0
    
    for feature_idx in range(start_feat_idx, train_X.shape[2]):
        mean = train_X[:, :, feature_idx].mean()
        std = train_X[:, :, feature_idx].std() + 1e-8
        train_X[:, :, feature_idx] = (train_X[:, :, feature_idx] - mean) / std
        val_X[:, :, feature_idx] = (val_X[:, :, feature_idx] - mean) / std
        test_X[:, :, feature_idx] = (test_X[:, :, feature_idx] - mean) / std

    logger.info("\n" + "="*40)
    logger.info("      DATA INSPECTION (After Norm)      ")
    logger.info("="*40)
    
    if add_time_str == 'yes':
        ts_snippet = train_X[0, :10, 0]
        ts_min = train_X[:, :, 0].min().item()
        ts_max = train_X[:, :, 0].max().item()
        feat_snippet = train_X[0, :5, 1:]
        all_feats = train_X[:, :, 1:]
        logger.info(f"TIMESTAMPS (Sample 0, first 10 steps, embedded):\n{ts_snippet}")
        logger.info(f"TIMESTAMPS RANGE (Embedded): Min={ts_min:.4f}, Max={ts_max:.4f}")
    else:
        ts_snippet = timestamps[:10]
        ts_min = timestamps.min().item()
        ts_max = timestamps.max().item()
        feat_snippet = train_X[0, :5, :]
        all_feats = train_X
        logger.info(f"TIMESTAMPS (Separate Tensor, first 10 steps):\n{ts_snippet}")
        logger.info(f"TIMESTAMPS RANGE (Separate): Min={ts_min:.4f}, Max={ts_max:.4f}")

    logger.info(f"FEATURES (Sample 0, first 5 steps):\n{feat_snippet}")
    logger.info(f"FEATURES STATS (Global Train): Mean={all_feats.mean():.4f}, Std={all_feats.std():.4f}")

    logger.info(f"LABELS (First 15): {train_y[:15].tolist()}")
    logger.info("="*40 + "\n")
    
    return train_X, val_X, test_X, train_y, val_y, test_y, timestamps