import torch
import torch.nn as nn
import torch.nn.functional as F


def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch

    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def block_pooling(x, block_sizes, pooling='amax', dim=0):
    _, m = x.shape
    block_indices = torch.repeat_interleave(torch.arange(len(block_sizes)).to(block_sizes.device), block_sizes)

    output_matrix = torch.zeros((len(block_sizes), m), dtype=x.dtype, device=x.device)
    output_matrix.index_reduce_(dim, block_indices, x, pooling, include_self=False)
    return output_matrix


class DeepSet(nn.Module):
    def __init__(self, input_dim=1, feature_dim=10, multiplier=2):
        super().__init__()
        self.feature_dim = feature_dim
        self.phi1 = nn.Linear(input_dim, feature_dim)
        self.phi2 = nn.Linear(feature_dim, feature_dim*multiplier)
        self.rho1 = nn.Linear(feature_dim * multiplier, feature_dim * multiplier)
        self.rho2 = nn.Linear(feature_dim * multiplier, 1)

        self.silu = nn.SiLU()

        self.reset_parameters()

    def forward(self, x, block_sizes):
        x = self.silu(self.phi1(x))
        x = self.phi2(x)
        x_max = block_pooling(x, block_sizes, pooling='amax', dim=0)

        x_rep = x_max
        x = self.rho1(self.silu(x_rep))
        x = self.silu(x)
        x = self.rho2(x)
        # x = F.relu(x)
        return x, x_rep

    def reset_parameters(self):
        """ Initialize the weights and bias.
        :return: None
        """
        torch.nn.init.xavier_uniform_(self.phi1.weight)
        torch.nn.init.xavier_uniform_(self.phi2.weight)
        torch.nn.init.xavier_uniform_(self.rho1.weight)
        torch.nn.init.xavier_uniform_(self.rho2.weight)


if __name__ == "__main__":
    import numpy as np
    x = torch.randn(3000*3000)

    net = DeepSet(1, 16).to('cuda:0')

    out = net(x.to('cuda:0'), torch.from_numpy(np.array([3000]*3000)).to('cuda:0'))

