from cmath import isnan
import torch
import torch.nn as nn
from collections import OrderedDict

class PFL_Hypernet(nn.Module):
    def __init__(self, n_nodes, embedding_dim, num_layers, num_hidden, out_params_shapes, lr, device, input_size=1, use_layernorm=False, dropout=0):
        super().__init__()
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.lr = lr
        self.n_nodes = n_nodes
        self.embedding_dim = embedding_dim
        self.device = device
        self.input_size = input_size
        
        self.out_params_dict = out_params_shapes
        self.out_dim = self.calculate_out_dim()
        self.dropout = dropout
        self.use_layernorm=use_layernorm
        shifts = nn.Parameter(torch.zeros(self.out_dim, device=self.device))
        scales = nn.Parameter(torch.ones(self.out_dim, device=self.device))
        self.register_parameter(name='shifts', param=shifts)
        self.register_parameter(name='scales', param=scales)
        self.validate_inputs(n_nodes, embedding_dim, num_layers, num_hidden, lr)
        if self.input_size == 1:
            self.embedding = nn.Embedding(num_embeddings=n_nodes, embedding_dim=embedding_dim)
        else:
            self.embedding = nn.Linear(input_size-1, embedding_dim)
            self.layers = [self.embedding]
        if num_layers == 1:
            self.layers.append(nn.Linear(embedding_dim + input_size - 1, self.out_dim))
        else:
            self.layers.append(nn.Linear(embedding_dim + input_size - 1, num_hidden))
            for i in range(1, num_layers - 1):
                #Add activation function
                self.layers.append(nn.ReLU())
                #Add Layernorm
                if self.use_layernorm:
                    self.layers.append(nn.LayerNorm(num_hidden))
                #Add dropout
                if self.dropout > 0:
                    self.layers.append(nn.Dropout(self.dropout))
                #Add next linear layer
                self.layers.append(nn.Linear(num_hidden, num_hidden))\
            #Add penultimate activation
            self.layers.append(nn.ReLU())
            #Add penultimate layernorm
            if self.use_layernorm:
                self.layers.append(nn.LayerNorm(num_hidden))
                #Add penultimate dropout
            if self.dropout > 0:
                self.layers.append(nn.Dropout(self.dropout))
            #Add output layer
            self.layers.append(nn.Linear(num_hidden, self.out_dim))
        if self.input_size != 1:
            self.embedding = self.embedding.to(device)
            self.net = nn.Sequential(*self.layers[1:]).to(device)
        else:
            self.net = nn.Sequential(*self.layers).to(device)
        
    def calculate_out_dim(self):
        dim = 0
        self.products_dict = {}
        for k, v in self.out_params_dict.items():
            product = 1
            for i in range(len(v)):
                product *= v[i]
            dim += product
            self.products_dict[k] = product
        return dim

    def create_weight_dict(self, weight_vector):
        weight_vector = weight_vector.squeeze()
        return_dict = OrderedDict()
        index = 0
        for k in self.out_params_dict.keys():

            # restrict weights to have mean 0 and variance scaled with avg of fan_in and fan_out, similar to Xavier initialization
           
           weights = weight_vector[index:index + self.products_dict[k]].reshape(self.out_params_dict[k])
           shifts = self.shifts[index:index + self.products_dict[k]].reshape(self.out_params_dict[k])
           scales = self.scales[index:index + self.products_dict[k]].reshape(self.out_params_dict[k])
           if len(self.out_params_dict[k]) > 1:
            fan_in = self.out_params_dict[k][0]
            fan_out = self.out_params_dict[k][1]
            mean_w = weights.mean()
            std_w = weights.std()
            weights = (weights - mean_w) / std_w
            weights *= scales * ((2 / (fan_in + fan_out)) ** 0.5)
            weights += shifts

           return_dict[k] = weights.to("cpu")
           
           index += self.products_dict[k]
        return return_dict
        
    def validate_inputs(self, n_nodes, embedding_dim, num_layers, num_hidden, lr):
        assert n_nodes > 0, "n_nodes <= 0"
        assert isinstance(n_nodes, int) == True, "n_nodes must be an int"
        assert embedding_dim >= 0, "embedding_dim <= 0"
        assert isinstance(embedding_dim, int) == True, "embedding_dim must be an int"
        assert num_layers > 0, "num_layers <= 0"
        assert isinstance(num_layers, int) == True, "num_layers must be an int"
        assert num_hidden > 0, "num_hidden <= 0"
        assert isinstance(num_hidden, int) == True, "num_hidden must be an int"
        assert lr > 0, "lr <= 0"
        
    def forward(self, x):
        # return self.create_weight_dict(self.net(torch.tensor(x).to(self.device)))

        if self.input_size != 1:
            embed = self.layers[0](torch.tensor(x[1:]).float().to(self.device))
            net_input = torch.cat([embed, torch.tensor(x[1:]).float().to(self.device)], dim=0)
            return self.create_weight_dict(self.net(net_input.unsqueeze(0)))
        else:
            return self.create_weight_dict(self.net(torch.tensor(x).to(self.device)))
