from typing import Optional

import einops
import torch
import torch.nn as nn

from .encoder import Encoder
from .decoder import Decoder
from .query import Query_Gen_transformer, MultiHeadAttention_kqv
import torch.nn.functional as F
import numpy as np
from .hypernn_utils import create_functional_target_network
from torch import vmap

def log_mean_exp(a):
    b = torch.max(a, dim=1, keepdim=True).values
    # Compute the mean of exponentials instead of the sum, by dividing by the number of elements in a
    return torch.log(torch.mean(torch.exp(a - b), dim=1)) + b[:, 0]

class target_net(nn.Module):
    def __init__(self, in_features_x, in_features_y, out_features=1, hidden_features=512, num_layers=6, negative_slope=0.1):
        super().__init__()

        self.in_feats = [in_features_x, hidden_features, hidden_features]
        self.out_feats = [hidden_features, hidden_features, out_features]
        
        self.fc1_x = nn.Linear(in_features_x, hidden_features, bias=False)
        self.fc1_y = nn.Linear(in_features_y, hidden_features, bias=False)
        self.fc1_bias = nn.Parameter(torch.zeros(hidden_features))
        
        self.hidden_layers = nn.ModuleList(
            [nn.Linear(hidden_features, hidden_features) for _ in range(num_layers - 2)]
        )
        
        self.fc_out = nn.Linear(hidden_features, out_features)
        self.negative_slope = negative_slope
        self.in_features_x = in_features_x
        self.in_features_y = in_features_y

    def forward(self, input):
        
        x = input[:, :self.in_features_x]
        y = input[:, self.in_features_x:]
        x = self.fc1_x(x)
        y = self.fc1_y(y)
        x = F.leaky_relu(x + y + self.fc1_bias, negative_slope=self.negative_slope)
        
        for layer in self.hidden_layers:
            x = F.leaky_relu(layer(x), negative_slope=self.negative_slope)
        
        x = self.fc_out(x)
        return x

    def get_in_dims(self):
        return self.in_feats

    def get_out_dims(self):
        return self.out_feats

class InfoNet(nn.Module):

    def __init__(
            self,
            encoder: Encoder,
            decoder: Decoder,
            query_gen: Query_Gen_transformer,
            decoder_query_dim: int,
            input_dim_x: int = 3,
            input_dim_y: int = 3,
            num_mlp_layer: int = 1,
            hidden_dim: int = 512,
            dropout=float(0.0),
            targetnet_hiddim=int(256),
            hypermlp_hiddim=int(2048)
    ):

        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.query_gen = query_gen
        self.num_mlp_layer = num_mlp_layer
        self.input_dim_x = input_dim_x
        self.input_dim_y = input_dim_y
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.decoder_query_dim = decoder_query_dim
        self.targetnet_hiddim = targetnet_hiddim

        # self.query = nn.Parameter(torch.randn(1, decoder_query_dim, decoder_query_dim))
        # self.weight_gen = nn.ModuleList([WeightNet(input_dim, self.targetnet_hiddim, self.decoder_query_dim),
        #        WeightNet(self.targetnet_hiddim, self.targetnet_hiddim, self.decoder_query_dim),
        #        WeightNet(self.targetnet_hiddim, 1, self.decoder_query_dim)]
        # )

        target_network = target_net(input_dim_x, input_dim_y, 1, self.targetnet_hiddim, num_layers=6, negative_slope=0.1)
        target_total_params = sum(p.numel() for p in target_network.parameters())
        self.functional_target_network = create_functional_target_network(
            target_network
        )

        self.weight_gen = nn.Sequential(
            nn.Linear(decoder_query_dim, hypermlp_hiddim),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hypermlp_hiddim, hypermlp_hiddim),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hypermlp_hiddim, hypermlp_hiddim),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hypermlp_hiddim, hypermlp_hiddim),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hypermlp_hiddim, hypermlp_hiddim),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hypermlp_hiddim, target_total_params//decoder_query_dim+1)
        )

    def forward(
            self,
            inputs: Optional[torch.Tensor],
            query: Optional[torch.Tensor] = None,
            input_mask: Optional[torch.Tensor] = None,
            query_mask: Optional[torch.Tensor] = None,
            early_sup=False
    ):

        latents = self.encoder(inputs, input_mask)
        query = self.query_gen(inputs)

        outputs = self.decoder(
            x_q=query,
            latents=latents,
            query_mask=query_mask
        )
       
        outputs = self.weight_gen(outputs)
        weights = einops.rearrange(outputs, 'b n m -> b (n m)')
        batched_forward = vmap(self.functional_target_network, in_dims=(0, 0))

        log_mean_exp_et = 0
        for _ in range(20):
            perm = torch.randperm(inputs.shape[1])
            marginal = torch.cat((inputs[:, :, 0:self.input_dim_x], inputs[:, perm, self.input_dim_x:]),
                                dim=2)
            et = batched_forward(weights, marginal)
            #et = self.weight_gen.forward_target(weight_dict, marginal, early_sup=early_sup)
            log_mean_exp_et += log_mean_exp(et)

        #t = self.weight_gen.forward_target(weight_dict, inputs, early_sup=early_sup)
        t = batched_forward(weights, inputs)
        #et = self.weight_gen.forward_target(weight_dict, marginal, early_sup=early_sup)

        mi_lb = torch.mean(t, dim=1) - log_mean_exp_et/20

        return mi_lb
