import numpy as np
import pandas as pd
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
import os
import random

seed = 2025
random.seed(seed)
np.random.seed(seed)


def scaling_split(dataset, TASKS, DataPath):
    Combined_Dataset = []
    all_labels = []
    for task_id in TASKS:
        df,target_cols = load_task_dataframe(dataset, DataPath, task_id)
        Combined_Dataset.append(df)

    Combined_Dataset = pd.concat(Combined_Dataset).dropna()
    if dataset=='Landmine':
        scaler = preprocessing.StandardScaler()
        label_scaler = preprocessing.StandardScaler()
    else:
        scaler = preprocessing.QuantileTransformer(n_quantiles=30)

    ALL_VALS = Combined_Dataset.drop(columns=target_cols).values
    scaler.fit(ALL_VALS)
    label_scaler.fit(all_labels)

    all_splits = {}

    for task_id in TASKS:
        df,target_cols = load_task_dataframe(dataset, DataPath, task_id)
        X = df.drop(columns=[target_cols[0]]).values
        y = df[target_cols[0]].values

        X_scaled = scaler.transform(X)
        stratify = y if dataset in ['Landmine', 'Chemical'] else None
        X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3, random_state=42, stratify=stratify)

        all_splits[task_id] = {
            'X_train': X_train,
            'X_test': X_test,
            'y_train': y_train,
            'y_test': y_test
        }

        # Save splits to disk
        save_dir = f"{DataPath}/Task_Splits"
        os.makedirs(save_dir, exist_ok=True)

        np.save(f"{save_dir}/{task_id}_X_train.npy", X_train)
        np.save(f"{save_dir}/{task_id}_X_test.npy", X_test)
        np.save(f"{save_dir}/{task_id}_y_train.npy", y_train)
        np.save(f"{save_dir}/{task_id}_y_test.npy", y_test)

    return all_splits

def load_task_dataframe(dataset, DataPath, task_id):
    if dataset == 'School':
        df = pd.read_csv(f"{DataPath}/{task_id}_School_Data.csv")
        df = df[['1985', '1986', '1987',
                 'ESWI', 'African', 'Arab', 'Bangladeshi', 'Caribbean', 'Greek', 'Indian', 'Pakistani', 'SE_Asian',
                 'Turkish', 'Other', 'VR_Band', 'Gender', 'FSM', 'VR_BAND_Student', 'School_Gender',
                 'Maintained', 'Church', 'Roman_Cath', 'ExamScore']]
        target_col = ['ExamScore']
    elif dataset == 'Landmine':
        df = pd.read_csv(f"{DataPath}/LandmineData_{task_id}.csv")
        target_col = ['Labels']
    elif dataset == 'Chemical':
        df = pd.read_csv(f"{DataPath}/{task_id}_Molecule_Data.csv")
        df.loc[df['181'] < 0, '181'] = 0
        target_col = ['181']
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")

    df = df.dropna()
    return df, target_col


# Main loop over datasets
for dataset in ['School', 'Landmine', 'Chemical']:
    DataPath = f"../Dataset/{dataset.upper()}/"
    info_path = f'{DataPath}Task_Information_{dataset if dataset != "Chemical" else "Chemical"}.csv'
    Task_InfoData = pd.read_csv(info_path)
    task_column = 'Molecule' if dataset == 'Chemical' else 'Task_Name'
    TASKS = list(Task_InfoData[task_column])

    print(f'Processing dataset = {dataset}, Total Tasks = {len(TASKS)}')
    splits = scaling_split(dataset, TASKS, DataPath)

