from escnn import gspaces
import escnn.nn as enn
import torch.nn as nn
from utils.group_utils import *
class rnConv(nn.Module):
    def __init__(self,
                 *,
                in_group_type: str,
                in_order: int,
                in_num_features: int,
                in_representation: str,
                out_group_type: str,
                out_order: int,
                out_num_features: int,
                out_representation: str,
                domain: int = 2,
                kernel_size: int = 3,
                layer_kwargs: dict = {}):
        super(rnConv, self).__init__()
        self.in_group_type = in_group_type
        self.in_order = in_order
        self.in_num_features = in_num_features
        self.in_representation = in_representation
        self.out_group_type = out_group_type
        self.out_order = out_order
        self.out_num_features = out_num_features
        self.out_representation = out_representation

        self.G_in = get_group(in_group_type, in_order)
        self.G_out = get_group(out_group_type, out_order)
        self.gspace_in = get_gspace(group_type=in_group_type,
                                    order=in_order,
                                    num_features=in_num_features,
                                    representation=in_representation)
        self.gspace_out = get_gspace(group_type=out_group_type,
                                    order=out_order,
                                    num_features=out_num_features,
                                    representation=out_representation)
        
        if domain == 2:
            self.conv = enn.R2Conv(in_type=self.gspace_in,
                                out_type=self.gspace_out,
                                kernel_size=kernel_size,
                                padding=(kernel_size-1)//2,
                                **layer_kwargs)
        else:
            raise ValueError(f'Domain {domain} not found')


        
    def forward(self, x):
        f_x = self.gspace_in(x)
        f_x_out = self.conv(f_x)

        return f_x_out.tensor

    
    def get_group(self):
        return self.G
    
    def get_gspace(self):
        return self.gspace