import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pandas as pd
from kaggle.api.kaggle_api_extended import KaggleApi
import zipfile
from utils import check_data

################################################################################
# (1) DATA LOADING
################################################################################
print("=" * 50)
print("STEP 1: DATA LOADING")
print("=" * 50)

# Set paths
base_dir = os.path.dirname(__file__)
RAW_PATH = os.path.join(base_dir, 'raw.csv')
OUT_PATH = os.path.join(base_dir, 'quasar.csv') 

# Authenticate with Kaggle API
api = KaggleApi()
api.authenticate()

# Download dataset ZIP (don't unzip automatically)
dataset_name = 'fedesoriano/stellar-classification-dataset-sdss17' 
api.dataset_download_files(dataset_name, path=base_dir, unzip=False)

# Unzip manually and rename CSV to 'raw.csv'
zip_path = os.path.join(base_dir, dataset_name.split('/')[-1] + '.zip')
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    for file in zip_ref.namelist():
        if file.endswith('.csv'):
            zip_ref.extract(file, base_dir)
            os.rename(os.path.join(base_dir, file), RAW_PATH)  # rename to raw.csv
os.remove(zip_path)

# Load and process CSV
df = pd.read_csv(RAW_PATH)

print("STEP 1 COMPLETED: Data loaded and raw file saved")
print("=" * 50)

################################################################################
# (2) FORMAT
################################################################################
print("STEP 2: FORMAT")
print("=" * 50)

# Create binary classification: QSO (quasar) = anomaly (1), others = normal (0)
df['label'] = (df['class'] == 'QSO').astype(int)

# Drop identification and metadata columns (keep only physical features)
columns_to_drop = ['obj_ID', 'run_ID', 'rerun_ID', 'cam_col', 'field_ID', 'spec_obj_ID', 'plate', 'MJD', 'fiber_ID', 'class']
df = df.drop(columns=[col for col in columns_to_drop if col in df.columns])

print("STEP 2 COMPLETED: Data formatted and cleaned")
print("=" * 50)

################################################################################
# (3) VALIDATION
################################################################################
print("STEP 3: VALIDATION")
print("=" * 50)

df = check_data(df)

if df is None:
    print("ERROR: Data validation failed!")
    exit()

print("STEP 3 COMPLETED: Data validation passed")
print("=" * 50)

################################################################################
# (4) POSTPROCESSING & SAVE
################################################################################
print("STEP 4: POSTPROCESSING & SAVE")
print("=" * 50)

# Save final processed data
df.to_csv(OUT_PATH, index=False)

print("STEP 4 COMPLETED: Final data saved")
print("=" * 50)
print("ALL PREPROCESSING STEPS COMPLETED!")
print("=" * 50)