import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import r2_score
from sklearn.preprocessing import PolynomialFeatures

from learning.criterion import Criterion
from learning.model.base_models import BaseRegressor
from learning.optimizer import Optimizer


class LinearRegressionModel(BaseRegressor):
    def __init__(
        self, input_dim, optimizer_instance: Optimizer, criterion_instance: Criterion
    ):
        super().__init__(optimizer_instance, criterion_instance)
        self.linear = nn.Linear(input_dim, 1)
        self.optimizer_instance.set_from_model(self)
        self.apply_weight_initialization()

    def forward(self, x):
        return self.linear(x)

    def score(self, X_test, y_test):
        self.eval()
        with torch.no_grad():
            y_pred = self.forward(X_test).squeeze()

        y_test_np = y_test.numpy()
        y_pred_np = y_pred.numpy()

        return r2_score(y_test_np, y_pred_np)


class PolynomialRegressionModel(BaseRegressor):
    def __init__(self, input_dim, degree):
        super().__init__()
        self.degree = degree
        self.poly = PolynomialFeatures(degree=self.degree)
        self.linear = nn.Linear(
            self.poly.fit_transform(np.zeros((1, input_dim))).shape[1], 1
        )

    def forward(self, x):
        x_poly = self._transform_to_polynomial(x)
        return self.linear(x_poly)

    def _transform_to_polynomial(self, x):
        x_np = x.detach().cpu().numpy()
        x_poly_np = self.poly.fit_transform(x_np)
        return torch.tensor(x_poly_np, dtype=torch.float32, device=x.device)

    def fit(self, X_train, y_train):
        X_train_poly = self._transform_to_polynomial(X_train)
        return super().fit(X_train_poly, y_train)


class MLPRegressor(BaseRegressor):
    def __init__(self, input_dim, hidden_dim, output_dim=1):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)