import torch
from torch import nn


class concrete_selector(torch.nn.Module):
    def __init__(self,input_shape):
        super(concrete_selector, self).__init__()
        k, d = input_shape
        self.logits = nn.Parameter(torch.ones(k, d), requires_grad=True)
        self.temp_hook = Temperature_Adjuster()

    def gumbel_softmax(self):
        uniform = torch.rand_like(self.logits)
        gumbel = -torch.log(-torch.log(uniform))

        temperature = self.temp_hook.update()
        noisy_logits = (self.logits+gumbel)/temperature
        samples = torch.nn.Softmax(dim = -1)(noisy_logits) 

        return samples.T, temperature

    def forward(self, input = None):

        FS_matrix,temperature = self.gumbel_softmax()

        if input != None:
            FS_res = torch.matmul(input, FS_matrix)
            return FS_res,  FS_matrix, temperature                # n x k, d x k
        else:
            return FS_matrix, temperature                        # d x k
  

class concrete_selector_orth(torch.nn.Module):
    def __init__(self,input_shape):
        super(concrete_selector_orth, self).__init__()
        k, d = input_shape
        self.logits = nn.Parameter(torch.ones(d, k), requires_grad=True)
        self.temp_hook = Temperature_Adjuster()
        self.eye_const = nn.Parameter((1e-6)*torch.eye(k), requires_grad=False)

    def gumbel_softmax(self,logits):
        uniform = torch.rand_like(logits)

        
        gumbel = -torch.log(-torch.log(uniform))

        temperature = self.temp_hook.update()
        noisy_logits = (logits+gumbel)/temperature
        samples = torch.nn.Softmax(dim = -1)(noisy_logits) 

        return samples.T, temperature               # d x k

    def forward(self, input = None):
        L = torch.linalg.cholesky(torch.matmul(self.logits.T,self.logits)+self.eye_const)       # k x k

        L_inv = torch.linalg.inv(L)                             # k x k
        orth_logits = torch.matmul(self.logits,L_inv.T).T                # k x d

        FS_matrix,temperature = self.gumbel_softmax(orth_logits)

        if input != None:
            FS_res = torch.matmul(input, FS_matrix)
            return FS_res,  FS_matrix, temperature                # n x k, d x k
        else:
            return FS_matrix, temperature                        # d x k
  

class Temperature_Adjuster:
    def __init__(self, start_temp = 4.0, min_temp = 0.01, alpha = 0.99):
        self.min_temp = min_temp
        self.alpha = alpha
        self.temperature = start_temp

    def update(self):
        self.temperature = max(self.min_temp,self.temperature*self.alpha)
        return self.temperature
