import torch.nn as nn
from .modules import *


class ElementWiseLayer(nn.Module):
    def __init__(self, algebra, in_channels: int, out_channels: int,
                 use_eigenvalue=False, eigval_encoder=None, 
                 restrict_grade=None, share_weights_from=None):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.algebra = algebra
        self.use_eigenvalue = use_eigenvalue
        self.restrict_grade = restrict_grade
        if share_weights_from is None:
            self.net = nn.ModuleList([
                MVLinear(self.algebra, in_channels, out_channels, 
                            use_eigenvalue=use_eigenvalue,
                            eigval_encoder=eigval_encoder,
                            restrict_grade=restrict_grade),
                MVSiLU(self.algebra, out_channels),
                ElementWiseGP(self.algebra, out_channels,
                                use_eigenvalue=use_eigenvalue,
                                eigval_encoder=eigval_encoder,
                                restrict_grade=restrict_grade),
                NormalizationLayer(self.algebra, out_channels, 
                                    restrict_grade=restrict_grade)
            ])
        else:
            self.net = nn.ModuleList([
                share_weights_from.net[0].share_weights_with(self.algebra),
                MVSiLU(self.algebra, out_channels),
                ElementWiseGP(self.algebra, out_channels,
                              use_eigenvalue=use_eigenvalue,
                              eigval_encoder=eigval_encoder,
                              restrict_grade=restrict_grade,
                              share_weights_from=share_weights_from.net[2]),
                NormalizationLayer(self.algebra, out_channels, 
                                    restrict_grade=restrict_grade)
            ])
    
    def to(self, device):
        super().to(device)
        self.net[0].to(device)
        self.net[1].to(device)
        self.net[2].to(device)
        self.net[3].to(device)

    def forward(self, x, eigenvalue=None):
        y = self.net[0](x, eigenvalue)
        y = self.net[1](y)
        y = self.net[2](y, eigenvalue)
        return self.net[3](y)
    

class FullyConnectedLayer(nn.Module):
    def __init__(self, algebra, in_channels: int, out_channels: int,
                 use_eigenvalue=False, eigval_encoder=None, 
                 restrict_grade=None, share_weights_from=None):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.algebra = algebra
        self.use_eigenvalue = use_eigenvalue

        if self.use_eigenvalue:
            assert in_channels == out_channels

        self.net = nn.ModuleList([
            FullyConnectedGP(self.algebra, in_channels, out_channels,
                             use_eigenvalue_1=use_eigenvalue,
                             use_eigenvalue_2=use_eigenvalue,
                             eigval_encoder_1=eigval_encoder,
                             eigval_encoder_2=eigval_encoder,
                             restrict_grade=restrict_grade,
                             share_weights_from=share_weights_from.net[0]
                             if share_weights_from is not None else None),
            NormalizationLayer(self.algebra, out_channels, 
                               restrict_grade=restrict_grade)
        ])

    def to(self, device):
        super().to(device)
        self.net[0].to(device)
        self.net[1].to(device)

    def forward(self, x, eigenvalue=None):
        y = self.net[0](x, eigenvalue)
        return self.net[1](y)

