from typing import Tuple, List, Optional
import pandas as pd
import numpy as np
import json
from sklearn.preprocessing import LabelEncoder, StandardScaler
import joblib
import argparse
from sklearn.ensemble import RandomForestClassifier


def load_data_csv(
    file_path: str,
    total_samples: int,
    num_samples: int,
    scaler_dump_path: str,
    label_dump_path: str,
    exclude_labels: Optional[List[str]] = None,  # Optional parameter for excluded labels
    train_file_path: Optional[str] = None,      # Optional parameter for training CSV
    test_file_path: Optional[str] = None        # Optional parameter for test CSV
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]:
    """
    Loads and preprocesses data from a CSV file containing trace data, 
    splits it into training and test sets, and optionally saves them to CSV files.

    Args:
        file_path (str): Path to the CSV data file.
        total_samples (int): Total samples per label (unused in current logic).
        num_samples (int): Number of samples to select per label. If 0, use all samples.
        scaler_dump_path (str): File path to save the fitted StandardScaler.
        label_dump_path (str): File path to save the fitted LabelEncoder.
        exclude_labels (List[str], optional): List of labels to exclude. Defaults to None.
        train_file_path (str, optional): Path to save the training CSV file. Defaults to None.
        test_file_path (str, optional): Path to save the test CSV file. Defaults to None.

    Returns:
        Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]: 
            Preprocessed training features X_train and labels y_train, 
            and optionally test features X_test and labels y_test.
    """
    df = pd.DataFrame()  


    try:
        df = pd.read_csv(
            file_path,
            engine='python',  
        )
        print(f"Successfully loaded {len(df)} rows from {file_path}")
    except pd.errors.ParserError as e:
        print(f"Error reading {file_path}: {e}")
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for i, line in enumerate(f):
                    if '"' in line and line.count('"') % 2 != 0:
                        print(f"Unmatched quotes on line {i+1}: ")
                    if line.count(',') == 0:
                        print(f"Possible missing commas on line {i+1}: ")
        except Exception as file_error:
            print(f"Failed to read file line by line: {file_error}")
        return np.array([]), np.array([]), None, None
    except Exception as e:
        print(f"An unexpected error occurred while reading {file_path}: {e}")
        return np.array([]), np.array([]), None, None

    except Exception as e:
        print(f"An unexpected error occurred while reading {file_path}: {e}")
        return np.array([]), np.array([]), None, None

    if 'label' not in df.columns:
        print("CSV file does not contain 'label' column. Please include labels.")
        return np.array([]), np.array([]), None, None

    if exclude_labels:
        initial_count = len(df)
        df = df[~df['label'].isin(exclude_labels)]
        excluded_count = initial_count - len(df)
        print(f"Excluded {excluded_count} rows with labels: {exclude_labels}")

    def parse_token_times(x):
        try:
            result = json.loads(x)
            return result
        except json.JSONDecodeError as e:
            print(f"Error parsing line: {x} - {e}")
            return []  

    df['token_times'] = df['token_times'].apply(parse_token_times)

    traces = df['token_times'].tolist()  
    labels = df['label'].tolist()        
    print("List of labels: ", labels)

    label_encoder = LabelEncoder()
    scaler = StandardScaler()

    if num_samples > 0:
        sampled_df = df.groupby('label').apply(lambda x: x.head(num_samples)).reset_index(drop=True)
        traces = sampled_df['token_times'].tolist()
        labels = sampled_df['label'].tolist()
        print(f"Sampled {num_samples} samples per label.")

    try:
        X = np.array(traces)
    except Exception as e:
        print(f"Error converting traces to numpy array: {e}")
        X = np.array([])

    try:
        y = label_encoder.fit_transform(labels)
    except Exception as e:
        print(f"Error encoding labels: {e}")
        y = np.array([])

    if X.size == 0 or y.size == 0:
        print("Empty feature or label array due to previous errors.")
        return X, y, None, None

    # **Split the DataFrame into Training and Test Sets**
    if train_file_path and test_file_path:
        try:
            train_rows = []
            test_rows = []

            grouped = df.groupby('label')

            for label, group in grouped:
                group_sorted = group.reset_index(drop=True)
                train_subset = group_sorted.head(30)
                test_subset = group_sorted.tail(5)

                train_rows.append(train_subset)
                test_rows.append(test_subset)

            train_df = pd.concat(train_rows).reset_index(drop=True)
            test_df = pd.concat(test_rows).reset_index(drop=True)

            train_df.to_csv(train_file_path, index=False)
            print(f"Training data saved to {train_file_path} with {len(train_df)} rows.")

            test_df.to_csv(test_file_path, index=False)
            print(f"Test data saved to {test_file_path} with {len(test_df)} rows.")

            train_df['token_times'] = train_df['token_times'].apply(parse_token_times)
            test_df['token_times'] = test_df['token_times'].apply(parse_token_times)

            X_train = np.array(train_df['token_times'].tolist())
            y_train = label_encoder.transform(train_df['label'].tolist())

            X_test = np.array(test_df['token_times'].tolist())
            y_test = label_encoder.transform(test_df['label'].tolist())

            X_train = scaler.fit_transform(X_train)
            X_test = scaler.transform(X_test)

            try:
                joblib.dump(scaler, scaler_dump_path)
                print(f"Scaler saved to {scaler_dump_path}.")
            except Exception as e:
                print(f"Error saving scaler to {scaler_dump_path}: {e}")

            try:
                joblib.dump(label_encoder, label_dump_path)
                print(f"Label encoder saved to {label_dump_path}.")
            except Exception as e:
                print(f"Error saving label encoder to {label_dump_path}: {e}")

            print("Shape of X_train:", X_train.shape)
            print("Shape of y_train:", y_train.shape)
            print("Shape of X_test:", X_test.shape)
            print("Shape of y_test:", y_test.shape)

            return X_train, y_train, X_test, y_test

        except Exception as e:
            print(f"Error during train-test split or saving CSV files: {e}")
            return X, y, None, None

    try:
        X = scaler.fit_transform(X)
    except Exception as e:
        print(f"Error during feature scaling: {e}")
        return np.array([]), np.array([]), None, None

    try:
        joblib.dump(scaler, scaler_dump_path)
        print(f"Scaler saved to {scaler_dump_path}.")
    except Exception as e:
        print(f"Error saving scaler to {scaler_dump_path}: {e}")

    try:
        joblib.dump(label_encoder, label_dump_path)
        print(f"Label encoder saved to {label_dump_path}.")
    except Exception as e:
        print(f"Error saving label encoder to {label_dump_path}: {e}")

    print("Shape of X:", X.shape)
    print("Shape of y:", y.shape)

    return X, y, None, None


