#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 11 15:04:49 2025

This module contains basic functionality for fitting and testing supervised
machine learning models for classification.
"""

import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC, LinearSVC
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.decomposition import PCA


class SupervisedMachineLearning(object):
    """
    Wrapper class for fitting and testing supervised models for classification.
    """
    def __init__(self, type_, **kwargs):
        """
        Inputs:
        -------
            type_ : str
                Options are "LogisticRegression", "KNeighborsClassifier", "SVC", "LinearSVC",
                "MLPClassifier", "GradientBoostingClassifier", "PCA_LogisticRegression", "PCA_SVC".
        """
        super(SupervisedMachineLearning, self).__init__()
        type_options = ["LogisticRegression", "KNeighborsClassifier", "SVC", "LinearSVC",
                        "MLPClassifier", "GradientBoostingClassifier", "PCA_LogisticRegression",
                        "PCA_SVC"]
        assert type_ in type_options, "The chosen supervised machine learning model is not supported."
        self.type_ = type_
        if type_ == "LogisticRegression":
            self.SMLmodel = LogisticRegression(**kwargs)
        elif type_ == "KNeighborsClassifier":
            self.SMLmodel = KNeighborsClassifier(**kwargs)
        elif type_ == "SVC":
            self.SMLmodel = make_pipeline(StandardScaler(), SVC(**kwargs))
        elif type_ == "LinearSVC":
            self.SMLmodel = LinearSVC(**kwargs)
        elif type_ == "MLPClassifier":
            self.SMLmodel = MLPClassifier(**kwargs)
        elif type_ == "GradientBoostingClassifier":
            self.SMLmodel = GradientBoostingClassifier(**kwargs)
        elif type_ == "PCA_LogisticRegression":
            self.PCAmodel = PCA(kwargs["n_components"])
            kwargs.pop("n_components")
            self.SMLmodel = LogisticRegression(**kwargs)
        elif type_ == "PCA_SVC":
            self.PCAmodel = PCA(kwargs["n_components"])
            kwargs.pop("n_components")
            self.SMLmodel = make_pipeline(StandardScaler(), SVC(**kwargs))

    def one_hot_encoding(self, labels):
        """
        Inputs:
        -------
            labels : ndarray, size=(n,)

        Output:
        -------
            one_hot_labels : ndarray, size=(n, d')
        """
        enc = OneHotEncoder()
        enc.fit(np.expand_dims(labels, axis=1))
        one_hot_labels = enc.transform(np.expand_dims(labels, axis=1)).toarray()
        return one_hot_labels

    def fit(self, train_features, train_labels):
        """
        Inputs:
        -------
            train_features : ndarray, size=(n, d)
                The input features used for fitting the supervised model.
            train_labels : ndarray, size=(n,)
                The labels used for fitting the supervised model.
        """
        if (self.type_ == "LogisticRegression")  or (self.type_ == "KNeighborsClassifier") \
        or (self.type_ == "SVC") or (self.type_ == "LinearSVC") or (self.type_ == "GradientBoostingClassifier"):
            self.SMLmodel.fit(train_features, train_labels)
            self.labels = np.unique(train_labels)
        elif self.type_ == "MLPClassifier":
            one_hot_labels = self.one_hot_encoding(train_labels)
            self.SMLmodel.fit(train_features, one_hot_labels)
            self.labels = np.unique(train_labels)
        elif (self.type_ == "PCA_LogisticRegression") or (self.type_ == "PCA_SVC"):
            self.PCAmodel.fit(train_features)
            train_PCA = self.PCAmodel.transform(train_features)
            self.SMLmodel.fit(train_PCA, train_labels)
            self.labels = np.unique(train_labels)

    def predict(self, test_features):
        """
        Inputs:
        -------
            X_test : ndarray, size=(n, d)
                Data on which we evaluate the conformal score.
        """
        if (self.type_ == "LogisticRegression") or (self.type_ == "KNeighborsClassifier") or (self.type_ == "SVC") \
        or (self.type_ == "LinearSVC") or (self.type_ == "MLPClassifier") or (self.type_ == "GradientBoostingClassifier"):
            predictions = self.SMLmodel.predict(test_features)
        elif (self.type_ == "PCA_LogisticRegression") or (self.type_ == "PCA_SVC"):
            predictions = self.SMLmodel.predict(self.PCAmodel.transform(test_features))
        return predictions

    def softmax_scores(self, test_features):
        """
        """
        if (self.type_ == "LogisticRegression") or (self.type_ == "KNeighborsClassifier") or (self.type_ == "SVC") \
        or (self.type_ == "LinearSVC") or (self.type_ == "MLPClassifier") or (self.type_ == "GradientBoostingClassifier"):
            probs = self.SMLmodel.predict_proba(test_features)
        elif (self.type_ == "PCA_LogisticRegression") or (self.type_ == "PCA_SVC"):
            probs = self.SMLmodel.predict_proba(self.PCAmodel.transform(test_features))
        return probs

    def score(self, test_features, test_labels):
        """
        """
        if (self.type_ == "LogisticRegression") or (self.type_ == "KNeighborsClassifier") or (self.type_ == "SVC") \
        or (self.type_ == "LinearSVC") or (self.type_ == "GradientBoostingClassifier"):
            score = self.SMLmodel.score(test_features, test_labels)
            assert np.all(self.labels == np.unique(test_labels)), "Matching training and test labels!"
        elif self.type_ == "MLPClassifier":
            one_hot_labels = self.one_hot_encoding(test_labels)
            score = self.SMLmodel.score(test_features, one_hot_labels)
            assert np.all(self.labels == np.unique(test_labels)), "Matching training and test labels!"
        elif (self.type_ == "PCA_LogisticRegression") or (self.type_ == "PCA_SVC"):
            score = self.SMLmodel.score(self.PCAmodel.transform(test_features), test_labels)
            assert np.all(self.labels == np.unique(test_labels)), "Matching training and test labels!"
        return score
