import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))


import argparse
import pandas as pd
import os
from pathlib import Path
from sklearn.model_selection import train_test_split
import logging

from fortress.config import Config
from scripts.utils.script_helpers import find_csv_files

               
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def split_single_csv(csv_path: Path, split_ratio: float, config: Config):
    """
    Splits a single CSV file into database and benchmark by filling the split column,
    and overwrites the original file.
    """
    try:
        df = pd.read_csv(csv_path)
        logger.info(f"Successfully read {csv_path}. Shape: {df.shape}")

        if config.split_column_name not in df.columns:
            df[config.split_column_name] = config.cleared_split_value
            logger.info(f"Added split column '{config.split_column_name}' to {csv_path.name}")

        df = df.sort_index()

                                                        
        mask_unassigned = ~df[config.split_column_name].isin([config.database_split_value, config.benchmark_split_value])
        to_split_df = df[mask_unassigned]

        if to_split_df.empty:
            logger.info(f"No rows to split in {csv_path.name}. All rows might be pre-assigned or file is empty.")
                                                      
            df.to_csv(csv_path, index=False)
            return

        stratify_col = None
        if 'label' in to_split_df.columns:
            if to_split_df['label'].nunique() > 1 and to_split_df['label'].value_counts().min() >= 2:
                stratify_col = to_split_df['label']
            else:
                logger.warning(f"Cannot stratify by 'label' for {csv_path.name} due to insufficient samples per class. Using random split.")

        if len(to_split_df) < 2:
            logger.warning(f"Not enough samples to split in {csv_path.name} after filtering pre-assigned. Assigning all to database.")
            db_indices = to_split_df.index
            bench_indices = []
        elif split_ratio == 1.0:
            db_indices = to_split_df.index
            bench_indices = []
        elif split_ratio == 0.0:
            db_indices = []
            bench_indices = to_split_df.index
        else:
            db_part, bench_part = train_test_split(
                to_split_df,
                train_size=split_ratio,
                stratify=stratify_col,
                random_state=42
            )
            db_indices = db_part.index
            bench_indices = bench_part.index

                             
        df.loc[db_indices, config.split_column_name] = config.database_split_value
        df.loc[bench_indices, config.split_column_name] = config.benchmark_split_value

        df.to_csv(csv_path, index=False)
        logger.info(f"Overwritten {csv_path} with updated split column.")

    except FileNotFoundError:
        logger.error(f"Error: File not found at {csv_path}")
    except pd.errors.EmptyDataError:
        logger.error(f"Error: File is empty at {csv_path}")
    except Exception as e:
        logger.error(f"An unexpected error occurred while processing {csv_path}: {e}", exc_info=True)

def main():
    parser = argparse.ArgumentParser(description="Split CSV file(s) into database and benchmark sets by updating the split column in-place.")
    parser.add_argument("input_path", type=str, help="Path to a CSV file or a directory containing CSV files.")
    parser.add_argument("--split_ratio", type=float, help="Ratio for the 'database' split (e.g., 0.8 for 80%% database). Defaults to value from settings.yaml.")

    args = parser.parse_args()
    config = Config()

    input_path = Path(args.input_path)
    split_ratio = args.split_ratio if args.split_ratio is not None else config.default_split_ratio

    if not (0.0 <= split_ratio <= 1.0):
        logger.error("Split ratio must be between 0.0 and 1.0.")
        return

    csv_files_to_process = find_csv_files(input_path)

    if not csv_files_to_process:
        logger.warning(f"No CSV files found in {input_path}.")
        return

    for csv_file in csv_files_to_process:
        logger.info(f"Processing {csv_file}...")
        split_single_csv(csv_file, split_ratio, config)
        logger.info("-" * 30)

    logger.info("CSV splitting process completed.")

if __name__ == "__main__":
    main()
