import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy import pi, sqrt


class ToyNet(nn.Module):
    def __init__(self, negative_slope, noutputs=2):
        super(ToyNet, self).__init__()
        self.name = "BasicNet"
        self.negative_slope = negative_slope
        self.fc1 = nn.Linear(2, noutputs, bias=True)

    def forward(self, x):
        x = self.fc1(x)
        return x


class LinearReg(nn.Module):
    def __init__(self, negative_slope, noutputs=1, nfeatures=2):
        super(LinearReg, self).__init__()
        self.name = "LinearRegNet"
        self.negative_slope = negative_slope
        self.fc1 = nn.Linear(nfeatures, noutputs, bias=False)

    def forward(self, x):
        x = self.fc1(x)
        return x


class TwoLayerLinear(nn.Module):
    def __init__(self, negative_slope, noutputs=1, nfeatures=2, nhiddenneurons=10):
        super(TwoLayerLinear, self).__init__()
        self.name = "TwoLayerLinearNet"
        self.negative_slope = negative_slope
        self.fc1 = nn.Linear(nfeatures, nhiddenneurons, bias=False)
        self.fc2 = nn.Linear(nhiddenneurons, noutputs, bias=False)

    def forward(self, x):
        h1 = self.fc1(x)
        x = self.fc2(h1)
        return x

    def set_weights(self, weights):
        """This function is specifically for the two-layer linear networknd it updates the model weights according to
        the inputs

        Args:
            weights: Weights list in appropriate dimensions to update the model weights

        Returns:

        """
        weight1 = weights[0]
        weight2 = weights[1]
        with torch.no_grad():
            self.fc1.weight = nn.Parameter(weight1)
            self.fc2.weight = nn.Parameter(weight2)
        return


class ToyTwoLinear(nn.Module):
    def __init__(self, n_input, n_features=2, noutputs=2):
        super(ToyTwoLinear, self).__init__()
        self.name = "ToyTwoLinear"
        self.noutputs = noutputs
        self.n_features = n_features
        self.fc1 = nn.Linear(n_input, n_features, bias=False)
        self.fc2 = nn.Linear(n_features, noutputs, bias=False)

    def reset_head(self):
        self.fc2 = nn.Linear(self.n_features, self.noutputs, bias=False)

    def forward(self, x):
        x = self.fc2(self.fc1(x))
        return x

    def set_weights(self, weights):
        """This function is specifically for the two-layer linear networknd it updates the model weights according to
        the inputs

        Args:
            weights: Weights list in appropriate dimensions to update the model weights

        Returns:

        """
        weight1 = weights[0]
        weight2 = weights[1]
        with torch.no_grad():
            self.fc1.weight = nn.Parameter(weight1)
            self.fc2.weight = nn.Parameter(weight2)
        return


class TwoLayerRelu(nn.Module):
    def __init__(self, negative_slope, noutputs=1, nfeatures=2, nhiddenneurons=10):
        super(TwoLayerRelu, self).__init__()
        self.name = "TwoLayerReluNet"
        self.negative_slope = negative_slope
        self.fc1 = nn.Linear(nfeatures, nhiddenneurons, bias=False)
        self.fc2 = nn.Linear(nhiddenneurons, noutputs, bias=False)

    def forward(self, x):
        h1 = F.relu(self.fc1(x))
        x = self.fc2(h1)
        return x

    def set_weights(self, weights):
        """This function is specifically for the two-layer linear network and it updates the model weights according to
        the inputs
        Args: weights: Weights list in appropriate dimensions to update the model weights
        Returns: None
        """
        weight1 = weights[0]
        weight2 = weights[1]
        with torch.no_grad():
            self.fc1.weight = nn.Parameter(weight1)
            self.fc2.weight = nn.Parameter(weight2)
        return


class ToyBatchNet(nn.Module):
    def __init__(self, negative_slope, noutputs=2):
        super(ToyBatchNet, self).__init__()
        self.name = "ToyBatchNet"
        self.negative_slope = negative_slope
        self.fc1 = nn.Linear(2, noutputs, bias=True)
        self.bn1 = nn.BatchNorm1d(num_features=2)

    def forward(self, x):
        x = self.bn1(self.fc1(x))
        return x


class ToyReluNet(nn.Module):
    def __init__(self, negative_slope, noutputs=2):
        super(ToyReluNet, self).__init__()
        self.name = "ToyReluNet"
        self.negative_slope = negative_slope
        self.fc1 = nn.Linear(2, noutputs, bias=True)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), self.negative_slope)
        return x


class ToyBatchReluNet(nn.Module):
    def __init__(self, negative_slope, noutputs=2):
        super(ToyBatchReluNet, self).__init__()
        self.name = "ToyBatchReluNet"
        self.negative_slope = negative_slope
        self.fc1 = nn.Linear(2, noutputs, bias=True)
        self.bn1 = nn.BatchNorm1d(num_features=2)

    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.fc1(x)), self.negative_slope)
        return x


