import torch
from torch_geometric.nn import TransformerConv
from torch.nn import Linear


class GraphAttentionEmbedding(torch.nn.Module):
    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super().__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels
        self.conv = TransformerConv(
            in_channels, out_channels // 2, heads=2, dropout=0.1, edge_dim=edge_dim
        )

    def forward(self, x, last_update, edge_index, t, msg):
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
        return self.conv(x, edge_index, edge_attr)


class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.lin_src = Linear(in_channels, in_channels)
        self.lin_dst = Linear(in_channels, in_channels)
        self.lin_final = Linear(in_channels, 1)

    def forward(self, z_src, z_dst):
        h = self.lin_src(z_src) + self.lin_dst(z_dst)
        h = h.relu()
        return self.lin_final(h)


class ProductLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels=1):
        super().__init__()
        self.lin_src = Linear(in_channels, in_channels)
        self.lin_dst = Linear(in_channels, in_channels)
        self.lin_final = Linear(in_channels, out_channels)

    def forward(self, z_src, z_dst):
        """Return: tensor of shape (z_src.shape[0], z_dst.shape[0])
        where element (i, j) is the unnormalized score
        for an interaction between src `i` and dst `j`
        """
        b = self.lin_src(z_src)  # .relu()
        a = self.lin_dst(z_dst)  # .relu()
        h = a + b.unsqueeze(1)
        h = h.relu()
        return self.lin_final(h).squeeze(2)


class MergeLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels=1):
        super().__init__()
        self.lin_src = Linear(in_channels, in_channels)
        self.lin_dst = Linear(in_channels, in_channels)
        self.lin_final = Linear(in_channels, out_channels)

    def forward(self, z_src, z_dst):
        """Return: tensor of shape z_src.shape = z_dst.shape = (batch_size, out_channels)
        where element `i` is the transformed embedding obtained combining z_src[i] and z_dst[i]
        """
        b = self.lin_src(z_src).relu()
        a = self.lin_dst(z_dst).relu()
        h = a + b
        h = h.relu()
        return self.lin_final(h)


class ReshapeLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.lin_src = Linear(in_channels, in_channels)
        self.lin_final = Linear(in_channels, out_channels)

    def forward(self, z):
        """
        Takes a tensor of shape (z.shape[0], in_channels)
        Return: tensor of shape (z.shape[0], out_channels)
        """
        h = self.lin_src(z).relu()
        return self.lin_final(h)


class Column:
    def __init__(self, idx, seq_len, distrib, params):
        """Class to describe each element in the sequence (column)

        idx: int, index of the element in the sequence (0 is the first)

        seq_len: int, total length of sequence (number of columns)

        distrib: class of `torch.distributions`

        params: dict of distribution parameters and how to compute them.
            key is a string with the name of a parameter of `distrib`
            value is a callable with input the last layer of the `model.decoder` that outputs the value of the parameter

        Example:

            Column(0,
                    distrib=torch.distributions.Normal,
                    params={"loc": model.layer_mu,
                            "scale": lambda z: model.layer_std(z).abs()})
        """
        super(Column, self).__init__()
        self.idx = idx
        self.seq_len = seq_len
        self.mask = self.col2mask()
        self.distrib = distrib
        self.params = params

    def col2mask(self):
        mask = torch.zeros(self.seq_len).bool()
        mask[self.idx] = True
        return mask

    def ptdist(self, x):
        #         return self.distrib(**{p: l(x[:, self.mask]) for p, l in self.params.items()})
        params = {
            p: l[0](**{p1: l1(x[:, self.mask]) for p1, l1 in l[1].items()})
            if isinstance(l, tuple)
            else l(x[:, self.mask])
            for p, l in self.params.items()
        }
        return self.distrib(**params)


class MLP(torch.nn.Module):
    def __init__(self, in_channels, out_channels, h_channels=None):
        super(MLP, self).__init__()
        if h_channels is None:
            h_channels = in_channels
        self.lin_first = torch.nn.Linear(in_channels, h_channels)
        self.lin_final = torch.nn.Linear(h_channels, out_channels)

    def forward(self, x):
        x = self.lin_first(x).relu()
        return self.lin_final(x)


