import json
import numpy as np

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import LabelEncoder
import argparse


def load_data(file_path: str, total_samples: int, num_samples: int):
    ### 1. Load data from traces file
    # Load data from JSON file
    with open(file_path, "r") as f:
        data = json.load(f)

    traces = data["traces"]  # List of 1500 lists, each containing 500 floats
    labels = data["labels"]  # List of 1500 strings

    label_encoder = LabelEncoder()

    if num_samples != 0:

        new_traces = []
        new_labels = []

        # Iterate over the labels in chunks of per_label size
        for i in range(0, len(labels), total_samples):
            # Take the first n elements from each chunk
            new_traces.extend(traces[i:i + num_samples])
            new_labels.extend(labels[i:i + num_samples])

        # Convert labels to numerical labels using LabelEncoder
        
        numeric_labels = label_encoder.fit_transform(new_labels)

        # Convert traces to numpy array
        X = np.array(new_traces)
        y = np.array(numeric_labels)
    else:
        numeric_labels = label_encoder.fit_transform(labels)

        # Convert traces to numpy array
        X = np.array(traces)
        y = np.array(numeric_labels)

    # Check the shape of X and y
    # print("Shape of X:", X.shape)
    # print("Shape of y:", y.shape)

    return X, y, label_encoder


def eval(data_size: int, num_samples: int):
    y_pred_full, y_test_full = [], []

    X, y, label_encoder = load_data("out_trace/original_trace_30.out", data_size, num_samples)

    split = 0.3
    if num_samples != 0:
        frac = 1.0 / num_samples

        if frac > 0.3:
            split = frac

    # Re-train 10 times in order to reduce effects of randomness
    for i in range(10):

        ### 2. Split data into X_train, X_test, y_train, y_test with train_test_split
        # Perform train-test-split
        
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=split, random_state=42, stratify=y)

        # Check the shape of the split data
        # print("Shape of X_train:", X_train.shape)
        # print("Shape of X_test:", X_test.shape)
        # print("Shape of y_train:", y_train.shape)
        # print("Shape of y_test:", y_test.shape)

        ### 3. Train classifier with X_train and y_train

        # Grid search
        # Define parameter grid
        param_grid = {
            'n_estimators': [100, 150, 200],
            'max_depth': [10, 15, 20],
            'min_samples_split': [2, 5, 10],
            'min_samples_leaf': [1, 2, 4],
            'max_features': ['sqrt', 'log2']
        }

        # # Initialize classifier
        clf = RandomForestClassifier(random_state=42)

        # # Perform grid search
        grid_search = GridSearchCV(estimator=clf, param_grid=param_grid, cv=3, n_jobs=-1, verbose=2)
        grid_search.fit(X_train, y_train)

        # # Print best parameters
        print(f"At iteration {i} Best parameters found: ", grid_search.best_params_)

        # ### 4. Use classifier to make predictions on X_test. Save the result to a variable called y_pred
        # # Evaluate the best model
        best_model = grid_search.best_estimator_
        y_pred = best_model.predict(X_test)

        # Execution with the best parameters
        # clf = RandomForestClassifier(
        #     n_estimators=200,
        #     max_depth=15,  # Adjusting depth for larger dataset
        #     min_samples_split=5,
        #     min_samples_leaf=1,
        #     max_features='sqrt',
        #     random_state=42
        # )
        # clf.fit(X_train, y_train)
        # y_pred = clf.predict(X_test)

        # Do not modify the next two lines
        y_test_full.extend(y_test)
        y_pred_full.extend(y_pred)

    ### 5. Print classification report using y_test_full and y_pred_full
    report = classification_report(y_test_full, y_pred_full, target_names=label_encoder.classes_)

    # Write the report to a file
    if num_samples != 0:
        tag = num_samples
    else:
        tag = data_size
    with open(f'grid_search_report/classification_report_{tag}_{i}.txt', 'w') as f:
        f.write(report)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="ChatExperiment script")
    parser.add_argument('--data_size', type=int, default=30, help='Size of the data to be collected')
    parser.add_argument('--num_samples', type=int, default=0, help="Number of samples to select from each label")
    args = parser.parse_args()
    eval(args.data_size, args.num_samples)