from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import DenseGINConv as DenseGINConv_orig

from peagang.data.dense.utils.features import stable_sym_eigen
from peagang.utils.utils import kcycles
from peagang.models.components.utilities_classes import SpectralNorm, SpectralNormNonDiff, Swish, NodeFeatNorm, LinearTransmissionLayer
from pdb import set_trace


class DenseGINConv(DenseGINConv_orig):
    def __init__(
        self, model, *args, transmission_layer=True, spectral_norm=None, **kwargs
    ):
        super().__init__(model)
        if spectral_norm == "diff":
            sn = lambda x: SpectralNorm(x, name="B")
        elif spectral_norm == "nondiff":
            sn = lambda x: SpectralNormNonDiff(x, name="B")
        else:
            sn = lambda x: x
        if transmission_layer:
            c_in = None
            for l in model:
                if isinstance(l, SpectralNorm) or isinstance(l, SpectralNormNonDiff):
                    l = l.module
                if hasattr(l, "weight"):
                    c_in = l.weight.shape[1]
                    break
                elif hasattr(l, "weight_bar"):
                    c_in = l.weight_bar.shape[1]
                    break
            c_out = None
            for l in reversed(model):
                if isinstance(l, SpectralNorm) or isinstance(l, SpectralNormNonDiff):
                    l = l.module
                if hasattr(l, "weight"):
                    c_out = l.weight.shape[0]
                    break
                elif hasattr(l, "weight_bar"):
                    c_out = l.weight_bar.shape[0]
                    break
                elif "weight_bar" in l._parameters:
                    c_out = l._parameters["weight_bar"].shape[0]
                    break
            self.transmission_layer = sn(LinearTransmissionLayer(c_in, c_out))
        else:
            self.transmission_layer = None

    def forward(self, x, adj, mask=None, add_loop=True):
        out = super().forward(x, adj, mask=mask, add_loop=add_loop)
        if self.transmission_layer is not None:
            transmission = self.transmission_layer(x)
            out = out + transmission
        return out

    def extra_repr(self) -> str:
        return f"{super().extra_repr()},\n transmission_layer={str(self.transmission_layer)}"


class DenseGIN(torch.nn.Module):
    def __init__(
        self,
        channels: List[int],
        gin_hidden_width=32,
        models=None,
        act=torch.nn.ReLU,
        spectral_norm=None,
        dropout=None,
    ):
        super().__init__()
        if spectral_norm == "diff":
            sn = SpectralNorm
        elif spectral_norm == "nondiff":
            sn = SpectralNormNonDiff
        else:
            sn = lambda x: x
        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
        else:
            self.dropout = None
        if models is None:
            models = []
            for i in range(1, len(channels)):
                c_i = sum(channels[:i])
                c_o = channels[i]
                l = [
                    NodeFeatNorm(c_i),
                    sn(torch.nn.Linear(c_i, gin_hidden_width)),
                    act(),
                    sn(torch.nn.Linear(gin_hidden_width, c_o)),
                ]
                m = torch.nn.Sequential(*l)
                models.append(m)

        self.gins = torch.nn.ModuleList()
        for i, m in enumerate(models):
            self.gins.append(DenseGINConv(models[i], spectral_norm=spectral_norm))

    def forward(self, x, A):
        """

        :param x: B N F
        :param A:
        :return:
        """
        outs = [x]
        for gin in self.gins:
            # aggregate node features
            agg = torch.cat(outs, dim=-1)
            if self.dropout:
                agg = self.dropout(agg)
            outs.append(gin(agg, A))
        return torch.cat(outs, dim=-1)


class GCNSkip(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        gin_hidden_width=32,
        model=None,
        swish=False,
        spectral_norm=None,
        dropout=None,
    ):
        super().__init__()
        self.norm = torch.nn.InstanceNorm1d(in_channels)
        if swish:
            act = Swish
        else:
            act = torch.nn.ReLU
        if spectral_norm == "diff":
            sn = SpectralNorm
        elif spectral_norm == "nondiff":
            sn = SpectralNormNonDiff
        else:
            sn = lambda x: x
        self.act = (
            act()
        )  # TODO/Note: Before we were doing tanh on the output of the GCN instead

        if model is None:
            l = [
                sn(torch.nn.Linear(in_channels, gin_hidden_width)),
                act(),
                sn(torch.nn.Linear(gin_hidden_width, out_channels)),
            ]
            model = torch.nn.Sequential(*l)
        self.gcn = DenseGINConv(model, spectral_norm=spectral_norm)
        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
        else:
            self.dropout = None

        if in_channels != out_channels:
            self.proj = DenseGINConv(
                torch.nn.Sequential(sn(torch.nn.Linear(in_channels, out_channels))),
                spectral_norm=spectral_norm,
            )
        else:
            self.proj = None

    def forward(self, x, A):
        if self.dropout:
            x = self.dropout(x)
        if self.proj is not None:
            xskip = self.proj(x, A)
        else:
            xskip = x

        xn = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1)
        xr = self.act(xn)
        xg = self.gcn(xr, A)
        out = xg + xskip
        return out


