from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import numpy as np

from greatx.nn.layers import Sequential, activations
from greatx.utils import wrapper
from greatx.functional import spmm

from torch_geometric.utils import add_self_loops
from torch_sparse import SparseTensor, fill_diag


#######################
# Orthonormalization
#######################
def orthonormalize_weights(w, beta=0.5, iters=20, order=3):
    if order == 1:
        for _ in range(iters):
            w_t_w = w.t().mm(w)
            w = (1 + beta) * w - beta * w.mm(w_t_w)

    elif order == 2:
        if beta != 0.5:
            raise ValueError("Order >1 requires beta = 0.5")
        for _ in range(iters):
            w_t_w = w.t().mm(w)
            w_t_w_w_t_w = w_t_w.mm(w_t_w)
            w = ((15 / 8) * w
                 - (5 / 4) * w.mm(w_t_w)
                 + (3 / 8) * w.mm(w_t_w_w_t_w))

    elif order == 3:
        if beta != 0.5:
            raise ValueError("Order >1 requires beta = 0.5")
        for _ in range(iters):
            w_t_w = w.t().mm(w)
            w_t_w_w_t_w = w_t_w.mm(w_t_w)
            w_t_w_w_t_w_w_t_w = w_t_w.mm(w_t_w_w_t_w)
            w = ((35 / 16) * w
                 - (35 / 16) * w.mm(w_t_w)
                 + (21 / 16) * w.mm(w_t_w_w_t_w)
                 - (5 / 16) * w.mm(w_t_w_w_t_w_w_t_w))
    else:
        raise NotImplementedError("Only order 1-3 supported.")

    return w


def scale_values(weight):
    """Returns a scalar value used to normalize the full weight matrix."""
    scale = np.sqrt(weight.shape[0] * weight.shape[1])
    return torch.tensor(scale, dtype=weight.dtype, device=weight.device)


#######################
# Graph Utilities
#######################
def dense_add_self_loops(adj: Tensor, fill_value: float = 1.0) -> Tensor:
    diag = torch.diag(adj.new_full((adj.size(0), ), fill_value))
    return adj + diag


def dense_gcn_norm(adj: Tensor, improved: bool = False,
                   add_self_loops: bool = True, rate: float = -0.5) -> Tensor:
    fill_value = 2. if improved else 1.
    if add_self_loops:
        adj = dense_add_self_loops(adj, fill_value)
    deg = adj.sum(dim=1)
    deg_inv_sqrt = deg.pow_(rate)
    deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
    norm_src = deg_inv_sqrt.view(1, -1)
    norm_dst = deg_inv_sqrt.view(-1, 1)
    return norm_src * adj * norm_dst


def make_self_loops(edge_index, edge_weight=None, num_nodes=None,
                    fill_value=1.0, improved=False):
    fill_value = 2. if improved else 1.
    if isinstance(edge_index, Tensor) and edge_index.dtype == torch.long:
        edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                 fill_value=fill_value,
                                                 num_nodes=num_nodes)
    elif isinstance(edge_index, Tensor) and edge_index.dtype == torch.float:
        edge_index = dense_add_self_loops(edge_index, fill_value)
    elif isinstance(edge_index, SparseTensor):
        edge_index = fill_diag(edge_index, fill_value)
    else:
        raise ValueError(f"Type {type(edge_index)} is not supported.")
    return edge_index, edge_weight


def make_gcn_norm(edge_index, edge_weight=None, num_nodes=None,
                  add_self_loops=True, dtype=None):
    from torch_geometric.nn.conv.gcn_conv import gcn_norm
    if isinstance(edge_index, Tensor) and edge_index.dtype == torch.long:
        edge_index, edge_weight = gcn_norm(edge_index, edge_weight,
                                           num_nodes=num_nodes, improved=False,
                                           add_self_loops=add_self_loops,
                                           dtype=dtype)
    elif isinstance(edge_index, Tensor) and edge_index.dtype == torch.float:
        edge_index = dense_gcn_norm(edge_index, improved=False,
                                    add_self_loops=add_self_loops)
    elif isinstance(edge_index, SparseTensor):
        edge_index = gcn_norm(edge_index, num_nodes=num_nodes,
                              improved=False, add_self_loops=add_self_loops,
                              dtype=dtype)
    else:
        raise ValueError(f"Type {type(edge_index)} is not supported.")
    return edge_index, edge_weight


#######################
# GCORN Conv Layer
#######################
class GCORNConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int,
                 beta: float = 0.5, iters: int = 20, order: int = 2,
                 bias: bool = True):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.beta = beta
        self.iters = iters
        self.order = order

        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / np.sqrt(self.weight.size(1))
        nn.init.orthogonal_(self.weight, gain=stdv)
        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def forward(self, x: Tensor, edge_index, edge_weight=None) -> Tensor:
        W = self.weight
        scaling = scale_values(W)  # scalar
        W_scaled = W / scaling
        ortho_W = orthonormalize_weights(W_scaled.t(), beta=self.beta,
                                         iters=self.iters, order=self.order).t()
        x = F.linear(x, ortho_W)
        out = spmm(x, edge_index, edge_weight)

        if self.bias is not None:
            out += self.bias

        return out


class GCORN(nn.Module):
    @wrapper
    def __init__(self, in_channels: int, out_channels: int,
                 hids: List[int] = [64], acts: List[str] = ['relu'],
                 dropout: float = 0.5, bias: bool = True,
                 bn: bool = False, normalize: bool = True,
                 beta: float = 0.5, iters: int = 20, order: int = 2):
        super().__init__()

        layers = []
        assert len(hids) == len(acts)

        # Feature extractor
        for hid, act in zip(hids, acts):
            layers.append(
                GCORNConv(in_channels, hid, beta=beta,
                          iters=iters, order=order, bias=bias))
            if bn:
                layers.append(nn.BatchNorm1d(hid))
            layers.append(activations.get(act))
            layers.append(nn.Dropout(dropout))
            in_channels = hid

        self.encoder = Sequential(*layers)

        # Final MLP classifier
        self.classifier = nn.Linear(in_channels, out_channels)

    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.classifier.reset_parameters()

    def forward(self, x, edge_index, edge_weight=None):
        edge_index, edge_weight = make_self_loops(edge_index, edge_weight,
                                                    num_nodes=x.size(0))
        edge_index, edge_weight = make_gcn_norm(edge_index, edge_weight,
                                                num_nodes=x.size(0),
                                                dtype=x.dtype,
                                                add_self_loops=False)
        x = self.encoder(x, edge_index, edge_weight)
        x = self.classifier(x)
        return x
