from typing import Union, List, Tuple, Optional, Dict
from torch import Tensor

import torch
from torch import nn
from torch_scatter import scatter
from torch_geometric.utils import degree

# from torch_geometric.nn.aggr import DegreeScalerAggregation
from .gatv1 import GATv1Layer
from .gatv2 import GATv2Layer
from torch_geometric.nn import PNAConv


class GATPNAv1Layer(GATv1Layer):
    def __init__(
            self,
            in_channels: Union[int, Tuple[int, int]],
            out_channels: int,
            aggregators: List[str],
            scalers: List[str],
            deg: Tensor,
            negative_slope: float = 0.2,
            add_self_loops: bool = True,
            heads: int = 1,
            bias: bool = True,
            convolve: bool = True,
            lambda_policy: str = None,  # [None, 'learn1', 'learn2', 'learn12', 'gcn_gat', 'individual']
            gcn_mode: bool = False,
            share_weights_score: bool = False,
            share_weights_value: bool = False,
            aggr: str = 'add',
            **kwargs,
    ):
        super().__init__(in_channels, out_channels, negative_slope, add_self_loops,
                         heads, bias, convolve, lambda_policy, gcn_mode, share_weights_score,
                         share_weights_value, aggr, **kwargs)

        self.aggregators = aggregators
        self.scalers = scalers

        deg = deg.to(torch.float)
        self.avg_deg: Dict[str, float] = {
            'lin': deg.mean().item(),
            'log': (deg + 1).log().mean().item(),
            'exp': deg.exp().mean().item(),
        }

        in_channels = len(aggregators) * len(scalers) * in_channels
        self.post_nns = nn.Linear(in_channels, out_channels)

    def aggregate(self, inputs: Tensor, index: Tensor,
                  dim_size: Optional[int] = None, convolve: bool = False) -> Tensor:
        if convolve:
            return super().aggregate(inputs, index, dim_size=dim_size)

        outs = []
        for aggregator in self.aggregators:
            if aggregator == 'sum':
                out = scatter(inputs, index, 0, None, dim_size, reduce='sum')
            elif aggregator == 'mean':
                out = scatter(inputs, index, 0, None, dim_size, reduce='mean')
            elif aggregator == 'min':
                out = scatter(inputs, index, 0, None, dim_size, reduce='min')
            elif aggregator == 'max':
                out = scatter(inputs, index, 0, None, dim_size, reduce='max')
            elif aggregator == 'var' or aggregator == 'std':
                mean = scatter(inputs, index, 0, None, dim_size, reduce='mean')
                mean_squares = scatter(inputs * inputs, index, 0, None,
                                       dim_size, reduce='mean')
                out = mean_squares - mean * mean
                if aggregator == 'std':
                    out = torch.sqrt(torch.relu(out) + 1e-5)
            else:
                raise ValueError(f'Unknown aggregator "{aggregator}".')
            outs.append(out)
        out = torch.cat(outs, dim=-1)

        deg = degree(index, dim_size, dtype=inputs.dtype)
        deg = deg.clamp_(1).view(-1, 1, 1)

        outs = []
        for scaler in self.scalers:
            if scaler == 'identity':
                pass
            elif scaler == 'amplification':
                out = out * (torch.log(deg + 1) / self.avg_deg['log'])
            elif scaler == 'attenuation':
                out = out * (self.avg_deg['log'] / torch.log(deg + 1))
            elif scaler == 'linear':
                out = out * (deg / self.avg_deg['lin'])
            elif scaler == 'inverse_linear':
                out = out * (self.avg_deg['lin'] / deg)
            else:
                raise ValueError(f'Unknown scaler "{scaler}".')
            outs.append(out)
        outs = torch.cat(outs, dim=-1)

        outs = self.post_nns(outs)
        return outs


class GATPNAv2Layer(GATv2Layer):
    def __init__(
            self,
            in_channels: Union[int, Tuple[int, int]],
            out_channels: int,
            aggregators: List[str],
            scalers: List[str],
            deg: Tensor,
            negative_slope: float = 0.2,
            add_self_loops: bool = True,
            heads: int = 1,
            bias: bool = True,
            convolve: bool = True,
            lambda_policy: str = None,  # [None, 'learn1', 'learn2', 'learn12', 'gcn_gat', 'individual']
            gcn_mode: bool = False,
            share_weights_score: bool = False,
            share_weights_value: bool = False,
            aggr: str = 'add',
            **kwargs,
    ):
        super().__init__(in_channels, out_channels, negative_slope, add_self_loops,
                         heads, bias, convolve, lambda_policy, gcn_mode, share_weights_score,
                         share_weights_value, aggr, **kwargs)

        self.aggregators = aggregators
        self.scalers = scalers

        deg = deg.to(torch.float)
        self.avg_deg: Dict[str, float] = {
            'lin': deg.mean().item(),
            'log': (deg + 1).log().mean().item(),
            'exp': deg.exp().mean().item(),
        }

        in_channels = len(aggregators) * len(scalers) * in_channels
        self.post_nns = nn.Linear(in_channels, out_channels)

    def aggregate(self, inputs: Tensor, index: Tensor,
                  dim_size: Optional[int] = None, convolve: bool = False) -> Tensor:
        if convolve:
            return super().aggregate(inputs, index, dim_size=dim_size)

        outs = []
        for aggregator in self.aggregators:
            if aggregator == 'sum':
                out = scatter(inputs, index, 0, None, dim_size, reduce='sum')
            elif aggregator == 'mean':
                out = scatter(inputs, index, 0, None, dim_size, reduce='mean')
            elif aggregator == 'min':
                out = scatter(inputs, index, 0, None, dim_size, reduce='min')
            elif aggregator == 'max':
                out = scatter(inputs, index, 0, None, dim_size, reduce='max')
            elif aggregator == 'var' or aggregator == 'std':
                mean = scatter(inputs, index, 0, None, dim_size, reduce='mean')
                mean_squares = scatter(inputs * inputs, index, 0, None,
                                       dim_size, reduce='mean')
                out = mean_squares - mean * mean
                if aggregator == 'std':
                    out = torch.sqrt(torch.relu(out) + 1e-5)
            else:
                raise ValueError(f'Unknown aggregator "{aggregator}".')
            outs.append(out)
        out = torch.cat(outs, dim=-1)

        deg = degree(index, dim_size, dtype=inputs.dtype)
        deg = deg.clamp_(1).view(-1, 1, 1)

        outs = []
        for scaler in self.scalers:
            if scaler == 'identity':
                pass
            elif scaler == 'amplification':
                out = out * (torch.log(deg + 1) / self.avg_deg['log'])
            elif scaler == 'attenuation':
                out = out * (self.avg_deg['log'] / torch.log(deg + 1))
            elif scaler == 'linear':
                out = out * (deg / self.avg_deg['lin'])
            elif scaler == 'inverse_linear':
                out = out * (self.avg_deg['lin'] / deg)
            else:
                raise ValueError(f'Unknown scaler "{scaler}".')
            outs.append(out)
        outs = torch.cat(outs, dim=-1)

        outs = self.post_nns(outs)
        return outs