class GCNSkipBlock(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        layers=1,
        swish=False,
        spectral_norm=None,
        dropout=None,
    ):
        # TODO rediscuss this, need to take a look at norm vs conv and variable number of nodes
        super().__init__()
        self.skip_layers = torch.nn.ModuleList()
        self.skip_layers.append(
            GCNSkip(
                in_channels,
                out_channels,
                swish=swish,
                spectral_norm=spectral_norm,
                dropout=dropout,
            )
        )
        for _ in range(layers - 1):
            self.skip_layers.append(
                GCNSkip(
                    out_channels,
                    out_channels,
                    swish=swish,
                    spectral_norm=spectral_norm,
                    dropout=dropout,
                )
            )

    def forward(self, x, A):
        out = x
        for i, block in enumerate(self.skip_layers):
            out = block(out, A)
        return out


class DiscriminatorReadout(torch.nn.Module):
    def __init__(self, node_features, num_hidden, out_features, kc_flag=True, swish=False, spectral_norm=None,
                 dropout=None, eigenfeat4=False):
        super().__init__()
        self.kc_flag = kc_flag
        self.eigenfeat4=eigenfeat4
        effective_node_feats=node_features
        if self.kc_flag:
            self.kcycles = kcycles()
            effective_node_feats+=4
        if self.eigenfeat4:
            effective_node_feats += 4
        self.hidden_lin = torch.nn.Linear(effective_node_feats, num_hidden)

        if swish:
            self.act = Swish()
        else:
            self.act = torch.nn.ReLU()
        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
        else:
            self.dropout = None
        self.out_lin = torch.nn.Linear(num_hidden, out_features)
        if spectral_norm == "diff":
            self.hidden_lin = SpectralNorm(self.hidden_lin)
            self.out_lin = SpectralNorm(self.out_lin)
        elif spectral_norm == "nondiff":
            self.hidden_lin = SpectralNormNonDiff(self.hidden_lin)
            self.out_lin = SpectralNormNonDiff(self.out_lin)

    def forward(self, X, adj):
        if self.dropout:
            X = self.dropout(X)
        xs = X.sum(dim=-2)
        if self.kc_flag:
            kcycles = self.kcycles.k_cycles(adj)

            xs = torch.cat([xs, kcycles], dim=-1)
        if self.eigenfeat4:
            eigen_vals = stable_sym_eigen(adj)[:,:4]

            xs = torch.cat([xs, eigen_vals], dim=-1)

        xl = self.hidden_lin(xs)
        xt = self.act(xl)
        if self.dropout:
            xt = self.dropout(xt)
        xout = self.out_lin(xt)
        return xout


class DensenetDiscriminator(nn.Module):
    """GCN encoder with residual connections"""

    def __init__(
        self,
        node_feature_dim,
        conv_channels,
        readout_hidden=32,
        swish=True,
        spectral_norm=None,
        dropout=None,
        kc_flag=True,
    ):
        super(DensenetDiscriminator, self).__init__()

        # Convolutional layers
        # +1 for num_nodes

        self.swish = swish
        all_channels = [node_feature_dim + 1] + conv_channels
        if swish:
            self.trunk = DenseGIN(
                channels=all_channels,
                act=Swish,
                spectral_norm=spectral_norm,
                dropout=dropout,
            )
        else:
            self.trunk = DenseGIN(
                channels=all_channels, spectral_norm=spectral_norm, dropout=dropout
            )
        self.read_out = DiscriminatorReadout(sum(all_channels), readout_hidden, out_features=1, kc_flag=kc_flag,
                                             swish=swish, spectral_norm=spectral_norm, dropout=dropout)

    def forward(self, x, adj):
        if x.dim() < 3:
            # ensure batch size dime exists, otherwise InstanceNorm1D throws errors
            x = x.unsqueeze(0)
            adj = adj.unsqueeze(0)

        zi = self.trunk(x, adj)
        out = self.read_out(zi, adj)
        return out