class ToyLongBatchReluNet(nn.Module):
    def __init__(self, negative_slope, noutputs=2):
        super(ToyLongBatchReluNet, self).__init__()
        self.name = "ToyLongBatchReluNet"
        self.negative_slope = negative_slope
        self.fc1 = nn.Linear(2, noutputs, bias=True)
        self.bn1 = nn.BatchNorm1d(num_features=2)
        self.fc2 = nn.Linear(2, noutputs, bias=True)
        self.bn2 = nn.BatchNorm1d(num_features=2)

    def forward(self, x):
        x = self.bn2(self.fc2(F.leaky_relu(self.bn1(self.fc1(x)), self.negative_slope)))
        return x


class ToyResNet(nn.Module):
    def __init__(self, negative_slope, noutputs=2):
        super(ToyResNet, self).__init__()
        self.name = "ToyResNet"
        self.negative_slope = negative_slope
        self.n_features = 2
        self.fc1 = nn.Linear(2, self.n_features, bias=True)

    def forward(self, x):
        x = x + F.leaky_relu(self.fc1(x), self.negative_slope)
        return x


class ToyBatchResNet(nn.Module):
    def __init__(self, negative_slope, noutputs=2):
        super(ToyBatchResNet, self).__init__()
        self.name = "ToyBatchResNet"
        self.negative_slope = negative_slope
        self.n_features = 2
        self.fc1 = nn.Linear(2, self.n_features, bias=True)
        self.bn1 = nn.BatchNorm1d(num_features=2)

    def forward(self, x):
        x = x + F.leaky_relu(self.bn1(self.fc1(x)), self.negative_slope)
        return x


class ToyLongBatchResNet(nn.Module):
    def __init__(self, negative_slope, noutputs=2):
        super(ToyLongBatchResNet, self).__init__()
        self.name = "ToyLongBatchResNet"
        self.negative_slope = negative_slope
        self.n_features = 2
        self.fc1 = nn.Linear(2, self.n_features, bias=True)
        self.bn1 = nn.BatchNorm1d(num_features=2)
        self.fc2 = nn.Linear(2, self.n_features, bias=True)
        self.bn2 = nn.BatchNorm1d(num_features=2)

    def forward(self, x):
        x = x + self.bn2(self.fc2(F.leaky_relu(self.bn1(self.fc1(x)), self.negative_slope)))


class ToySVM(nn.Module):
    def __init__(self, negative_slope, noutputs=1):
        super(ToySVM, self).__init__()
        self.name = "ToySVM"
        self.negative_slope = negative_slope
        self.n_features = 2
        self.fc1 = nn.Linear(self.n_features, noutputs)

    def forward(self, x):
        x = self.fc1(x)
        return x


class ToySVMpoly(nn.Module):
    def __init__(self, negative_slope, noutputs=1):
        super(ToySVMpoly, self).__init__()
        self.name = "ToySVMpoly"
        self.negative_slope = negative_slope
        self.degree = 2
        self.n_features = 7
        self.fc1 = nn.Linear(self.n_features, noutputs)

    def forward(self, x):
        # feature transformation for kernel k(x,x')=(1+x^Tx')^d with d = 2
        feature_0 = 1 + 0 * x[:, 0]  # = 1
        feature_1 = sqrt(2) * x[:, 0]
        feature_2 = sqrt(2) * x[:, 1]
        feature_3 = x[:, 0] * x[:, 0]
        feature_4 = x[:, 0] * x[:, 1]
        feature_5 = x[:, 1] * x[:, 0]
        feature_6 = x[:, 1] * x[:, 1]
        features = torch.stack((feature_0, feature_1, feature_2, feature_3, feature_4, feature_5, feature_6), dim=1)
        x = self.fc1(features)
        return x


class ToyGaussianNaiveBayes(nn.Module):
    def __init__(self, negative_slope, noutputs=2):
        super(ToyGaussianNaiveBayes, self).__init__()
        self.name = "ToyGaussianNaiveBayes"
        self.negative_slope = negative_slope
        self.n_features = 2
        self.register_parameter("means", nn.Parameter(torch.eye(noutputs, self.n_features)))
        self.register_parameter("variances", nn.Parameter(torch.ones(noutputs, self.n_features)))
        self.register_parameter("priors", nn.Parameter(torch.empty(noutputs, 1).uniform_(0, 1)))

    def forward(self, x):
        x = x.repeat(1, 2)
        coef = (1 / (2 * pi)) / torch.prod(self.variances, 1)
        x = -(torch.flatten(self.means) - x) ** 2 / (2 * torch.flatten(self.variances ** 2))
        x = torch.t(torch.stack((torch.sum(x[:, 0:2], 1), torch.sum(x[:, 2:4], 1)), 0))
        x = torch.exp(x)
        x = torch.t(coef) * x
        return x