import torch 
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


def sample_gumbel(shape, eps=1e-8):
    U = torch.rand(shape)
    return -torch.log(-torch.log(U + eps) + eps)

def hard_sample(out):
    binary_out = torch.round(out)  
    binary_out = (binary_out - out).detach() + out  
    return binary_out

class dyn_hypernetwork(nn.Module):
    def __init__(self, t_structures, lrp_scale=0.4, base=0.5, T_start=0.5, T_end=0.1, target_sparsity=0.2, hidden_dim=32):
        super(dyn_hypernetwork, self).__init__()
        self.T = T_start
        self.T_start, self.T_end = T_start, T_end
        self.base = base
        self.lrp_scale = lrp_scale
        self.target_sparsity = target_sparsity
        self.t_sp = t_structures

        # create independent small MLP for each layer
        self.layer_mlps = nn.ModuleList([
            nn.Sequential(
                nn.Linear(int(self.t_sp[i]), hidden_dim),
                nn.GELU(),
                nn.LayerNorm(hidden_dim),
                nn.Linear(hidden_dim, int(self.t_sp[i])),
                # nn.Tanh() # constraint: output range [-1, 1]
            ) for i in range(len(self.t_sp))
        ])

        # layer importance parameters (learnable) -> global importance
        # multiplicative scaling
        self.layer_importance = nn.Parameter(torch.ones(len(self.t_sp)))
        nn.init.normal_(self.layer_importance, mean=1.0, std=0.1)

        # additive scaling
        self.layer_importance_bias = nn.Parameter(torch.ones(len(self.t_sp)))
        nn.init.normal_(self.layer_importance_bias, mean=0.0, std=0.1)

    def forward(self, layer_activations, input_lrp):
        if len(layer_activations) != len(self.t_sp):
            raise ValueError(f"layer activation count ({len(layer_activations)}) doesn't match structure count ({len(self.t_sp)})")
        
        if len(input_lrp) != len(self.t_sp):
            raise ValueError(f"LRP count ({len(input_lrp)}) doesn't match structure count ({len(self.t_sp)})")

        device = layer_activations[0].device
        batch_size = layer_activations[0].shape[0]
        
        # ensure LRP data on correct device
        lrp_data = [lrp.to(device) for lrp in input_lrp]
        
        # temperature parameter adjustment (for non-training mode)
        if not self.training:  
            progress = getattr(self, "_prog", 0.0) 
            self.T = max(self.T_end,
                         self.T_start - (self.T_start - self.T_end) * progress)
        
        # get layer importance
        layer_importance = F.softplus(self.layer_importance).to(device)

        # generate scores for each layer
        processed = []
        for i, (activation, mlp) in enumerate(zip(layer_activations, self.layer_mlps)):
            # check if activation and LRP shapes match
            if activation.shape != lrp_data[i].shape:
                raise ValueError(f"layer {i} activation shape {activation.shape} doesn't match LRP shape {lrp_data[i].shape}")
            
            # process current layer activation through independent MLP
            score = mlp(activation)  # [batch_size, layer_size]
            
            # add Gumbel noise
            noise = sample_gumbel(score.size()).to(device)
            y = score + noise # + self.base
            
            # get current layer LRP scores
            lrp = lrp_data[i]
            
            # apply layer importance scaling to LRP scores
            importance_factor = layer_importance[i]
            importance_factor_bias = self.layer_importance_bias[i]
            # scaled_lrp = lrp * importance_factor

            # add scaled LRP bias
            # y = y + scaled_lrp
            y = y  + lrp * self.lrp_scale
            y = y * importance_factor + importance_factor_bias

            # apply sigmoid to get probabilities
            out = torch.sigmoid(y / self.T)
            processed.append(out)
        
        # training mode return soft masks
        if self.training:
            masks = processed
        # test mode convert to hard masks
        else:
            soft_masks = processed
            masks = [hard_sample(out) for out in processed]
            
            # ensure each mask has at least one non-zero element
            for i in range(len(masks)):
                for b in range(batch_size):
                    if masks[i][b].sum() == 0:
                        masks[i][b][soft_masks[i][b].argmax()] = 1
        
        return masks
    
    # get importance scores
    def get_imp_score(self, layer_activations, input_lrp):
        if len(layer_activations) != len(self.t_sp):
            raise ValueError(f"layer activation count ({len(layer_activations)}) doesn't match structure count ({len(self.t_sp)})")
        
        if len(input_lrp) != len(self.t_sp):
            raise ValueError(f"LRP count ({len(input_lrp)}) doesn't match structure count ({len(self.t_sp)})")
        
        device = layer_activations[0].device
        batch_size = layer_activations[0].shape[0]

        # get layer importance
        layer_importance = F.softplus(self.layer_importance).to(device)
        
        # generate scores for each layer
        soft_scores = []
        for i, (activation, mlp) in enumerate(zip(layer_activations, self.layer_mlps)):
            # check if activation and LRP shapes match
            if activation.shape != input_lrp[i].shape:
                raise ValueError(f"layer {i} activation shape {activation.shape} doesn't match LRP shape {input_lrp[i].shape}")
            
            # process current layer activation through independent MLP
            score = mlp(activation)  # [batch_size, layer_size]
            
            # get current layer LRP scores
            lrp = input_lrp[i]
            
            # apply layer importance scaling to LRP scores
            importance_factor = layer_importance[i]
            importance_factor_bias = self.layer_importance_bias[i]
            y = score + lrp * self.lrp_scale
            y = y * importance_factor + importance_factor_bias
            out = torch.sigmoid(y / self.T)
            soft_scores.append(out)
        
        '''
        # generate hard masks
        hard_masks = [hard_sample(out) for out in soft_scores]
        
        
        # ensure each mask has at least one non-zero element
        for i in range(len(hard_masks)):
            for b in range(batch_size):
                if hard_masks[i][b].sum() == 0:
                    hard_masks[i][b][soft_scores[i][b].argmax()] = 1
        
        # multiply importance scores with masks, masked units become 0
        masked_importance_scores = []
        for soft_score, hard_mask in zip(soft_scores, hard_masks):
            masked_score = soft_score * hard_mask  # masked positions, importance scores also become 0
            masked_importance_scores.append(masked_score)
        '''
        return soft_scores

        
    def hard_output(self, layer_activations, input_lrp):
        if len(layer_activations) != len(self.t_sp):
            raise ValueError(f"layer activation count ({len(layer_activations)}) doesn't match structure count ({len(self.t_sp)})")
        
        if len(input_lrp) != len(self.t_sp):
            raise ValueError(f"LRP count ({len(input_lrp)}) doesn't match structure count ({len(self.t_sp)})")
        
        device = layer_activations[0].device
        batch_size = layer_activations[0].shape[0]
        
        # ensure LRP data on correct device
        lrp_data = [lrp.to(device) for lrp in input_lrp]
        
        # get layer importance
        layer_importance = F.softplus(self.layer_importance).to(device)
        
        # generate scores for each layer
        soft_masks = []
        for i, (activation, mlp) in enumerate(zip(layer_activations, self.layer_mlps)):
            # check if activation and LRP shapes match
            if activation.shape != lrp_data[i].shape:
                raise ValueError(f"layer {i} activation shape {activation.shape} doesn't match LRP shape {lrp_data[i].shape}")
            
            # process current layer activation through independent MLP
            score = mlp(activation)  # [batch_size, layer_size]
            
            # add Gumbel noise
            noise = sample_gumbel(score.size()).to(device)
            y = score + noise # + self.base
            
            # get current layer LRP scores
            lrp = lrp_data[i]
            
            # apply layer importance scaling to LRP scores
            importance_factor = layer_importance[i]
            importance_factor_bias = self.layer_importance_bias[i]
            # scaled_lrp = lrp * importance_factor
            y = y  + lrp * self.lrp_scale
            y = y * importance_factor + importance_factor_bias

            # add scaled LRP bias
            # y = y + scaled_lrp
            
            # apply sigmoid to get probabilities
            out = torch.sigmoid(y / self.T)
            soft_masks.append(out)
        
        # convert to hard masks
        hard_masks = [hard_sample(out) for out in soft_masks]
        
        # ensure each mask has at least one non-zero element
        for i in range(len(hard_masks)):
            for b in range(batch_size):
                if hard_masks[i][b].sum() == 0:
                    hard_masks[i][b][soft_masks[i][b].argmax()] = 1
        
        return hard_masks

