import torch
import torch.nn as nn
from layers.rnconv import *
from thirdparty.blurpool2d import BlurPool2d
from layers.downsampling import *
from utils.group_utils import *
from einops import rearrange


class Gcnn(nn.Module):
    def __init__(self,
                 *,
                 num_layers: list[int],
                 num_channels: list[int],
                 kernel_sizes:  list[int],
                 num_classes: int,
                 dwn_group_types: list,
                 dwn_orders: list,
                 spatial_subsampling_factors: list[int],
                 subsampling_factors,
                 domain,
                 pooling_type,
                 apply_antialiasing,
                 cannonicalize,
                 antialiasing_kwargs,
                 dropout_rate,
                 fully_convolutional=False,
                 layer_kwargs={}) -> None:
        super().__init__()
        self.num_layers = num_layers
        self.num_channels = num_channels
        self.kernel_sizes = kernel_sizes
        self.num_classes = num_classes
        self.dwn_group_types = dwn_group_types
        self.dwn_orders = dwn_orders
        self.spatial_subsampling_factors = spatial_subsampling_factors
        self.subsampling_factors = subsampling_factors
        self.domain = domain
        self.pooling_type = pooling_type
        self.apply_antialiasing = apply_antialiasing
        self.cannonicalize = cannonicalize
        self.antialiasing_kwargs = antialiasing_kwargs
        self.dropout_rate = dropout_rate
        self.fully_convolutional = fully_convolutional

        self.conv_layers = nn.ModuleList()
        self.sampling_layers = nn.ModuleList()
        self.spatial_sampling_layers = nn.ModuleList()

        for i in range(num_layers):
            if i ==0:
                rep = 'trivial'
            else:
                rep = 'regular'

            conv = rnConv(in_group_type=self.dwn_group_types[i][0],
                in_order=self.dwn_orders[i][0],
                in_num_features=num_channels[i],
                in_representation=rep,
                out_group_type=self.dwn_group_types[i][0],
                out_order=self.dwn_orders[i][0],
                out_num_features=num_channels[i+1],
                out_representation='regular',
                domain= domain,
                kernel_size=kernel_sizes[i],
                layer_kwargs=layer_kwargs)
            
            self.conv_layers.append(conv)


            if subsampling_factors[i] > 1:
                sampling_layer = SubgroupDownsample(group_type=self.dwn_group_types[i][0],
                                                    order=self.dwn_orders[i][0],
                                                    sub_group_type=self.dwn_group_types[i][1],
                                                    subsampling_factor=subsampling_factors[i],
                                                    num_features=num_channels[i+1], 
                                                    generator='r-s',
                                                    device='cpu',
                                                    dtype=torch.float32,
                                                    sample_type='sample',
                                                    apply_antialiasing=self.apply_antialiasing,
                                                    anti_aliasing_kwargs=self.antialiasing_kwargs,
                                                    cannonicalize=self.cannonicalize)
            else:
                sampling_layer = None
            
            if self.spatial_subsampling_factors[i] > 1:
                spatial_sampling_layer = BlurPool2d(channels=num_channels[i+1]*conv.G_out.order(),
                                                    stride=self.spatial_subsampling_factors[i])
            else:
                spatial_sampling_layer = nn.Identity()
            
            self.sampling_layers.append(sampling_layer)
            self.spatial_sampling_layers.append(spatial_sampling_layer)

        self.last_g_size = get_group(dwn_group_types[-1][1], dwn_orders[-1][1]).order()
        self.linear_layer = nn.Linear(num_channels[-1], num_classes)
    
    def pooling(self, x):
        x = rearrange(x , "b (c g) h w -> b c (g h w)", g=self.last_g_size)
        if self.pooling_type == 'max':
            x = torch.max(x, dim=-1)[0]
        elif self.pooling_type == 'mean':
            x = torch.mean(x, dim=-1)
        return x
    def get_feature(self, x):
        for i in range(self.num_layers):
            x = self.conv_layers[i](x)
            x = torch.relu(x)
            if self.spatial_subsampling_factors[i] > 1:
                x = self.spatial_sampling_layers[i](x)

            if self.sampling_layers[i] is not None:
                x,_ = self.sampling_layers[i](x)
            
            if self.dropout_rate > 0:
                x = nn.functional.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.pooling(x)
        
        return x

    def forward(self, x):
        x = self.get_feature(x)
        if not self.fully_convolutional:
            x = self.linear_layer(x)
        return x
