import torch
import torch.nn as nn
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch
import torch.nn as nn
import torch.nn.functional as F


class LinearRegressor(nn.Module):
    def __init__(self, model_input_dim ):
        """
        2D Linear Regressor: Takes x and predicts y, outputs (x, yhat)
        """
        super(LinearRegressor, self).__init__()
        self.model_input_dim = model_input_dim
        self.model = nn.Linear(model_input_dim, 1)  # Linear transformation for a single feature
        self._initialize_weights()

    def _initialize_weights(self):
        """Initializes weights using Kaiming Normal for stability"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='linear')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        """Performs forward pass and outputs (x, yhat)."""
        x_input = x[:, :self.model_input_dim].squeeze()  # Extract only the first dimension of x
        yhat = self.model(x_input)  # Predict yhat
        x_y_hat_concat = torch.cat((x_input, yhat), dim=-1)  # Output (x, yhat)
        return {"sample": x_y_hat_concat, "aux_loss": 0.0}

import torch
import torch.nn as nn

class MLPRegressor(nn.Module):
    def __init__(self, model_input_dim, hidden_dim=16):
        """
        2-layer MLP Regressor: Takes x and predicts y, returns (x, y_hat)
        """
        super(MLPRegressor, self).__init__()
        self.model_input_dim = model_input_dim
        self.model = nn.Sequential(
            nn.Linear(model_input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Softmax()
        )
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize weights with Kaiming Normal"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                nn.init.zeros_(m.bias)

    def forward(self, x):
        """
        Input: x of shape (B, D)
        Output: dict with key "sample": (B, D+1) with concatenated (x, y_hat)
        """
        x_input = x[:, :self.model_input_dim]
        y_hat = 5.0 * self.model(x_input)  # (B, 1)
        x_y_hat = torch.cat((x_input, y_hat), dim=-1)  # (B, D+1)
        return {"sample": x_y_hat, "aux_loss": 0.0}


import torch
import torch.nn as nn
import torch.nn.functional as F

class PolynomialRegressor(nn.Module):
    def __init__(self, degree=1):
        """
        Polynomial Regressor: Takes x and predicts y, outputs (x, yhat)
        """
        super(PolynomialRegressor, self).__init__()
        self.degree = degree
        self.model = nn.Linear(degree, 1)  # Linear transformation for polynomial features
        self._initialize_weights()

    def _initialize_weights(self):
        """Initializes weights with zeros."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.zeros_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def _polynomial_features(self, x):
        """Expands input x into polynomial features up to the specified degree."""
        x_expanded = torch.cat([x ** i for i in range(1, self.degree + 1)], dim=-1)
        return x_expanded

    def get_params(self):
        """Returns model parameters as a dictionary"""
        return {
            "weights": self.model.weight.data.clone().cpu().numpy(),
            "bias": self.model.bias.data.clone().cpu().numpy()
        }

    def forward(self, x):
        """Performs forward pass and outputs (x, yhat)."""
        x_input = x[:, 0].unsqueeze(-1)  # Extract only the first dimension of x
        x_poly = self._polynomial_features(x_input)  # Expand to polynomial features
        yhat = self.model(x_poly)  # Predict yhat
        x_y_hat_concat = torch.cat((x_input, yhat), dim=-1)  # Output (x, yhat)
        return {"sample": x_y_hat_concat, "aux_loss": 0.0}



# Example usage
if __name__ == "__main__":
    batch_size = 16
    model = LinearRegressor()

    sample_input = torch.randn(batch_size, 2)  # First dim is x, second dim ignored
    output = model(sample_input)
    print("Output shape:", output.shape)  # Should be (batch_size, 2)