class col_RNN(torch.nn.Module):
    # define model elements
    def __init__(
        self, seq_len, col2K, embed_dim=1, hidden_size=8, num_layers=1, n_comp=3
    ):
        super(col_RNN, self).__init__()

        self.seq_len = seq_len
        #         self.decoder = torch.nn.LSTM(
        self.decoder = torch.nn.GRU(
            input_size=embed_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=False,
        )

        # exponential
        # self.layer_exp = torch.nn.Linear(hidden_size, 1)
        self.layer_exp = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, 1), torch.nn.Softplus()
        )
        # normal
        self.layer_mu = torch.nn.Linear(hidden_size, 1)
        self.layer_std = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, 1), torch.nn.Softplus()
        )
        # gmm
        self.layer_gmm_mix = torch.nn.Linear(hidden_size, n_comp)
        self.layer_gmm_mu = torch.nn.Linear(hidden_size, n_comp)
        self.layer_gmm_std = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, n_comp), torch.nn.Softplus()
        )
        # gmm with non-linear layers for mu, std and mix
        #         self.layer_gmm_mix = MLP(hidden_size, n_comp, h_channels=hidden_size)
        #         self.layer_gmm_mu = MLP(hidden_size, n_comp, h_channels=hidden_size)
        #         self.layer_gmm_std = MLP(hidden_size, n_comp, h_channels=hidden_size)

        # create columns
        self.col2K = col2K
        self.columns = [self.create_column(self.col2K[i], i) for i in range(seq_len)]
        self.categorical_idx = [i for i, v in col2K.items() if isinstance(v, int)]

    def create_column(self, col, i):
        if isinstance(col, int):
            column = Column(
                i,
                self.seq_len,
                distrib=torch.distributions.Categorical,
                params={"logits": torch.nn.Linear(hidden_size, self.col2K[i])},
            )

        elif col == "exponential":
            column = Column(
                i,
                self.seq_len,
                distrib=torch.distributions.Exponential,
                params={"rate": self.layer_exp},
            )

        elif col == "normal":
            column = Column(
                i,
                self.seq_len,
                distrib=torch.distributions.Normal,
                params={"loc": self.layer_mu, "scale": self.layer_std},
            )
        else:
            column = Column(
                i,
                self.seq_len,
                distrib=torch.distributions.MixtureSameFamily,
                params={
                    "mixture_distribution": (
                        torch.distributions.Categorical,
                        {"logits": self.layer_gmm_mix},
                    ),
                    "component_distribution": (
                        torch.distributions.Normal,
                        {"loc": self.layer_gmm_mu, "scale": self.layer_gmm_std},
                    ),
                },
            )
        return column

    # forward propagate input
    def forward(self, x, hx=None):
        x, _ = self.decoder(x, hx)
        return x

    def sample(self, num_samples, preprocess, hx=None):
        seq_len = self.seq_len

        if hx is not None:
            assert (
                hx.shape[1] == num_samples
            ), "Dimension 1 of `hx` should match `num_samples`."

        with torch.no_grad():
            y = torch.empty((num_samples, seq_len)).float().cuda()
            y = preprocess(y)

            for i, c in enumerate(self.columns):
                x = self(y[:, :-1], hx)
                a = c.ptdist(x).sample()
                if isinstance(c.ptdist(x), torch.distributions.Normal) or isinstance(
                    c.ptdist(x), torch.distributions.Exponential
                ):
                    y[:, i + 1, :] = a[:, 0, :]
                else:
                    y[:, i + 1, :] = a
        y0 = y[:, 1:, 0]
        return y0


def preprocess(smp):
    # add start of sequence
    sos = torch.ones(smp.shape[:-1] + ((1,))).cuda()
    smp = torch.cat((sos, smp), -1).cuda()
    return smp.unsqueeze(-1)
