import torch
import torch.nn as nn
from layers.sampling import SamplingLayer
from layers.anti_aliasing import AntiAliasingLayer
from graph.graph_constructors import GraphConstructor
from utils.group_utils import *

class SubgroupDownsample(nn.Module):
    def __init__(self,
                group_type: str,
                order: int,
                sub_group_type: str,
                subsampling_factor: int,
                num_features: int,
                generator: str='r-s',
                device: str = 'cpu',
                dtype: torch.dtype = torch.float32,
                sample_type: str = 'sample',
                apply_antialiasing: bool = False,
                anti_aliasing_kwargs: dict = {"smooth_operator": 'adjacency',
                                              "mode": 'linear_optim',
                                              "iterations": 100000,
                                              "smoothness_loss_weight": 5.0,
                                              "threshold": 0.0,
                                              "equi_constraint": True,
                                              "equi_correction": True,
                                              },
                cannonicalize : bool = True,
                ):
        super(SubgroupDownsample, self).__init__()
        self.group_type = group_type
        self.order = order
        self.sub_group_type = sub_group_type
        self.subsampling_factor = subsampling_factor
        self.num_features = num_features
        self.generator = generator
        self.anti_aliasing_kwargs = anti_aliasing_kwargs
        self.device = device
        self.dtype = dtype

        self.sample_type = sample_type
        self.apply_antialiasing = apply_antialiasing

        self.G = get_group(group_type, order)
        sub_order =  order//subsampling_factor if group_type == sub_group_type else  order// max(subsampling_factor//2, 1)
        self.G_sub = get_group(sub_group_type, sub_order)

        self.graphs = GraphConstructor(group_size=self.G.order(),
                                       group_type=self.group_type,
                                       group_generator=self.generator,
                                       subgroup_type=self.sub_group_type,
                                       subsampling_factor=self.subsampling_factor)
        
        self.sample = SamplingLayer(sampling_factor=self.subsampling_factor,
                                        nodes=self.graphs.graph.nodes,
                                        subsample_nodes=self.graphs.subgroup_graph.nodes,
                                        type=sample_type)
        
        self.sample.to(device=self.device, dtype=self.dtype)
        
        if apply_antialiasing:
            print("initializing anti aliasing layer")
            self.anti_aliaser = AntiAliasingLayer(nodes=self.graphs.graph.nodes,
                                                    adjaceny_matrix=self.graphs.graph.adjacency_matrix,
                                                    basis=self.graphs.graph.fourier_basis,
                                                    subsample_nodes=self.graphs.subgroup_graph.nodes,
                                                    subsample_adjacency_matrix=self.graphs.subgroup_graph.adjacency_matrix,
                                                    sub_basis=self.graphs.subgroup_graph.fourier_basis,
                                                    dtype=self.dtype,
                                                    device=self.device,
                                                    equi_raynold_op=self.graphs.graph.equi_raynold_op,
                                                    **self.anti_aliasing_kwargs)
            self.anti_aliaser.to(device=self.device, dtype=self.dtype)
        else:
            print("anti aliasing layer not applied")
            self.anti_aliaser = None
        
        
        self.cannonicalize = None
        
    def forward(self, x: torch.Tensor):
        
        v = [(-1,-1)]
        
        if self.anti_aliaser is not None:
            x = self.anti_aliaser(x)
        
        x = self.sample(x)

        return x, v
    
    def upsample(self, x: torch.Tensor, v: list = None):
        if self.anti_aliaser is not None:
            x = self.anti_aliaser.up_sample(x)
        else:
            x = self.sample.up_sample(x)
        
        return x