import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Any, Callable, List, Optional, Tuple
import torch


class MaskNet(nn.Module):
    def __init__(
        self,
        summarizer,
        act,
    ) -> None:
        super().__init__()
        self.summarizer = summarizer
        self.dim = summarizer.dim
        self.sample_bias = 0.0

        if "gumbel" in act.lower():
            self.linear1 = nn.Sequential(
                nn.Linear(in_features=self.dim, out_features=self.dim),
                nn.ReLU(),
                nn.Dropout(p=0.5),
                nn.LayerNorm(normalized_shape=self.dim, eps=1e-6)
            )
            self.linear2 = nn.Sequential(
                nn.Linear(in_features=self.dim, out_features=2),
            )
        self.act = act
        self.EPS = 1e-6

    def concrete_sample(self, log_alpha: Tensor, temperature: float = 1.0) -> Tensor:
        r""" Sample from the instantiation of concrete distribution when training """
        if self.training:
            bias = self.sample_bias
            random_noise = torch.rand(log_alpha.shape) * (1 - 2 * bias) + bias
            random_noise = torch.log(random_noise) - torch.log(1.0 - random_noise)
            gate_inputs = (random_noise.to(log_alpha.device) + log_alpha) / temperature
            gate_inputs = gate_inputs.sigmoid()
        else:
            gate_inputs = log_alpha.sigmoid()
        return gate_inputs
    
    def forward(self, feature: Tensor) -> Tensor:
        pass

    def gumbel_read_out(self, feature: Tensor) -> Tensor:
        x = self.summarizer.embed(feature) # T,D
        x = self.linear1(x)
        x = self.linear2(x)
        return x # T,2
    
    def calculate_entropy(self, feature):
        if self.act=="sigmoidConcreteSoft":
            logits = self.summarizer.read_out(feature)
            gate_inputs = logits.sigmoid()  # T,2
            eps = 1e-10
            p = gate_inputs
            entropy = -p * torch.log(p + eps) - (1 - p) * torch.log(1 - p + eps)
        else:
            raise
        return entropy.mean()
    
    def sample_mask(self, feature: Tensor, tau=0.5, ratio=0.6):
        mask_loss = {}
        if self.act=="gumbelHard":
            bi_logits = self.gumbel_read_out(feature) # T,D
            soft_mask = bi_logits.softmax(dim=-1)[:, 0]
            mask = F.gumbel_softmax(bi_logits, tau=tau, hard=True, dim=-1)[:, 0] # T,2
        elif self.act=="gumbelSoft":
            bi_logits = self.gumbel_read_out(feature) # T,D
            soft_mask = bi_logits.softmax(dim=-1)[:, 0]
            mask = F.gumbel_softmax(bi_logits, tau=tau, hard=False, dim=-1)[:, 0] # T,2
        elif self.act=="sigmoidSoft":
            mask = soft_mask = self.summarizer.forward(feature)
        elif self.act=="sigmoidHard":
            soft_mask = self.summarizer.forward(feature)
            mask = torch.bernoulli(soft_mask).detach() - soft_mask.detach() + soft_mask
        elif self.act=="sigmoidConcreteSoft":
            logits = self.summarizer.read_out(feature)
            soft_mask = self.concrete_sample(logits, temperature=tau) # T,2
            mask = soft_mask
        elif self.act=="sigmoidConcreteHard":
            logits = self.summarizer.read_out(feature)
            soft_mask = self.concrete_sample(logits, temperature=tau) # T,2
            mask = torch.bernoulli(soft_mask).detach() - soft_mask.detach() + soft_mask
        elif self.act=="sigratio":
            soft_mask = self.summarizer.forward(feature)
            num_samples = int(soft_mask.numel() * ratio)
            flat_mask = torch.zeros_like(soft_mask).view(-1)
            flat_mask[torch.multinomial(soft_mask.view(-1).softmax(0), num_samples, replacement=False)] = 1.0
            mask = flat_mask.view(soft_mask.shape).detach() - soft_mask.detach() + soft_mask

        mask_loss["size"] =  soft_mask.mean()
        mask_loss["ent"] = (soft_mask * (1 - soft_mask)).mean()
            
        return mask, mask_loss