import numpy as np
import cv2 as cv
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
import os
import matplotlib.pyplot as plt
import seaborn as sns

def process_data(path):
    data = []
    label = []
    all_folders = os.listdir(path)
    for idx, folder in enumerate(all_folders):
        numbers = os.listdir(os.path.join(path, folder))
        for number in numbers:
            img = cv.imread(os.path.join(path, folder, number), 0)
            img = img.reshape(-1)
            data.append(img)
            label.append(idx)
    return np.array(data), np.array(label)

def train_and_evaluate():
    train_data, train_label = process_data('../MNIST/raw/TRAIN/')
    test_data, test_label = process_data('../MNIST/raw/TEST/')

    print("Standardizing data...")
    scaler = StandardScaler()
    train_data = scaler.fit_transform(train_data)
    test_data = scaler.transform(test_data)

    print("Training the model...")
    model = SGDClassifier(loss='log_loss', penalty='elasticnet', l1_ratio=0.5, alpha=0.0001, max_iter=2000, tol=1e-3, random_state=42)
    model.fit(train_data, train_label)

    predictions = model.predict(test_data)
    accuracy = accuracy_score(test_label, predictions)
    print(f'Test accuracy: {accuracy:.4f}')

    n_features = train_data.shape[1]  
    n_classes = len(np.unique(train_label)) 
    total_params = n_features * n_classes + n_classes 
    print(f'Total number of parameters: {total_params}')

    cm = confusion_matrix(test_label, predictions)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False, xticklabels=np.unique(test_label), yticklabels=np.unique(test_label))
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.savefig('confusion_matrix.png')

if __name__ == "__main__":
    train_and_evaluate()