from typing import List, Tuple
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from .propagation_ import OurPropagate

from .utils import MixedLinear, MixedDropout


class Id(nn.Module):
    def forward(
            self, predictions: torch.Tensor, idx: torch.LongTensor
    ) -> torch.Tensor:
        return predictions


class OurBranch(nn.Module):
    def __init__(
            self, in_features: int, out_features: int, propagation: nn.Module,
            bias: bool = False):
        super().__init__()

        self.transform = MixedLinear(in_features, out_features, bias=bias)
        self.propagation = propagation

    def forward(
            self, attr_matrix: torch.Tensor, idx: torch.LongTensor
    ) -> torch.Tensor:
        local_logits = self.transform(attr_matrix)
        return self.propagation(local_logits, idx)


class OurLayer(nn.Module):
    def __init__(
            self, in_features: int, out_features: int, drop_prob: float,
            propagation: nn.Module, bias: bool = False):
        super().__init__()
        if drop_prob == 0.0:
            self.dropout = lambda x: x
        else:
            self.dropout = MixedDropout(drop_prob)
        self.branch_0 = OurBranch(
            in_features, out_features, Id(), bias=bias)
        self.branch_m1 = OurBranch(
            in_features, out_features, propagation, bias=bias)

    def forward(
            self, t_and_idx: Tuple[torch.Tensor, torch.LongTensor]
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
        t, idx = t_and_idx
        t = self.dropout(t)
        return self.branch_0(t, idx=idx) + self.branch_m1(t, idx=idx), idx


class ReLUWithIdx(nn.Module):
    def forward(
            self, t_and_idx: Tuple[torch.Tensor, torch.LongTensor]
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
        t, idx = t_and_idx
        return F.relu(t), idx


class ResolvNet(nn.Module):
    def __init__(
            self, nfeatures: int, nclasses: int, hiddenunits: List[int],
            drop_prob: float = 0.5, adj_matrix: sp.spmatrix = None,
            omega: float = -1.0, bias: bool = True, normalizing_factor: float = 1.0):
        super().__init__()
        self.propagate = OurPropagate(
            adj_matrix, omega=omega, drop_prob=drop_prob, normalizing_factor=normalizing_factor)
        sizes = [nfeatures] + list(hiddenunits)
        our_layers = []
        for s_in, s_out in zip(sizes[:-1], sizes[1:]):
            our_layers.append(OurLayer(
                s_in, s_out, drop_prob, self.propagate, bias))
            our_layers.append(ReLUWithIdx())
        self.hidden_layers = nn.Sequential(*our_layers)
        if drop_prob == 0.0:
            self.output_dropout = lambda x: x
        else:
            self.output_dropout = MixedDropout(drop_prob)
        self.output_layer = nn.Linear(hiddenunits[-1], nclasses, bias=True)
        self.reg_params = list(self.parameters())

    def forward(self, t: torch.Tensor, idx: torch.LongTensor) -> torch.Tensor:
        t_and_idx = (t, idx)
        t_and_idx = self.hidden_layers(t_and_idx)
        t = self.output_dropout(t_and_idx[0])
        t = self.output_layer(t)[idx]
        return F.log_softmax(t, dim=-1)
