from typing import Optional, Tuple
from torch import Tensor
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SimpleConv
from torch_geometric.nn.aggr import Aggregation


class AggConcat(Aggregation):

    def __init__(self, max_num_elements: int, **kwargs):
        super().__init__()

        self.max_num_elements = max_num_elements
    def to_dense_batch(
        self,
        x: Tensor,
        index: Optional[Tensor] = None,
        ptr: Optional[Tensor] = None,
        dim_size: Optional[int] = None,
        dim: int = -2,
        fill_value: float = 0.0,
        max_num_elements: Optional[int] = None,
    ) -> Tuple[Tensor, Tensor]:

        self.assert_index_present(index)
        self.assert_sorted_index(index)
        self.assert_two_dimensional_input(x, dim)

        return to_dense_batch(
            x,
            index,
            batch_size=dim_size,
            fill_value=fill_value,
            max_num_nodes=max_num_elements,
            source_index=self.source_index,
        )
    def forward(self, x: Tensor, index: Optional[Tensor] = None,
                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
                dim: int = -2) -> Tensor:
        x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,
                                   max_num_elements=self.max_num_elements)
        return x.view(-1, x.size(1) * x.size(2))

class MessageConcat(nn.Module):
    # send representation of each node to all its neighbors and concatenate features
    def __init__(self, max_num_elements):
        super().__init__()
        self.message_passing = SimpleConv(aggr = AggConcat(max_num_elements=max_num_elements))
    def forward(self, x, edge_index):
        # input shape [B, P, D]
        # aggregation expects [B*P, D] where we repeat the edge index for each sample in the batch
        B, C = x.shape[0],x.shape[1]
        x = x.view(-1,x.shape[2])
        self.message_passing.aggr_module.source_index = edge_index[0]
        x = self.message_passing(x, edge_index).view(B, C,-1)
        return x


class MessageMean(nn.Module):
    # send representation of each node to all its neighbors and take the mean
    def __init__(self):
        super().__init__()
        self.message_passing = SimpleConv(aggr = 'mean')
    def forward(self, x, edge_index):
        # input shape [B, P, D]
        # aggregation expects [B*P, D] where we repeat the edge index for each sample in the batch
        B, C = x.shape[0],x.shape[1]
        x = x.view(-1,x.shape[2])
        return self.message_passing(x, edge_index).view(B, C,-1)

class DevicePerceptron(nn.Module):
    # simulate multiple perceptrons at multiple device
    def __init__(self, n_device, d_in, d_out, activation='relu',use_relu=True):
        super().__init__()
        self.device = n_device
        self.use_relu = use_relu
        self.weight = torch.nn.Parameter(torch.empty(n_device, d_in, d_out))
        self.bias = torch.nn.Parameter(torch.empty(n_device, d_out))
        if activation =='relu':
            self.activation = torch.nn.functional.relu
        elif activation == 'leakyrelu':
            self.activation = torch.nn.functional.leaky_relu
        else:
            raise ValueError('activation not supported')
        self._reset_parameters()


    def _reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x, edge_index=None):
        # B - batch size
        # P - number of patches
        # x - [B,P,D_in]
        x = x.permute(1, 0, 2)
        # weight - [P,D_in,D_out]}
        # bias - [P,D_out]
        # x - [P,B,D_in]
        x = torch.matmul(x, self.weight) + self.bias.unsqueeze(1)
        # x - [P,B,D_out]
        x = x.permute(1, 0, 2)
        # x - [B,P,D_out]
        if self.use_relu:
            return self.activation(x)
        else:
            return x

class MultiHeadDevicePerceptron(nn.Module):
    # simulate multiple perceptrons at multiple device
    def __init__(self, n_head, n_device, d_in, d_out, activation='relu'):
        super().__init__()
        self.n_head = n_head
        self.n_device = n_device
        self.weight = torch.nn.Parameter(torch.empty(n_head, n_device, d_in, d_out))
        self.bias = torch.nn.Parameter(torch.empty(n_head, n_device, d_out))
        if activation =='relu':
            self.activation = torch.nn.functional.relu
        elif activation == 'leakyrelu':
            self.activation = torch.nn.functional.leaky_relu
        else:
            raise ValueError('activation not supported')
        self._reset_parameters()


    def _reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x, edge_index=None):
        # H - number of heads
        # B - batch size
        # P - number of patches
        # x - [B,P,D_in]
        x = x.permute(1, 0, 2)
        # x - [P,B,D_in]
        x = x.unsqueeze(0).repeat(self.n_head,1,1,1)
        # weight - [H,P,D_in,D_out]}
        # bias - [H,P,D_out]
        # x - [H,P,B,D_in]
        x = torch.matmul(x, self.weight) + self.bias.unsqueeze(2)
        # x - [H,P,B,D_out]
        x = self.activation(x)
        x = x.mean(dim=0)
        # x - [P,B,D_out]
        x = x.permute(1, 0, 2)
        # x - [B,P,D_out]
        return x


