import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score

from learning.criterion import Criterion
from learning.model.base_models import BaseClassifier
from learning.optimizer import Optimizer


class LogisticRegressionModel(BaseClassifier):
    def __init__(
        self,
        input_dim,
        optimizer_instance: Optimizer,
        criterion_instance: Criterion,
        num_class: int = 2,
    ):
        super().__init__(optimizer_instance, criterion_instance, num_class)
        self.linear = nn.Linear(input_dim, 1)
        self.optimizer_instance.set_from_model(self)
        self.apply_weight_initialization()

    def forward(self, x):
        return torch.sigmoid(self.linear(x))

    def score(self, X_test, y_test):
        self.eval()
        with torch.no_grad():
            y_prob = self.predict(X_test)
            y_pred = (y_prob >= 0.5).long()
        return accuracy_score(y_test, y_pred)


class MLPClassifier(BaseClassifier):
    def __init__(
        self,
        input_dim,
        hidden_dim,
        optimizer_instance: Optimizer,
        criterion_instance: Criterion,
        num_class,
        output_dim=None,
    ):
        out_dim = num_class

        super().__init__(optimizer_instance, criterion_instance, num_class)

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)

        self.optimizer_instance.set_from_model(self)
        self.apply_weight_initialization()

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        logits = self.fc2(x)
        return logits 

    def score(self, X_test, y_test):
        self.eval()
        with torch.no_grad():
            y_pred = self.predict(X_test)

            if self.is_binary:
                predicted_classes = (y_pred >= 0.5).float()
            else:
                predicted_classes = torch.argmax(y_pred, dim=1)

            accuracy = (predicted_classes == y_test).float().mean().item()

        return accuracy
