import torch.nn as nn
import torch
import numpy as np

class EnhancementLayer(nn.Module):
    def __init__(self, ent_embeds, rel_embeds, degree, temporal_rel_adj=None, config=None, time_span=24):
        super(EnhancementLayer, self).__init__()
        # import IPython
        # IPython.embed()
        self.config = config
        self.time_span = time_span


        self.degree = torch.tensor(degree, dtype=torch.float32)
        if self.config['cuda']:
            self.degree = self.degree.cuda()
        self.temporal_rel_adj = temporal_rel_adj
        self.model_embeds = ent_embeds
        self.relation_embeds = rel_embeds

        transformer_layer = nn.TransformerEncoderLayer(d_model=config['ent_dim'], nhead=config['nheads'])
        self.temp_encoder = nn.TransformerEncoder(transformer_layer, num_layers=config['num_layers'])
        self.linear_projection = nn.Linear(config['rel_dim'] + config['ent_dim'], config['ent_dim'])

        self.sim_ent_cach = {}

    def get_similar_entities(self, batch):
        # mask = (-1 < self.temporal_rel_adj[rels][:, 0, :]) & (self.temporal_rel_adj[rels][:, 0, :] < times.unsqueeze(1))
        # similar_ent = self.temporal_rel_adj[rels][:, 1, :][mask]
        # embedding_matrix = self.model_embeds(similar_ent, torch.zeros_like(similar_ent))
        #
        # group_sizes = mask.sum(dim=1).tolist()
        # groups = torch.split(embedding_matrix, group_sizes)
        # group_means = torch.stack([group.mean(dim=0) for group in groups])
        #
        # nan_mask = torch.isnan(group_means).any(dim=1)
        # group_means[nan_mask] = 0
        pass
    def get_similar_entities_(self, batch):
        rels = batch[1]
        times = batch[-1]

        all_emb_matrices = []

        cache = {}
        for t, r in zip(times, rels):
            # try:
            r = r.item()
            t = t.item()
            if (r, t) not in self.sim_ent_cach:
                timestamps = self.temporal_rel_adj[r][0]
                mask = (-1 < timestamps) & (timestamps < t)
                similar_ent = self.temporal_rel_adj[r][1][mask]
                self.sim_ent_cach[(r, t)] = (similar_ent, timestamps[mask])

            if (r, t) not in cache:
                if r in self.temporal_rel_adj:

                    similar_ent = self.sim_ent_cach[(r, t)][0]
                    timestamps = self.sim_ent_cach[(r, t)][1]
                    unique_timestamps, group_indices = torch.unique(timestamps, return_inverse=True)
                    binary_matrix = torch.zeros((len(similar_ent), len(unique_timestamps))).cuda()
                    binary_matrix[torch.arange(len(similar_ent)), group_indices] = 1


                    embedding_matrix = binary_matrix.T @ self.model_embeds(similar_ent, torch.zeros_like(similar_ent)) / torch.sum(binary_matrix, dim=0)[:, None]

                else:
                    embedding_matrix = torch.zeros(self.config['ent_dim']).cuda()
                cache[(r, t)] = embedding_matrix


            all_emb_matrices.append(cache[(r, t)])
        return all_emb_matrices

    def forward(self, batch):
        ents = batch[0]
        times = batch[2]
        neighbors = batch[-1][:, 0, :]
        timestamps = batch[-1][:, 1, :]



        # The timestamps are the last window w timestamps where there was a neighborhood connection
        # The neighbors are the corresponding entities from t to t-w
        # Neighbors is a b x w x n matrix where n is the number of neighbors at  each time step
        # timestamps is a b x w matrix where each row is the timestamps of the neighbors at each time step
        # We want to get the average embedding of the neighbors at each time step
        # We then want to use a diacronic transformer to get the average embedding of the neighbors at each time step

        # average embedding of the neighbors at each time step
        emb = self.model_embeds(neighbors, torch.zeros_like(neighbors))
        # import IPython;
        # IPython.embed()

        # diacronic coefficients
        coeffs = torch.sigmoid(-self.config['lmbda'] * ((times.unsqueeze(1) - timestamps)/self.time_span))
        # coeffs = torch.exp(-1 * self.config['lmbda'] * (timestamps - timestamps))



        temporal_avg = emb * coeffs.unsqueeze(-1)
        temporal_avg = temporal_avg.sum(dim=1) / coeffs.sum(dim=1).unsqueeze(-1)
        # Replace inf with 0
        if torch.isinf(temporal_avg).any() or torch.isnan(temporal_avg).any():
            import IPython;
            IPython.embed()
        temporal_avg[torch.isnan(temporal_avg)] = 0




        degree = self.degree[ents]
        if self.config['degree']:
            temporal_avg = temporal_avg * (1.0 / (1+torch.exp(degree))).unsqueeze(-1)


        return temporal_avg


    def forward_(self, batch):
        rels = batch[1]
        ents = batch[0]

        # import IPython;
        # IPython.embed()
        ent_embed = self.get_similar_entities_(batch)
        #create a 2d tensor from the list of matrices


        group_sums_batch_padded = torch.nn.utils.rnn.pad_sequence(ent_embed, batch_first=True, padding_value=0)
        # import IPython;
        # IPython.embed()

        # Feed all batches into the temporal encoder at once
        ent_embed = self.temp_encoder(group_sums_batch_padded).mean(dim=1)
        relation_embed = self.relation_embeds(rels)
        concat_embed = torch.cat([ent_embed, relation_embed], dim=1)
        degree = self.degree[ents]+1

        out = self.linear_projection(concat_embed)
        out = out * (1.0/degree.unsqueeze(-1))

        return out




