import json
import numpy as np
import joblib
import argparse
import pandas as pd
import os

from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder, StandardScaler
from typing import Tuple


def load_model(scaler_dump_path: str, label_dump_path: str, model_dump_path: str):
    """
    Loads the scaler, label encoder, and trained model from disk.

    Args:
        scaler_dump_path (str): Path to the saved StandardScaler.
        label_dump_path (str): Path to the saved LabelEncoder.
        model_dump_path (str): Path to the saved RandomForest model.

    Returns:
        Tuple[StandardScaler, LabelEncoder, RandomForestClassifier]: Loaded scaler, label encoder, and model.
    """
    try:
        scaler = joblib.load(scaler_dump_path)
        print(f"Scaler loaded from {scaler_dump_path}")
    except Exception as e:
        print(f"Error loading scaler from {scaler_dump_path}: {e}")
        scaler = None

    try:
        label_encoder = joblib.load(label_dump_path)
        print(f"Label Encoder loaded from {label_dump_path}")
    except Exception as e:
        print(f"Error loading label encoder from {label_dump_path}: {e}")
        label_encoder = None

    try:
        clf = joblib.load(model_dump_path)
        print(f"Random Forest model loaded from {model_dump_path}")
    except Exception as e:
        print(f"Error loading model from {model_dump_path}: {e}")
        clf = None

    return scaler, label_encoder, clf


def load_test_data(file_path: str, scaler: StandardScaler, num_samples_per_label: int = 5) -> np.ndarray:
    """
    Loads and preprocesses test data from a CSV file, selecting the last 'num_samples_per_label' samples per label.

    Args:
        file_path (str): Path to the test CSV file.
        scaler (StandardScaler): Fitted scaler for feature scaling.
        num_samples_per_label (int): Number of last samples to select per label.

    Returns:
        np.ndarray: Scaled feature matrix.
    """
    if not os.path.exists(file_path):
        print(f"Test file {file_path} does not exist.")
        return np.array([])

    X, y = load_data_csv(
        file_path=file_path,
        total_samples=0,       
        num_samples=num_samples_per_label,  
        scaler_dump_path=None, 
        label_dump_path=None    
    )

    if scaler is not None and X.size > 0:
        try:
            X_scaled = scaler.transform(X)
            print("Test data scaled successfully.")
            return X_scaled
        except Exception as e:
            print(f"Error during scaling test data: {e}")
            return np.array([])
    else:
        print("Scaler is not loaded or test data is empty.")
        return np.array([])


def predict_labels(X: np.ndarray, clf: RandomForestClassifier) -> np.ndarray:
    """
    Generates predictions using the trained classifier.

    Args:
        X (np.ndarray): Scaled feature matrix.
        clf (RandomForestClassifier): Trained Random Forest model.

    Returns:
        np.ndarray: Predicted label indices.
    """
    if clf is None:
        print("Classifier is not loaded.")
        return np.array([])

    try:
        predicted_labels_encoded = clf.predict(X)
        print("Predictions generated successfully.")
        return predicted_labels_encoded
    except Exception as e:
        print(f"Error during prediction: {e}")
        return np.array([])


def save_predicted_labels(predicted_labels: np.ndarray, label_encoder: LabelEncoder, output_file: str):
    """
    Saves the predicted labels to a text file.

    Args:
        predicted_labels (np.ndarray): Predicted label indices.
        label_encoder (LabelEncoder): Fitted label encoder.
        output_file (str): Path to the output text file.
    """
    if predicted_labels.size == 0:
        print("No predictions to save.")
        return

    if label_encoder is not None:
        try:
            predicted_labels_decoded = label_encoder.inverse_transform(predicted_labels)
            print("Labels inverse transformed successfully.")
        except Exception as e:
            print(f"Error during label inverse transformation: {e}")
            predicted_labels_decoded = predicted_labels
    else:
        print("Label encoder is not loaded. Saving encoded labels.")
        predicted_labels_decoded = predicted_labels

    try:
        with open(output_file, "w") as file:
            for label in predicted_labels_decoded:
                file.write(f"{label}\n")
        print(f"Predicted labels have been written to {output_file}")
    except Exception as e:
        print(f"Error writing predicted labels to {output_file}: {e}")


