import numpy as np
import time
import os
import argparse 
from sklearn.metrics import classification_report, f1_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
import xgboost as xgb
from catboost import CatBoostClassifier
import lightgbm as lgb
import warnings
warnings.filterwarnings("ignore")

DATA_PATH_PREFIX = 'processed'
OUTPUT_DIR = 'final_industrial_results'
CLASS_NAMES = ['Normal', 'Noise', 'Surface', 'Corona', 'Void']

def run_experiment(model_name):
    print("--- Loading data for Centralized Learning ---")
    X_train_orig = np.load(os.path.join(DATA_PATH_PREFIX, 'X_train.npy'))
    y_train_orig = np.load(os.path.join(DATA_PATH_PREFIX, 'y_train.npy'))
    X_test_orig = np.load(os.path.join(DATA_PATH_PREFIX, 'X_val.npy'))
    y_test_orig = np.load(os.path.join(DATA_PATH_PREFIX, 'y_val.npy'))
    edge_ids_test = np.load(os.path.join(DATA_PATH_PREFIX, 'edge_ids_val.npy'))

    X_train_reshaped = X_train_orig.reshape(X_train_orig.shape[0], -1)
    X_test_reshaped = X_test_orig.reshape(X_test_orig.shape[0], -1)
    
    print(f"\n--- Training {model_name} Classifier ---")
    if model_name == 'randomforest':
        model = RandomForestClassifier(random_state=42, n_jobs=-1, n_estimators=100)
    elif model_name == 'xgboost':
        model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='mlogloss', random_state=42, n_jobs=-1)
    elif model_name == 'lightgbm':
        model = lgb.LGBMClassifier(random_state=42, n_jobs=-1)
    elif model_name == 'catboost':
        model = CatBoostClassifier(objective='MultiClass', random_state=42, thread_count=-1, verbose=0)
    else:
        raise ValueError(f"Unsupported model: {model_name}")

    model.fit(X_train_reshaped, y_train_orig)
    print("Training complete.")

    print("\n--- Overall Performance Evaluation ---")
    y_pred_total = model.predict(X_test_reshaped)
    print("Overall Classification Report:")
    print(classification_report(y_test_orig, y_pred_total, target_names=CLASS_NAMES, digits=4))
    
    print("\n--- Edge-wise F1-Score Evaluation ---")
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    report_path = os.path.join(OUTPUT_DIR, f'report_{model_name}.txt')
    with open(report_path, 'w', encoding='utf-8') as f:
        f.write(f"--- Overall Classification Report for {model_name} ---\n")
        f.write(classification_report(y_test_orig, y_pred_total, target_names=CLASS_NAMES, digits=4) + "\n")
        
    edge_f1_scores = {}
    with open(report_path, 'a', encoding='utf-8') as f:
        f.write("\n--- Edge-wise F1-Score Evaluation ---\n")
        for edge_id in np.unique(edge_ids_test):
            indices = np.where(edge_ids_test == edge_id)[0]
            if len(indices) == 0: continue
            y_true_for_edge = y_test_orig[indices]
            y_pred_for_edge = y_pred_total[indices]
            f1 = f1_score(y_true_for_edge, y_pred_for_edge, average='macro', zero_division=0)
            edge_f1_scores[edge_id] = f1
            result_line = f"Edge {edge_id}: Macro F1-Score = {f1:.4f}\n"
            print(result_line, end='')
            f.write(result_line)

    print("\n--- Inference Time Evaluation ---")
    num_samples_to_test = 1000
    inference_times = []
    for i in range(num_samples_to_test):
        sample = X_test_reshaped[i:i+1]
        start_time = time.time()
        model.predict(sample)
        end_time = time.time()
        inference_times.append((end_time - start_time) * 1000)
    avg_inference_time = np.mean(inference_times)
    print(f" Average Inference Time per sample: {avg_inference_time:.4f} ms")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Centralized learning with various classifiers.")
    parser.add_argument('--model', type=str, required=True,
                        choices=['randomforest', 'xgboost', 'lightgbm', 'catboost'],
                        help="Classifier model to train.")
    args = parser.parse_args()
    
    run_experiment(model_name=args.model)