class DeepMVFL(nn.Module):
    """
    Deep MVFL model
    repeat: number of times to repeat
    """

    def __init__(self, n_device, d_inter_init=16, d_inter=4, repeat=1, gossip_mode='gm',n_gossip=0,
                 image_shape=(28,28,1),activation='relu'):
        super().__init__()


        d_init = int(image_shape[0]*image_shape[1]*image_shape[2] / n_device)
        self.repeat = repeat
        self.n_head = int(n_device/repeat)
        self.gossip_mode = gossip_mode
        if gossip_mode == 'gm':
            self.gossip_layer = self.gossip_gm
        elif gossip_mode == 'ugm':
            self.gossip_layer = self.gossip_ugm
        elif gossip_mode == 'logit':
            self.gossip_layer = self.gossip_logit
        elif gossip_mode == 'ugmrm':
            self.gossip_layer = self.gossip_ugmrm
        else:
            raise ValueError('gossip mode not supported')
        self.n_gossip = n_gossip
        # hard code for 16 devices for now
        self.init_layers = nn.Sequential(
            DevicePerceptron(n_device, d_init, d_inter_init, activation=activation),  # not sure if this is the best setup
            DevicePerceptron(n_device, d_inter_init, d_inter, activation=activation),
        )

        self.deep_layers = []
        for _ in range(repeat):
            self.deep_layers.append(nn.Sequential(
                MultiHeadDevicePerceptron(self.n_head, n_device, d_inter*n_device, d_inter, activation=activation),
            ))
        self.deep_layers = nn.ModuleList(self.deep_layers)

        self.classifier = nn.Sequential(
            DevicePerceptron(n_device, d_inter * n_device, 10, activation=activation, use_relu=False))

        self.message_passing = MessageConcat(max_num_elements=n_device)

        self.gossip_passing = MessageMean()

    def forward(self, x, edge_index):
        x = x.view(x.shape[0], x.shape[1], -1)
        x = self.init_layers(x).contiguous()
        x = self.message_passing(x, edge_index)

        for layer in self.deep_layers:
            x = layer(x).contiguous()
            x = self.message_passing(x, edge_index)

        x = self.classifier(x)
        #x = self.gossip_layer(x, edge_index)

        return x

    def gossip_gm(self, x, edge_index):
        x = nn.functional.log_softmax(x, dim=-1)
        if self.n_gossip > 0:
            for i in range(self.n_gossip):
                x = self.gossip_passing(x, edge_index)
                x = x - torch.logsumexp(x, dim=-1, keepdim=True)
        return x

    def gossip_ugm(self, x, edge_index):
        x = nn.functional.log_softmax(x, dim=-1)
        if self.n_gossip > 0:
            for i in range(self.n_gossip):
                x = self.gossip_passing(x, edge_index)
        return x

    def gossip_logit(self, x, edge_index):
        x = x.contiguous()
        if self.n_gossip > 0:
            for i in range(self.n_gossip):
                x = self.gossip_passing(x, edge_index)
        x = nn.functional.log_softmax(x, dim=-1)
        return x

    def gossip_ugmrm(self, x, edge_index):
        x = nn.functional.log_softmax(x, dim=-1)
        if self.n_gossip > 0:
            for i in range(self.n_gossip):
                x = self.gossip_passing(x, edge_index)
        x = x - torch.logsumexp(x, dim=-1, keepdim=True)
        return x

