import torch
from torch import nn
from torch.nn import functional as F
from ..utils.model import fc_nn


class LinearGenerator(nn.Module):
    def __init__(self, xdim: int, scale_x: torch.Tensor = None, scale_intercept: torch.Tensor = None,
                 shift_x: torch.Tensor = None, shift_intercept: torch.Tensor = None, seed=None):
        initial_seed = torch.seed()
        seed = torch.randint(0, 1000000, (1,)).item() if seed is None else seed
        torch.manual_seed(seed)
        super().__init__()
        self.modelslope = nn.Linear(xdim, 1)
        self.modelintercept = nn.Linear(xdim, 1)
        if scale_x is not None:
            self.modelslope.weight.data = scale_x
        if scale_intercept is not None:
            self.modelslope.bias.data = scale_intercept
        if shift_x is not None:
            self.modelintercept.weight.data = shift_x
        if shift_intercept is not None:
            self.modelintercept.bias.data = shift_intercept
        torch.manual_seed(initial_seed)

    def forward(self, y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        # return y*F.softplus(self.modelslope(x).squeeze(-1), beta=2.)+self.modelintercept(x).squeeze(-1)
        return y*self.modelslope(x).squeeze(-1)+self.modelintercept(x).squeeze(-1)


class MLPGenerator(nn.Module):
    def __init__(self, xdim: int, hidden_layers: list = [20, 20], activation: nn.Module = nn.ReLU()):
        super().__init__()
        self.model = fc_nn(xdim+1, hidden_layers, 1, activation)

    def forward(self, y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        return self.model(torch.cat([x, y.unsqueeze(-1)], dim=-1)).squeeze(-1)


class RFFGenerator(nn.Module):
    def __init__(self, xdim: int, n_features: int = 20, sigma: float = 1.0, seed: int = None):
        super().__init__()
        if seed is None:
            seed = torch.randint(0, 1000000, (1,)).item()
        self.init_seed = seed
        torch.manual_seed(seed)
        # Create RFF vectors
        self.n_features = n_features
        w_vectors = torch.randn(xdim + 1, n_features) / sigma
        bias = torch.rand(n_features) * 2 * torch.pi
        self.bias = bias
        self.w_vectors: torch.Tensor
        self.register_buffer('w_vectors', w_vectors)
        self.theta = nn.Parameter(torch.randn(n_features) * 0.1)

    def forward(self, y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        # Compute RFF features
        inputs = torch.cat([x, y.unsqueeze(-1)], dim=-1)
        projection = inputs @ self.w_vectors+self.bias  # Shape: (batch_size, n_features)
        rff_features = torch.cos(projection)/float(self.n_features)**0.5
        # Linear combination with learnable weights
        return (rff_features @ self.theta).squeeze(-1)


def myplus(x):
    return torch.exp(-F.relu(-x))+F.relu(x)


class PositiveMLPGenerator(nn.Module):
    def __init__(self, xdim: int, hidden_layers: list = [20, 20], activation: nn.Module = nn.ReLU(), weight_scale=1.):
        super().__init__()
        self.model = fc_nn(xdim, hidden_layers, 2, activation)
        self.weight_scale = weight_scale
        if self.weight_scale != 1.:
            for param in self.model.parameters():
                param.data = weight_scale*param.data
        self.myplus = myplus

    def get_shift_and_scale(self, x: torch.Tensor) -> torch.Tensor:
        model_output = self.model(x)
        return model_output[..., 0], self.myplus(model_output[..., 1])

    def forward(self, y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        model_output = self.model(x)
        return y*self.myplus(model_output[..., 1]) + model_output[..., 0]


class ShiftGenerator(nn.Module):
    def __init__(self, xdim: int, hidden_layers: list = [20, 20], activation: nn.Module = nn.ReLU(), weight_scale=1.):
        super().__init__()
        self.model = fc_nn(xdim, hidden_layers, 1, activation)
        for param in self.model.parameters():
            param.data = weight_scale*param.data

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x).squeeze(-1)
