from collections import OrderedDict

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import spectral_norm
import numpy as np
from torch_geometric.nn import GCNConv, InnerProductDecoder
from torch_geometric.utils import negative_sampling

EPS = 1e-15


class CNNHyper(nn.Module):
    def __init__(
            self, n_nodes, embedding_dim, in_channels=3, out_dim=10, n_kernels=16, hidden_dim=100,
            spec_norm=False, n_hidden=1):
        super().__init__()

        self.in_channels = in_channels
        self.out_dim = out_dim
        self.n_kernels = n_kernels
        self.embeddings = nn.Embedding(num_embeddings=n_nodes, embedding_dim=embedding_dim)

        layers = [
            spectral_norm(nn.Linear(embedding_dim, hidden_dim)) if spec_norm else nn.Linear(embedding_dim, hidden_dim),
        ]
        for _ in range(n_hidden):
            layers.append(nn.ReLU(inplace=True))
            layers.append(
                spectral_norm(nn.Linear(hidden_dim, hidden_dim)) if spec_norm else nn.Linear(hidden_dim, hidden_dim),
            )

        self.mlp = nn.Sequential(*layers)

        # linear layer for generating target neural network weights
        self.c1_weights = nn.Linear(hidden_dim, self.n_kernels * self.in_channels * 5 * 5)
        self.c1_bias = nn.Linear(hidden_dim, self.n_kernels)
        self.c2_weights = nn.Linear(hidden_dim, 2 * self.n_kernels * self.n_kernels * 5 * 5)
        self.c2_bias = nn.Linear(hidden_dim, 2 * self.n_kernels)
        self.l1_weights = nn.Linear(hidden_dim, 120 * 2 * self.n_kernels * 5 * 5)
        self.l1_bias = nn.Linear(hidden_dim, 120)
        self.l2_weights = nn.Linear(hidden_dim, 84 * 120)
        self.l2_bias = nn.Linear(hidden_dim, 84)
        self.l3_weights = nn.Linear(hidden_dim, self.out_dim * 84)
        self.l3_bias = nn.Linear(hidden_dim, self.out_dim)

        if spec_norm:
            self.c1_weights = spectral_norm(self.c1_weights)
            self.c1_bias = spectral_norm(self.c1_bias)
            self.c2_weights = spectral_norm(self.c2_weights)
            self.c2_bias = spectral_norm(self.c2_bias)
            self.l1_weights = spectral_norm(self.l1_weights)
            self.l1_bias = spectral_norm(self.l1_bias)
            self.l2_weights = spectral_norm(self.l2_weights)
            self.l2_bias = spectral_norm(self.l2_bias)
            self.l3_weights = spectral_norm(self.l3_weights)
            self.l3_bias = spectral_norm(self.l3_bias)

    def forward(self, idx):
        emd = self.embeddings(idx)
        features = self.mlp(emd)

        weights = OrderedDict({
            "conv1.weight": self.c1_weights(features).view(self.n_kernels, self.in_channels, 5, 5),
            "conv1.bias": self.c1_bias(features).view(-1),
            "conv2.weight": self.c2_weights(features).view(2 * self.n_kernels, self.n_kernels, 5, 5),
            "conv2.bias": self.c2_bias(features).view(-1),
            "fc1.weight": self.l1_weights(features).view(120, 2 * self.n_kernels * 5 * 5),
            "fc1.bias": self.l1_bias(features).view(-1),
            "fc2.weight": self.l2_weights(features).view(84, 120),
            "fc2.bias": self.l2_bias(features).view(-1),
            "fc3.weight": self.l3_weights(features).view(self.out_dim, 84),
            "fc3.bias": self.l3_bias(features).view(-1),
        })
        return weights


