import torch
from ripser import ripser
from tqdm import tqdm


class Net(torch.nn.Module):
    """
    A simple feed-forward neural network with two hidden layers.

    This is used in the toy example.

    """

    def __init__(
        self,
        input_units: int,
        middle_layer_units: int = 20,
        normalize_embeddings=False,
        n_output=2,
    ):
        """
        Params:
            - input_units: int, dim of the input
            - middle_layer_units: int, dim of the middle layer
            - normalize_embeddings: bool, whether to normalize the embeddings.

        """
        super(Net, self).__init__()

        self.fc1 = torch.nn.Linear(input_units, middle_layer_units)
        self.fc2 = torch.nn.Linear(middle_layer_units, 10)

        self.fc3 = torch.nn.Linear(10, n_output)

        self.normalize_embeddings = normalize_embeddings

        self.relu = torch.nn.ReLU()
        self.softmax = torch.nn.Softmax(dim=1)

        self.__initialise_weights()

    def first_two_layers(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)

        return torch.nn.functional.normalize(x) if self.normalize_embeddings else x

    def forward(self, x):
        x = self.first_two_layers(x)
        return self.softmax(self.fc3(x))

    def encode(self, x):
        x = self.first_two_layers(x)
        return x.detach().numpy()

    def __initialise_weights(self):
        # xavier initialization
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        torch.nn.init.xavier_uniform_(self.fc2.weight)
        torch.nn.init.xavier_uniform_(self.fc3.weight)


def normalised_h0(data, normalised=False):
    """
    Compute the (possibly normalised) persistence times
    for the 0-th homology group.

    Args:
        - data: np.array, the data to use ripser on.
        - normalised: bool, whether to normalise the
            persistence times.
    """

    v = ripser(data)["dgms"][0][:-1, 1]
    return v / v.max() if normalised else v


def train_net(
    net,
    X_train,
    y_train,
    X_test,
    epochs=10000,
    normalised=False,
    only_at_end=False,
    every=10,
    use_tqdm=True,
    lr=1e-5,
):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=0)

    fh0 = lambda x: normalised_h0(x, normalised=normalised)  # noqa: E731

    encodings = []
    h0_data = []

    # zero-epoch encoding
    encodings.append(net.encode(torch.tensor(X_test).float()))

    h0_data.append(fh0(encodings[0]))

    for epoch in tqdm(range(epochs), disable=not use_tqdm):
        optimizer.zero_grad()
        output = net(torch.tensor(X_train).float())
        loss = criterion(output, torch.tensor(y_train))
        loss.backward()
        optimizer.step()

        # save the encoding
        if epoch % every == 0 and not only_at_end:
            encodings.append(net.encode(torch.tensor(X_test).float()))
            h0_data.append(fh0(encodings[-1]))

    if only_at_end:
        encodings.append(net.encode(torch.tensor(X_test).float()))
        h0_data.append(fh0(encodings[-1]))

    return net, encodings, h0_data
