from torch import Tensor
from net.decomposition import get_zero_sum_vectors, get_iso_matrix, batched_matrix_decomposition
from .layer_utils import Dense
import torch.nn as nn
from easydict import EasyDict
import torch


class SchurLayer(nn.Module):
    """
    Schur layer.
    """

    def __init__(self, config, device='cuda'):
        """
        The init.
        Args:
            input_dim: The input dim.
            output_dim: The output dim.
            sigma: The activation.
            device: The device.
        """
        super().__init__()
        self.iso_matrix = get_iso_matrix().cuda()
        self.weight = nn.Parameter(torch.empty(9, 9).to(device), requires_grad=True)
        self.siamese = config.siamese
        if self.siamese  == 'DSS':
            self.alpha = nn.Parameter(torch.empty(1, 1).to(device), requires_grad=True)
        self.init_weights()

    def init_weights(self):
        """
        Init the weights.
        """
        torch.nn.init.xavier_uniform_(self.weight)
        if self.siamese == 'DSS':
            torch.nn.init.xavier_uniform_(self.alpha)

    def forward(self, X: Tensor) -> Tensor:
        """
        Forward of the block.
        Args:
            X: The input.

        Returns: The forwarded output.

        """
        # get all irrep projections
        composed_parts = batched_matrix_decomposition(X, siamese=self.siamese)
        weight = self.weight * self.iso_matrix
        # apply schur and calculate next layer irreps.
        weighted_sum = (composed_parts @ weight).sum(-1).permute(0, 1, 3, 4, 2)
        # reconstruct
        if self.siamese == 'DSS':
            all_graphs_sum = X.mean(-1).mean(-1)
            all_graphs_sum = all_graphs_sum.mean(1).unsqueeze(1) * self.alpha
            weighted_sum = all_graphs_sum.unsqueeze(-2).unsqueeze(-2) + weighted_sum

        return weighted_sum


class LinierAnomalyLayer(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, X):
        #   batch_size, num_graphs,_,n,_ = X.shape
        #  diagonal = torch.eye(n, device=X.device).repeat(batch_size, num_graphs, 1, 1, 1)
        # The first is diagonal matrix such that all entries are the same.
        # a0 = (X * diagonal).sum(-1).sum(-1).view(batch_size, num_graphs,1, 1, 1)
        # off_diag = (X - a0).view(-1, num_graphs,1, n, n)
        return X.squeeze(2).sum(-1).sum(-1)

class AnomalyDtectionModel(nn.Module):
    def __init__(self, config: EasyDict):
        """
        The graph Model.
        Args:
            config: The config.
        """
        super().__init__()
        self.dims = config.dims
        self.layers = nn.Sequential()
        for i in range(len(self.dims) - 1):
            schur_layer = SchurLayer(config=config)
            self.layers.append(schur_layer)
            layer = Dense(seed=0, in_features=self.dims[i], out_features=self.dims[i + 1], activation_fn=nn.ReLU())
            self.layers.append(layer)
        self.head = LinierAnomalyLayer()

    def forward(self, A):
        A = self.layers(A)
        A = self.head(A)
        return A