class MVFL(nn.Module):
    """
    (Deep) MVFL model
    repeat: number of times to repeat the MVFL layer (Deep MVFL)
    """
    def __init__(self, n_device, d_inter_init=16, d_inter=4, repeat=0, gossip_mode='gm', n_gossip=0,
                 image_shape=(28,28,1),activation='relu'):
        super().__init__()

        d_init = int(image_shape[0]*image_shape[1]*image_shape[2] / n_device)
        #self.repeat = repeat
        self.gossip_mode = gossip_mode
        if gossip_mode == 'gm':
            self.gossip_layer = self.gossip_gm
        elif gossip_mode == 'ugm':
            self.gossip_layer = self.gossip_ugm
        elif gossip_mode == 'logit':
            self.gossip_layer = self.gossip_logit
        elif gossip_mode == 'ugmrm':
            self.gossip_layer = self.gossip_ugmrm
        else:
            raise ValueError('gossip mode not supported')
        self.n_gossip = n_gossip
        # 16 device - 7x7
        self.init_layers = nn.Sequential(
            DevicePerceptron(n_device, d_init, d_inter_init, activation=activation), # not sure if this is the best setup
            DevicePerceptron(n_device, d_inter_init, d_inter, activation=activation),
        )
        self.classifier = nn.Sequential(
            DevicePerceptron(n_device, d_inter*n_device, d_inter*n_device, activation=activation),
            DevicePerceptron(n_device, d_inter*n_device, 10, activation=activation, use_relu=False))

        self.message_passing = MessageConcat(max_num_elements=n_device)

        self.gossip_passing = MessageMean()
    def forward(self, x, edge_index):
        x = x.view(x.shape[0], x.shape[1], -1)
        x = self.init_layers(x).contiguous()
        x = self.message_passing(x, edge_index)

        x = self.classifier(x)
        #x = self.gossip_layer(x, edge_index) #switching this off to support gossiping only during inference
        x = nn.functional.log_softmax(x, dim=-1) #Added this to support gossiping only during inference
        return x

    def gossip_gm(self, x, edge_index):
        x = nn.functional.log_softmax(x, dim=-1)
        if self.n_gossip > 0:
            for i in range(self.n_gossip):
                x = self.gossip_passing(x, edge_index)
                x = x - torch.logsumexp(x, dim=-1, keepdim=True)
        return x
    def gossip_ugm(self, x, edge_index):
        x = nn.functional.log_softmax(x, dim=-1)
        if self.n_gossip > 0:
            for i in range(self.n_gossip):
                x = self.gossip_passing(x, edge_index)
        return x
    def gossip_ugmrm(self, x, edge_index):
        x = nn.functional.log_softmax(x, dim=-1)
        if self.n_gossip > 0:
            for i in range(self.n_gossip):
                x = self.gossip_passing(x, edge_index)
        x = x - torch.logsumexp(x, dim=-1, keepdim=True)
        return x

    def gossip_logit(self, x, edge_index):
        x = x.contiguous()
        if self.n_gossip > 0:
            for i in range(self.n_gossip):
                x = self.gossip_passing(x, edge_index)
        x = nn.functional.log_softmax(x, dim=-1)
        return x

class VFL(MVFL):
    """
    VFL model
    I don't change the code for model. Instead, it is controlled by edge index.
    """

    def __init__(self, n_device, d_inter_init, d_inter, image_shape, activation):
        super(VFL, self).__init__(n_device=n_device, d_inter_init=d_inter_init,d_inter=d_inter,
                                  repeat=0, n_gossip=False, image_shape=image_shape, activation=activation)



def to_dense_batch(
        x: Tensor,
        batch: Optional[Tensor] = None,
        fill_value: float = 0.0,
        max_num_nodes: Optional[int] = None,
        batch_size: Optional[int] = None,
        source_index: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
    if batch is None and max_num_nodes is None:
        mask = torch.ones(1, x.size(0), dtype=torch.bool, device=x.device)
        return x.unsqueeze(0), mask

    if batch is None:
        batch = x.new_zeros(x.size(0), dtype=torch.long)

    if batch_size is None:
        batch_size = int(batch.max()) + 1

    tmp = source_index % max_num_nodes
    idx = tmp + (batch * max_num_nodes)

    size = [batch_size * max_num_nodes] + list(x.size())[1:]
    out = torch.as_tensor(fill_value, device=x.device)
    out = out.to(x.dtype).repeat(size)
    out[idx] = x
    out = out.view([batch_size, max_num_nodes] + list(x.size())[1:])

    mask = torch.zeros(batch_size * max_num_nodes, dtype=torch.bool,
                       device=x.device)
    mask[idx] = 1
    mask = mask.view(batch_size, max_num_nodes)

    return out, mask

