from typing import Optional, List
import torch
from torch import Tensor
from torch_geometric.utils import scatter


def p_norm_aggregation_w_weights(x: Tensor, weights: Optional[Tensor],  batch: Optional[Tensor], p: Optional[int] = 1, 
                    size: Optional[int] = None) -> Tensor:
 
    dim = -1 if x.dim() == 1 else -2

    if weights is None:
        if batch is None:
            return ((torch.abs(x)**p).sum(dim=dim, keepdim=x.dim() <= 2))**(1/p)
        size = int(batch.max().item() + 1) if size is None else size
        return (scatter((torch.abs(x)**p), batch, dim=dim, dim_size=size, reduce='sum'))**(1/p)
        
    else:
        if batch is None:
            return ((torch.abs(x)**p*weights).sum(dim=dim, keepdim=x.dim() <= 2))**(1/p)
        size = int(batch.max().item() + 1) if size is None else size
        return (scatter((torch.abs(x)**p*weights), batch, dim=dim, dim_size=size, reduce='sum'))**(1/p)


def concatenated_p_norm_aggregation_w_weights(x: Tensor, weights: Optional[Tensor], batch: Optional[Tensor], Ps: List[int], 
                    size: Optional[int] = None) -> Tensor:
    dummy = p_norm_aggregation_w_weights(x,weights, batch, Ps[0])  

    if len(Ps) > 1:
        for p in Ps[1:]:
            y = p_norm_aggregation_w_weights(x=x, weights=weights, batch=batch, p=p) 
            dummy = torch.cat((y,dummy),dim = 1)
        
    return dummy


def p_norm_aggregation(x: Tensor, batch: Optional[Tensor], p: Optional[int] = 1, 
                    size: Optional[int] = None) -> Tensor:
 
    dim = -1 if x.dim() == 1 else -2

    if batch is None:
        return ((torch.abs(x)**p).sum(dim=dim, keepdim=x.dim() <= 2))**(1/p)
    size = int(batch.max().item() + 1) if size is None else size
    return (scatter((torch.abs(x)**p), batch, dim=dim, dim_size=size, reduce='sum'))**(1/p)


def concatenated_p_norm_aggregation(x: Tensor, batch: Optional[Tensor], P: Optional[int] = 1, 
                    size: Optional[int] = None) -> Tensor:
    z = p_norm_aggregation(x, batch, 1)  

    if P > 1:
        for p in range(2 , P+1):
            y = p_norm_aggregation(x, batch, p) 
            z = torch.cat((y,z),dim = 1)
        
    return z