from typing import Optional

import einops
import torch
import torch.nn as nn

from .encoder import Encoder
from .decoder import Decoder
from .query import Query_Gen_transformer, MultiHeadAttention_kqv
import torch.nn.functional as F
import numpy as np
from .hypernn_utils import create_functional_target_network
from torch import vmap

class target_net(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.in_feats = [in_features, hidden_features, hidden_features]
        self.out_feats = [hidden_features, hidden_features, out_features]

        self.linear1 = nn.Linear(in_features, hidden_features)
        self.linear2 = nn.Linear(hidden_features, hidden_features)
        self.linear3 = nn.Linear(hidden_features, hidden_features)
        self.linear4 = nn.Linear(hidden_features, hidden_features)
        self.linear5 = nn.Linear(hidden_features, out_features)
        self.leakyrelu = nn.LeakyReLU(negative_slope=0.1)
        #self.silu = nn.SiLU()
        #self.relu = nn.ReLU()
        #self.elu = nn.ELU(alpha=1.0)

    def forward(self, x):
        x = self.leakyrelu(self.linear1(x))
        x = self.leakyrelu(self.linear2(x))
        x = self.leakyrelu(self.linear3(x))
        x = self.leakyrelu(self.linear4(x))
        x = self.linear5(x)
        return x

    def get_in_dims(self):
        return self.in_feats

    def get_out_dims(self):
        return self.out_feats

    def get_submodules(self):
        return [self.linear1, self.linear2, self.linear3, self.linear4]


def log_mean_exp(a):
    b = torch.max(a, dim=1, keepdim=True).values
    # Compute the mean of exponentials instead of the sum, by dividing by the number of elements in a
    return torch.log(torch.mean(torch.exp(a - b), dim=1)) + b[:, 0]


class InfoNet(nn.Module):

    def __init__(
            self,
            encoder: Encoder,
            decoder: Decoder,
            query_gen: Query_Gen_transformer,
            decoder_query_dim: int,
            input_dim_x: int = 3,
            input_dim_y: int = 3,
            num_mlp_layer: int = 1,
            hidden_dim: int = 512,
            dropout=float(0.0),
            targetnet_hiddim=int(256),
            hypermlp_hiddim=int(2048)
    ):

        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.query_gen = query_gen
        self.num_mlp_layer = num_mlp_layer
        self.input_dim_x = input_dim_x
        self.input_dim_y = input_dim_y
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.decoder_query_dim = decoder_query_dim
        self.targetnet_hiddim = targetnet_hiddim

        # self.query = nn.Parameter(torch.randn(1, decoder_query_dim, decoder_query_dim))
        # self.weight_gen = nn.ModuleList([WeightNet(input_dim, self.targetnet_hiddim, self.decoder_query_dim),
        #        WeightNet(self.targetnet_hiddim, self.targetnet_hiddim, self.decoder_query_dim),
        #        WeightNet(self.targetnet_hiddim, 1, self.decoder_query_dim)]
        # )

        target_network = target_net(input_dim_x+input_dim_y, self.targetnet_hiddim, 1)
        target_total_params = sum(p.numel() for p in target_network.parameters())
        self.functional_target_network = create_functional_target_network(
            target_network
        )

        self.weight_gen = nn.Sequential(
            nn.Linear(decoder_query_dim, hypermlp_hiddim),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hypermlp_hiddim, hypermlp_hiddim),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hypermlp_hiddim, hypermlp_hiddim),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hypermlp_hiddim, hypermlp_hiddim),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hypermlp_hiddim, hypermlp_hiddim),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hypermlp_hiddim, target_total_params//decoder_query_dim+1)
        )

    def forward(
            self,
            inputs: Optional[torch.Tensor],
            query: Optional[torch.Tensor] = None,
            input_mask: Optional[torch.Tensor] = None,
            query_mask: Optional[torch.Tensor] = None,
            early_sup=False
    ):

        latents = self.encoder(inputs, input_mask)
        query = self.query_gen(inputs)

        outputs = self.decoder(
            x_q=query,
            latents=latents,
            query_mask=query_mask
        )
       
        outputs = self.weight_gen(outputs)
        weights = einops.rearrange(outputs, 'b n m -> b (n m)')
        batched_forward = vmap(self.functional_target_network, in_dims=(0, 0))

        log_mean_exp_et = 0
        for _ in range(5):
            perm = torch.randperm(inputs.shape[1])
            marginal = torch.cat((inputs[:, :, 0:self.input_dim_x], inputs[:, perm, self.input_dim_x:]),
                                dim=2)
            et = batched_forward(weights, marginal)
            #et = self.weight_gen.forward_target(weight_dict, marginal, early_sup=early_sup)
            log_mean_exp_et += log_mean_exp(et)

        #t = self.weight_gen.forward_target(weight_dict, inputs, early_sup=early_sup)
        t = batched_forward(weights, inputs)
        #et = self.weight_gen.forward_target(weight_dict, marginal, early_sup=early_sup)

        mi_lb = torch.mean(t, dim=1) - log_mean_exp_et/5

        return mi_lb