def fit_model(X: np.ndarray, y: np.ndarray, dump_path: str):
    """
    Trains a Random Forest classifier and saves the trained model.

    Args:
        X (np.ndarray): Preprocessed feature matrix.
        y (np.ndarray): Encoded labels.
        dump_path (str): File path to save the trained model.
    """
    clf = RandomForestClassifier(
        n_estimators=250,
        max_depth=15,
        min_samples_split=10,
        min_samples_leaf=1,
        max_features='sqrt',
        random_state=42
    )
    clf.fit(X, y)
    joblib.dump(clf, dump_path)
    print(f"Random Forest model saved to {dump_path}")

if __name__ == "__main__":
    scaler_path = "model/scaler.pkl"
    label_path = "model/label_encoder.pkl"
    model_path = "model/random_forest_model.pkl"


    parser = argparse.ArgumentParser(description="Train Random Forest on Trace Data")
    parser.add_argument(
        "--Inkind",
        type=str,
        required=True,
        help="The trace Input kinds",
    )
    parser.add_argument(
        "--folder",
        type=str,
        required=True,
        help="The trace folder",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="The temperature for sampling.",
    )
    parser.add_argument(
        "--size",
        type=int,
        default=0,
        help="The training size.",
    )
    parser.add_argument(
        "--exp",
        type=int,
        default=0,
        help="The experiment performed.",
    )
    # **Add arguments for train and test CSV file paths**
    parser.add_argument(
        "--train_csv",
        type=str,
        default="data/train_data.csv",
        help="Path to save the training CSV file.",
    )
    parser.add_argument(
        "--test_csv",
        type=str,
        default="data/test_data.csv",
        help="Path to save the test CSV file.",
    )
    args = parser.parse_args()

    csv_file = "data/bild_prompts_token_times_combined_temp0.8_fb0.7_rb15.csv"

    labels_to_exclude = []  

    X_train, y_train, X_test, y_test = load_data_csv(
        file_path=csv_file,
        total_samples=30,  
        num_samples=args.size,  
        scaler_dump_path=scaler_path,
        label_dump_path=label_path,
        exclude_labels=labels_to_exclude,
        train_file_path=args.train_csv,
        test_file_path=args.test_csv
    )

    if X_train.size == 0 or y_train.size == 0:
        print("Training data loading failed. Exiting.")
        exit(1)
    if X_test is None or y_test is None:
        print("Test data loading failed or was not provided.")
    else:
        print("Training and Test data loaded successfully.")

    fit_model(X_train, y_train, model_path)
