import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
import xgboost as xgb
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score, f1_score
import argparse
import os, sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from filtering import grid_search_filter

def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description='XGBoost model training with data filtering')
    parser.add_argument('--df_train', type=str, default='./data/sampled_30.csv',
                        help='Path to training data CSV file')
    parser.add_argument('--df_syn', type=str, default='./data/synthetic_data.csv',
                        help='Path to synthetic data CSV file')
    parser.add_argument('--df_test', type=str, default='./data/test.csv',
                        help='Path to test data CSV file')
    parser.add_argument('--filter', action='store_true', default=True,
                        help='Whether to apply data filtering')
    parser.add_argument('--output_file', type=str, default='./data/filtered.csv',
                        help='Path to save filtered data')
    
    return parser.parse_args()

def main():
    # Parse command line arguments
    args = parse_args()
    
    # Load data files
    df_train = pd.read_csv(args.df_train)
    df_syn = pd.read_csv(args.df_syn)
    df_test = pd.read_csv(args.df_test)

    df = pd.concat([df_syn], ignore_index=True)

    df_test = df_test.drop("DoctorInCharge", axis=1)

    print("Original length: ", len(df))

    print("Filter enabled:", args.filter)

    label_columns = ["Gender", "Ethnicity", "EducationLevel", "Smoking", "FamilyHistoryAlzheimers", "CardiovascularDisease", 
                "Diabetes", "Depression", "HeadInjury", "Hypertension", "MemoryComplaints", "BehavioralProblems", "Confusion", 
                "Disorientation", "PersonalityChanges", "DifficultyCompletingTasks", "Forgetfulness"]

    if 'Diagnosis' in df_test.columns:
        cols = ['Diagnosis'] + [col for col in df_test.columns if col != 'Diagnosis']
        df = df[cols]
        df_train = df_train[cols]
        df_test = df_test[cols]

    categorical_features = label_columns
    all_features = df_test.drop('Diagnosis',axis=1).columns.tolist()

    numerical_features = [feat for feat in all_features if feat not in categorical_features]

    label_encoders = {}
    for col in label_columns:
        df[col] = df[col].astype(str)
        df_train[col] = df_train[col].astype(str)
        df_test[col] = df_test[col].astype(str)

        le = LabelEncoder()
        df_test[col] = le.fit_transform(df_test[col])
        label_encoders[col] = le
        df[col] = df[col].map(lambda x: le.transform([x])[0] if x in le.classes_ else -1)
        df_train[col] = df_train[col].map(lambda x: le.transform([x])[0] if x in le.classes_ else -1)

    X_train_orig = df_train.drop('Diagnosis', axis=1)
    y_train_orig = df_train['Diagnosis']
    X_train = df.drop('Diagnosis', axis=1)
    y_train = df['Diagnosis']  

    X_test = df_test.drop('Diagnosis', axis=1)
    y_test = df_test['Diagnosis']

    # Apply filtering if enabled
    if args.filter:
        result = grid_search_filter(
            X_small=X_train_orig, y_small=y_train_orig,
            X_large=X_train, y_large=y_train,
            X_test=X_train_orig, y_test=y_train_orig,
            numerical_features = numerical_features,
            categorical_features = categorical_features,
            block_sizes=[20, 25, 30, 35, 40, 45, 50, 55, 60], 
            n_iterations=10,
            curate = True
        )

        selected = result["idxs"]
        print("Selected samples after filtering:", len(selected))
        df = df.iloc[selected]
        df.to_csv(args.output_file, index=False)
        print(f"Filtered data saved to: {args.output_file}")

        X_train = df.drop('Diagnosis', axis=1)
        y_train = df['Diagnosis']  

    print("Training data length:", len(X_train))

    # Train XGBoost model
    xgb_model = xgb.XGBClassifier(objective='multi:softmax', num_class=2)

    param_grid = {
        'max_depth': [2, 3, 4, 5],
        'learning_rate': [0.1, 0.01],
        'n_estimators': [50, 100, 150, 200]
    }

    grid_search = GridSearchCV(estimator=xgb_model, param_grid=param_grid, cv=5, scoring='f1_macro')

    grid_search.fit(X_train, y_train)

    best_model = grid_search.best_estimator_
    y_pred = best_model.predict(X_test)

    accuracy = accuracy_score(y_test, y_pred)
    f1_macro = f1_score(y_test, y_pred, average='macro')
    print(f'F1 score on test set: {f1_macro * 100:.2f}% -- Accuracy on test set: {accuracy * 100:.2f}%')

if __name__ == "__main__":
    main()