import torch
import torch.nn as nn
import torch.nn.functional as F


class TBRModel(nn.Module):
    def __init__(self, device, root_weights, path_mat, delta_mat, p=1, init_deltas=False, init_roots=True):
        """
        :param device: cpu or gpu
        :param root_weights: np array or tensor of weights (initially zeros)
        :param path_mat: np array (num_edges, num_nodes)
        :param delta_mat: np array (feature_dim, num_edges)
        :param p: norm for delta loss, defaults to 1-norm, approximating 0-norm
        :param init_deltas: flag to instantiate deltas to small non-zero values before training
        :param init_roots: flag to instantiate root parameters to small non-zero values before training
        """
        super(TBRModel, self).__init__()
        self.device = device
        self.path_mat = torch.tensor(path_mat, device=device, dtype=torch.double)
        self.p = p
        self.root_weights = nn.Parameter(torch.tensor(root_weights, device=device, dtype=torch.double, requires_grad=True))
        if init_roots:
            torch.nn.init.normal_(self.root_weights, mean=0.0, std=0.01)

        self.delta_mat = nn.Parameter(torch.tensor(delta_mat, device=device, dtype=torch.double, requires_grad=True))
        if init_deltas:
            torch.nn.init.normal_(self.delta_mat, mean=0.0, std=0.01)

    def delta_loss(self, idx, rows=None):  # messy looking, but avoids unnecessary slicing or instantiating rows list
        if idx is not None:
            edges = torch.max(self.path_mat[:, idx], dim=1)
            if rows is None:
                mat_slice = self.delta_mat.T[torch.nonzero(edges.values == 1.0).reshape(-1)]
            else:
                mat_slice = self.delta_mat[rows].T[torch.nonzero(edges.values == 1.0).reshape(-1)]
            return torch.norm(mat_slice, p=self.p)

        if rows is None:
            return torch.norm(self.delta_mat, p=self.p)
        else:
            return torch.norm(self.delta_mat[rows], p=self.p)

    # node_idx identifies the paths relevant to all samples in x, in the same order
    def forward(self, x, node_idx):
        effective_weights = torch.add(self.root_weights, torch.matmul(self.delta_mat, self.path_mat[:, node_idx]).T)
        # this works for linreg with bias-in only
        return torch.sum((x * effective_weights), dim=1)


class CellEmbeddingLinear(nn.Module):
    """
    The embedding model, used in both DendroNet and baseline experiments
    Takes as input the simulated gene expression, outputs a latent embedding vector of specified size
    """
    def __init__(self, input_dim, output_dim, use_bias=False):
        super(CellEmbeddingLinear, self).__init__()
        self.lin_1 = nn.Linear(input_dim, output_dim, bias=use_bias)

    def forward(self, x):
        return self.lin_1(x)


class EmbeddingNN(nn.Module):
    """
    A NN model to be used for non-linear embeddings
    """
    def __init__(self, input_dim, output_dim, hidden_0=256, hidden_1=64, p=0.5):
        super(EmbeddingNN, self).__init__()
        self.batchnorm0 = nn.BatchNorm1d(num_features=output_dim)
        self.lin_0 = nn.Linear(input_dim, hidden_0)
        self.drop0 = nn.Dropout(p=p)
        self.lin_1 = nn.Linear(hidden_0, hidden_1)
        self.drop1 = nn.Dropout(p=p)
        self.output = nn.Linear(hidden_1, output_dim)
        self.relu = torch.nn.LeakyReLU()

    def forward(self, x, dropout=False, batchnorm=False):
        if dropout:
            return self.output(self.drop1(self.relu(self.lin_1(self.drop0(self.relu(self.lin_0(x)))))))
        elif batchnorm:
            return self.batchnorm0(self.output(self.relu(self.lin_1(self.relu(self.lin_0(x))))))
        else:
            return self.output(self.relu(self.lin_1(self.relu(self.lin_0(x)))))


class LinRegModel(nn.Module):
    def __init__(self, input_dim, use_bias=False):
        super(LinRegModel, self).__init__()
        self.lin_1 = nn.Linear(input_dim, 1, bias=use_bias)

    def forward(self, x):
        return self.lin_1(x).squeeze()
