import os
import pandas as pd
import numpy as np

COLUMNS = (['engine', 'cycle'] +
           [f'setting_{i}' for i in range(1, 4)] +
           [f'sensor_{i}'  for i in range(1, 22+1)])

def _read_one(file_path: str) -> pd.DataFrame:
    df = pd.read_csv(file_path, sep=r'\s+', header=None, names=COLUMNS)
    return df

def _add_rul(df: pd.DataFrame) -> pd.DataFrame:
    # RUL = (max cycle for this engine) - current cycle
    max_cycle = df.groupby('engine')['cycle'].transform('max')
    df['RUL'] = max_cycle - df['cycle']
    return df

def load_cmapss(subset='FD001') -> tuple:
    """Return features X and labels y for the requested subset."""
    # Check multiple possible locations for the CMAPSS dataset
    possible_paths = [
        os.path.join('raw_data', 'CMaps'),
        os.path.join('..', '..', 'raw_data', 'CMaps'),
        os.path.join('..', '..', 'DANCEST_model', 'data', 'CMAPSSData'),
        os.path.join('..', 'data', 'CMAPSSData'),
        os.path.join('data', 'CMAPSSData')
    ]
    
    base = None
    for path in possible_paths:
        if os.path.exists(os.path.join(path, f'train_{subset}.txt')):
            base = path
            break
    
    if base is None:
        raise FileNotFoundError(f"Could not find CMAPSS dataset files for {subset}")
    
    print(f"Loading CMAPSS dataset from: {base}")
    
    train_fp = os.path.join(base, f'train_{subset}.txt')
    df = _add_rul(_read_one(train_fp))
    
    # Feature engineering - keep only settings+sensor
    df = df.drop(['engine', 'cycle'], axis=1)
    X = df.drop('RUL', axis=1).astype('float32')
    y = df['RUL'].astype('float32').values.reshape(-1, 1)
    
    # Drop constant columns
    constant_columns = []
    for col in X.columns:
        if X[col].nunique() <= 1:
            constant_columns.append(col)
    
    if constant_columns:
        print(f"Dropping {len(constant_columns)} constant columns: {constant_columns}")
        X = X.drop(constant_columns, axis=1)
    
    feature_names = X.columns.tolist()
    return X, y, feature_names

# Example function to load test data separately
def load_test_data(subset='FD001'):
    """Load test data and RUL values."""
    # Check multiple possible locations for the CMAPSS dataset
    possible_paths = [
        os.path.join('raw_data', 'CMaps'),
        os.path.join('..', '..', 'raw_data', 'CMaps'),
        os.path.join('..', '..', 'DANCEST_model', 'data', 'CMAPSSData'),
        os.path.join('..', 'data', 'CMAPSSData'),
        os.path.join('data', 'CMAPSSData')
    ]
    
    base = None
    for path in possible_paths:
        if os.path.exists(os.path.join(path, f'test_{subset}.txt')):
            base = path
            break
    
    if base is None:
        raise FileNotFoundError(f"Could not find CMAPSS dataset files for {subset}")
    
    # Load test data
    test_fp = os.path.join(base, f'test_{subset}.txt')
    test_df = _read_one(test_fp)
    
    # Load true RUL values
    rul_fp = os.path.join(base, f'RUL_{subset}.txt')
    if not os.path.exists(rul_fp):
        # Try with additional 0 in filename
        rul_fp = os.path.join(base, f'RUL_FD00{subset[2:]}.txt')
    
    rul_df = pd.read_csv(rul_fp, header=None)
    
    # Get max cycle for each engine
    max_cycles = test_df.groupby('engine')['cycle'].max().reset_index()
    max_cycles['RUL'] = rul_df.values
    
    # Merge with test data
    merged = test_df.merge(max_cycles[['engine', 'RUL']], on='engine')
    
    # Calculate RUL for each row
    merged['RUL'] = merged['RUL'] + merged['cycle'].max() - merged['cycle']
    
    # Feature engineering - keep only settings+sensor
    merged = merged.drop(['engine', 'cycle'], axis=1)
    X_test = merged.drop('RUL', axis=1).astype('float32')
    y_test = merged['RUL'].astype('float32').values.reshape(-1, 1)
    
    # Drop constant columns (same as training data)
    constant_columns = []
    for col in X_test.columns:
        if X_test[col].nunique() <= 1:
            constant_columns.append(col)
    
    if constant_columns:
        X_test = X_test.drop(constant_columns, axis=1)
    
    return X_test, y_test 