class Discriminator(nn.Module):
    """GCN encoder with residual connections"""

    def __init__(
        self,
        node_feature_dim,
        conv_channels,
        readout_hidden=64,
        swish=False,
        spectral_norm=None,
        dropout=None,
        kc_flag=True,
        eigenfeat4=False
    ):
        super().__init__()

        # Convolutional layers
        self.gcn_layers = torch.nn.ModuleList()
        for l_ in range(len(conv_channels)):
            self.gcn_layers.append(
                GCNSkipBlock(
                    in_channels=conv_channels[l_ - 1]
                    if l_ > 0
                    else node_feature_dim + 1,  # +1 for num nodes
                    out_channels=conv_channels[l_],
                    swish=swish,
                    spectral_norm=spectral_norm,
                    dropout=dropout,
                )
            )

        self.read_out = DiscriminatorReadout(
            sum(conv_channels),
            readout_hidden,
            out_features=1,
            swish=swish,
            spectral_norm=spectral_norm,
            dropout=dropout,
            kc_flag=kc_flag,
            eigenfeat4=eigenfeat4

        )

    def forward(self, x, adj):
        if x.dim() < 3:
            # ensure batch size dime exists, otherwise InstanceNorm1D throws errors
            x = x.unsqueeze(0)
            adj = adj.unsqueeze(0)

        _x = x
        zi = x
        zis = []
        for i, gcn in enumerate(self.gcn_layers):
            zi = gcn(zi, adj)
            zis.append(zi)

        zfinal = torch.cat(zis, -1)
        out = self.read_out(zfinal, adj)
        return out


class GraphConvolution(nn.Module):
    def __init__(self, input_dim, out_dim, dropout_rate=0.0):
        super(GraphConvolution, self).__init__()

        layers = []
        layers.append(nn.Linear(input_dim, out_dim))
        self.linear = nn.Sequential(*layers)

        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, inputs, activation=None):
        adjacency_tensor, hidden_tensor, node_tensor = inputs
        adj = adjacency_tensor

        annotations = (
            torch.cat((hidden_tensor, node_tensor), -1)
            if hidden_tensor is not None
            else node_tensor
        )

        output = self.linear(annotations)
        output = torch.matmul(adj, output)
        output = output + self.linear(annotations)
        output = activation(output) if activation is not None else output
        output = self.dropout(output)

        return output


class GraphAggregation(nn.Module):
    def __init__(self, in_features, out_features, n_dim, dropout_rate=0):
        super(GraphAggregation, self).__init__()
        self.sigmoid_linear = nn.Sequential(
            nn.Linear(in_features + n_dim, out_features), nn.Sigmoid()
        )
        self.tanh_linear = nn.Sequential(
            nn.Linear(in_features + n_dim, out_features), nn.Tanh()
        )
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, inputs, activation):
        i = self.sigmoid_linear(inputs)  # i: BxNx128
        j = self.tanh_linear(inputs)  # j: BxNx128
        output = torch.sum(torch.mul(i, j), 1)  # output: Bx128
        output = activation(output) if activation is not None else output
        output = self.dropout(output)

        return output


class MolGAN_Discriminator(nn.Module):
    def __init__(
        self,
        node_feature_dim,
        conv_channels=None,
        readout_hidden=64,
        swish=False,
        spectral_norm=None,
    ):
        super(MolGAN_Discriminator, self).__init__()

        auxiliary_dim = 128
        self.layers_ = [[node_feature_dim, 128], [128 + node_feature_dim, 64]]

        self.bn = torch.nn.ModuleList()
        self.gcn_layers = torch.nn.ModuleList()
        for l_ in self.layers_:
            self.gcn_layers.append(GraphConvolution(l_[0], l_[1]))

        self.agg_layer = GraphAggregation(64, auxiliary_dim, node_feature_dim)

        # Multi dense layer [128x64]
        layers = []
        for c0, c1 in zip([auxiliary_dim], [64]):
            layers.append(nn.Linear(c0, c1))
            layers.append(nn.Tanh())
        self.linear_layer = nn.Sequential(*layers)  # L1: 256x512 | L2: 512x256

        # Linear map [128x1]
        self.output_layer = nn.Linear(64, 1)

    def forward(self, node, adj):
        h = None
        for l in range(len(self.layers_)):
            h = self.gcn_layers[l](inputs=(adj, h, node))
        annotations = torch.cat((h, node), -1)
        h = self.agg_layer(annotations, torch.nn.Tanh())
        h = self.linear_layer(h)

        output = self.output_layer(h)

        return output