class CNNTarget(nn.Module):
    def __init__(self, in_channels=3, n_kernels=16, out_dim=10):
        super(CNNTarget, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, n_kernels, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(n_kernels, 2 * n_kernels, 5)
        self.fc1 = nn.Linear(2 * n_kernels * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, out_dim)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class HyperNet(nn.Module):
    def __init__(self, hidden_dim, target_dnn_struct):
        super(HyperNet, self).__init__()
        self.mlps = nn.ModuleDict()
        self.mapped_name = {}
        for k, shape in target_dnn_struct.items():
            self.mapped_name[k] = k.replace('.','_')
            self.mlps[self.mapped_name[k]] = spectral_norm(nn.Linear(hidden_dim, shape.numel()))

    def _forward(self, x):
        return [self.mlps[mapped_name](x) for k, mapped_name in self.mapped_name.items()]

    def forward(self, x):
        return torch.cat(self._forward(x), dim=-1)
        
         

class BaseHyper(nn.Module):
    def __init__(self, n_nodes, embedding_dim, hidden_dim, target_dnn_struct):
        super(BaseHyper, self).__init__()
        self.embeddings = nn.Embedding(num_embeddings=n_nodes, embedding_dim=embedding_dim)
        self.target_dnn_struct = target_dnn_struct
        self.mapped_name = {}
        self.mlps = HyperNet(hidden_dim, target_dnn_struct)

    def generate_weights(self, node_z):
        x = self.mlps._forward(node_z)
        weights = dict()
        for idx, (k, shape) in enumerate(self.target_dnn_struct.items()):
            weights[k] = x[idx].view(shape)
        return OrderedDict(weights)

    def construct_weights(self, x):
        weights = dict()
        start = 0
        for k, shape in self.target_dnn_struct.items():
            end = start + shape.numel()
            weights[k] = x[start:end].view(shape)
            start = end
        return OrderedDict(weights)

    def encode(self, node_idx):
        return self.embeddings(node_idx)

    def forward(self, node_idx, data=None):
        # Encode embedding with GNN
        z = self.encode(node_idx)
        
        # Generate parameters \theta for client i with learned representation z[i]
        theta = self.generate_weights(z)
        return theta

class MLPHyper(BaseHyper):
    def __init__(self, n_nodes, embedding_dim, hidden_dim, target_dnn_struct, n_hidden=1):
        super(MLPHyper, self).__init__(n_nodes, embedding_dim, hidden_dim, target_dnn_struct)
        layers = [
            nn.Linear(embedding_dim, hidden_dim),
        ]
        for _ in range(n_hidden):
            layers.append(nn.ReLU(inplace=True))
            layers.append(
                nn.Linear(hidden_dim, hidden_dim),
            )

        self.hypernet = nn.Sequential(*layers)
        self.register_buffer('node_idx', torch.arange(n_nodes))

    def encode(self, node_idx=None, data=None):
        if node_idx is None:
            node_idx = self.node_idx
        x = self.embeddings(node_idx)
        x = self.hypernet(x)
        return x

class GNNHyper(BaseHyper):
    def __init__(self, data, n_nodes, embedding_dim, hidden_dim, target_dnn_struct):
        super(GNNHyper, self).__init__(n_nodes, embedding_dim, hidden_dim, target_dnn_struct)
        self.conv1 = GCNConv(embedding_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim*2)
        self.conv3 = GCNConv(hidden_dim*2, hidden_dim)
        layers = [
            nn.Linear(hidden_dim, hidden_dim),
        ]
        for _ in range(2):
            layers.append(nn.ReLU(inplace=True))
            layers.append(
                nn.Linear(hidden_dim, hidden_dim),
            )

        self.fc = nn.Sequential(*layers)
        self.register_buffer('x', data.x)
        self.register_buffer('edge_index', data.edge_index)

    def encode(self, node_idx=None, data=None):
        if data is None:
            x = self.x
            edge_index = self.edge_index
        else:
            x, edge_index = data.x, data.edge_index
            if node_idx is not None:
                node_idx = torch.where(x==node_idx)
        x = self.embeddings(x)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        if node_idx is not None:
            return self.fc(x[node_idx])
        else:
            return self.fc(x)
    
    def forward(self, node_idx, data=None):
        # Encode embedding with GNN
        node_z = self.encode(node_idx, data)
        
        # Generate parameters \theta for client i with learned representation z[i]
        theta = self.generate_weights(node_z)
        return theta


class MLPInnerProductDecoder(InnerProductDecoder):
    def __init__(self, hidden_dim, n_layers):
        super(MLPInnerProductDecoder, self).__init__()
        layers = [
            nn.Linear(hidden_dim, hidden_dim),
        ]
        for _ in range(n_layers-1):
            layers.append(nn.ReLU(inplace=True))
            layers.append(
                nn.Linear(hidden_dim, hidden_dim),
            )

        self.fc = nn.Sequential(*layers)
    
    def foward(self, z, edge_index, sigmoid=True):
        z = self.fc(z)
        return super(InnerProductDecoder, self).foward(z, edge_index, sigmoid=sigmoid)

    def forward_all(self, z, sigmoid=True):
        z = self.fc(z)
        return super(InnerProductDecoder, self).forward_all(z, sigmoid=sigmoid)


class GAEHyper(GNNHyper):
    def __init__(self, data, n_nodes, embedding_dim, hidden_dim, target_dnn_struct, decoder=None):
        super(GAEHyper, self).__init__(data, n_nodes, embedding_dim, hidden_dim, target_dnn_struct)
        self.decoder = InnerProductDecoder() if decoder is None else decoder
    
    def recon_loss(self, z, pos_edge_index=None, neg_edge_index=None):
        r"""Given latent variables :obj:`z`, computes the binary cross
        entropy loss for positive edges :obj:`pos_edge_index` and negative
        sampled edges.

        Args:
            z (Tensor): The latent space :math:`\mathbf{Z}`.
            pos_edge_index (LongTensor): The positive edges to train against.
            neg_edge_index (LongTensor, optional): The negative edges to train
                against. If not given, uses negative sampling to calculate
                negative edges. (default: :obj:`None`)
        """

        if pos_edge_index is None:
            pos_edge_index = self.edge_index
        pos_loss = -torch.log(
            self.decoder(z, pos_edge_index, sigmoid=True) + EPS).mean()

        if neg_edge_index is None:
            neg_edge_index = negative_sampling(pos_edge_index, z.size(0))
        neg_loss = -torch.log(1 -
                            self.decoder(z, neg_edge_index, sigmoid=True) +
                            EPS).mean()

        return pos_loss + neg_loss


class MLPTarget(nn.Module):
    def __init__(self, input_dim, hidden_dim=16, out_dim=1):
        super(MLPTarget, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x.squeeze()

      
class GRUTarget(nn.Module):
    def __init__(self, input_size, hidden_dim=64, out_dim=12, dropout=0,
        cl_decay_steps=0, use_curriculum_learning=True, gru_num_layers=1,
        *args, **kwargs):
        super(GRUTarget, self).__init__()
        self.cl_decay_steps = cl_decay_steps
        self.use_curriculum_learning = use_curriculum_learning
        self.encoder = nn.GRU(
            input_size, hidden_dim, num_layers=gru_num_layers, dropout=dropout
        )
        self.decoder = nn.GRU(
            input_size, hidden_dim, num_layers=gru_num_layers, dropout=dropout
        )
        self.out_net = nn.Linear(hidden_dim, out_dim)
        self.batches_seen = 0

    def _compute_sampling_threshold(self):
        if self.cl_decay_steps == 0:
            return 0
        else:
            return self.cl_decay_steps / (
                    self.cl_decay_steps + np.exp(self.batches_seen / self.cl_decay_steps))

    def forward(self, data, return_encoding=False):
        # B x T x N x F
        x, x_attr, y, y_attr = data
        batch_num, node_num = x.shape[0], x.shape[2]
        x_input = torch.cat((x, x_attr), dim=-1).permute(1, 0, 2, 3).flatten(1, 2) # T x (B x N) x F
        _, h_encode = self.encoder(x_input)
        encoder_h = h_encode # (B x N) x L x F
        if self.training and (not self.use_curriculum_learning):
            y_input = torch.cat((y, y_attr), dim=-1).permute(1, 0, 2, 3).flatten(1, 2)
            y_input = torch.cat((x_input[-1:], y_input[:-1]), dim=0)
            out_hidden, _ = self.decoder(y_input, h_encode)
            out = self.out_net(out_hidden)
            out = out.view(out.shape[0], batch_num, node_num, out.shape[-1]).permute(1, 0, 2, 3)
        else:
            last_input = x_input[-1:]
            last_hidden = h_encode
            step_num = y_attr.shape[1]
            out_steps = []
            y_input = y.permute(1, 0, 2, 3).flatten(1, 2)
            y_attr_input = y_attr.permute(1, 0, 2, 3).flatten(1, 2)
            for t in range(step_num):
                out_hidden, last_hidden = self.decoder(last_input, last_hidden)
                out = self.out_net(out_hidden) # T x (B x N) x F
                out_steps.append(out)
                last_input_from_output = torch.cat((out, y_attr_input[t:t+1]), dim=-1)
                last_input_from_gt = torch.cat((y_input[t:t+1], y_attr_input[t:t+1]), dim=-1)
                if self.training:
                    self.batches_seen += 1
                    p_gt = self._compute_sampling_threshold()
                    p = torch.rand(1).item()
                    if p <= p_gt:
                        last_input = last_input_from_gt
                    else:
                        last_input = last_input_from_output
                else:
                    last_input = last_input_from_output
            out = torch.cat(out_steps, dim=0)
            out = out.view(out.shape[0], batch_num, node_num, out.shape[-1]).permute(1, 0, 2, 3)
        if return_encoding:
            return out, encoder_h
        else:
            return out