def load_data_csv(
    file_path: str,
    total_samples: int,
    num_samples: int,
    scaler_dump_path: str,
    label_dump_path: str
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Loads and preprocesses data from a CSV file containing trace data.

    Args:
        file_path (str): Path to the CSV data file.
        total_samples (int): Total samples per label (unused in current logic).
        num_samples (int): Number of samples to select per label. If 0, use all samples.
        scaler_dump_path (str): File path to save the fitted StandardScaler.
        label_dump_path (str): File path to save the fitted LabelEncoder.

    Returns:
        Tuple[np.ndarray, np.ndarray]: Preprocessed features X and labels y.
    """
    df = pd.DataFrame()  

    try:
        df = pd.read_csv(
            file_path,
            engine='python',               
        )
        print(f"Loaded {len(df)} rows from {file_path}")
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return np.array([]), np.array([])

    if 'label' not in df.columns:
        print("CSV file does not contain 'label' column. Please include labels.")
        return np.array([]), np.array([])

    def parse_token_times(x):
        try:
            return json.loads(x)
        except json.JSONDecodeError as e:
            print(f"Error parsing line: {x} - {e}")
            return []  

    df['token_times'] = df['token_times'].apply(parse_token_times)

    traces = df['token_times'].tolist()  
    labels = df['label'].tolist()        

    label_encoder = LabelEncoder()
    scaler = StandardScaler()

    if num_samples > 0:
        sampled_df = df.groupby('label', group_keys=False).apply(lambda x: x.tail(num_samples))
        print(f"Selected the last {num_samples} samples per label.")
        traces = sampled_df['token_times'].tolist()
        labels = sampled_df['label'].tolist()


    try:
        X = np.array(traces)
    except Exception as e:
        print(f"Error converting traces to numpy array: {e}")
        X = np.array([])

    try:
        y = label_encoder.fit_transform(labels)
    except Exception as e:
        print(f"Error encoding labels: {e}")
        y = np.array([])

    if X.size == 0 or y.size == 0:
        print("Empty feature or label array due to previous errors.")
        return X, y

    try:
        X = scaler.fit_transform(X)
    except Exception as e:
        print(f"Error during feature scaling: {e}")
        return np.array([]), np.array([])

    try:
        if scaler_dump_path:
            joblib.dump(scaler, scaler_dump_path)
            print(f"Scaler saved to {scaler_dump_path}")
    except Exception as e:
        print(f"Error saving scaler to {scaler_dump_path}: {e}")

    try:
        if label_dump_path:
            joblib.dump(label_encoder, label_dump_path)
            print(f"Label Encoder saved to {label_dump_path}")
    except Exception as e:
        print(f"Error saving label encoder to {label_dump_path}: {e}")

    # Print shapes for verification
    print("Shape of X:", X.shape)
    print("Shape of y:", y.shape)

    return X, y


def eval_model(tester: np.ndarray, label_encoder: LabelEncoder, clf: RandomForestClassifier, out_file: str):
    """
    Evaluates the model on the test data and saves predicted labels.

    Args:
        tester (np.ndarray): Scaled feature matrix for testing.
        label_encoder (LabelEncoder): Fitted label encoder.
        clf (RandomForestClassifier): Trained Random Forest model.
        out_file (str): Path to save the predicted labels.
    """
    if tester.size == 0:
        print("Tester data is empty. Cannot perform evaluation.")
        return

    if clf is None:
        print("Classifier is not loaded. Cannot perform evaluation.")
        return

    if label_encoder is None:
        print("Label encoder is not loaded. Predicted labels will be saved as encoded values.")

    try:
        predicted_labels_encoded = clf.predict(tester)
        print("Predictions generated successfully.")
    except Exception as e:
        print(f"Error during prediction: {e}")
        return

    if label_encoder is not None:
        try:
            predicted_labels = label_encoder.inverse_transform(predicted_labels_encoded)
            print("Predicted labels inverse transformed successfully.")
        except Exception as e:
            print(f"Error during label inverse transformation: {e}")
            predicted_labels = predicted_labels_encoded
    else:
        print("Label encoder is not loaded. Saving encoded labels.")
        predicted_labels = predicted_labels_encoded

    try:
        with open(out_file, "w") as file:
            for label in predicted_labels:
                file.write(f"{label}\n")
        print(f"Predicted labels have been written to {out_file}")
    except Exception as e:
        print(f"Error writing predicted labels to {out_file}: {e}")


if __name__ == "__main__":
    # Define paths for loading scaler, label encoder, and model
    scaler_path = "model/scaler.pkl"
    label_path = "model/label_encoder.pkl"
    model_path = "model/random_forest_model.pkl"

    # Setup argument parser
    parser = argparse.ArgumentParser(description="Test Random Forest Model on Trace Data")
    parser.add_argument(
        "--test_csv",
        type=str,
        required=True,
        help="Path to the test CSV file containing trace data.",
    )
    parser.add_argument(
        "--scaler_path",
        type=str,
        default=scaler_path,
        help="Path to the saved StandardScaler.",
    )
    parser.add_argument(
        "--label_path",
        type=str,
        default=label_path,
        help="Path to the saved LabelEncoder.",
    )
    parser.add_argument(
        "--model_path",
        type=str,
        default=model_path,
        help="Path to the saved RandomForest model.",
    )
    parser.add_argument(
        "--output_file",
        type=str,
        default="labels/predicted_labels.txt",
        help="Path to save the predicted labels.",
    )
    parser.add_argument(
        "--num_samples_per_label",
        type=int,
        default=5,
        help="Number of last samples to select per label for testing.",
    )
    args = parser.parse_args()

    # Load the scaler, label encoder, and model
    scaler, label_encoder, clf = load_model(args.scaler_path, args.label_path, args.model_path)

    # Load and preprocess the test data, selecting the last 5 samples per label
    X_test = load_test_data(args.test_csv, scaler, args.num_samples_per_label)

    # Make predictions and save to file
    eval_model(X_test, label_encoder, clf, args.output_file)
