import torch
import torch.nn as nn
from torch.distributions import RelaxedOneHotCategorical


def restore_parameters(model, best_model):
    for param, best_param in zip(model.parameters(), best_model.parameters()):
        param.data = best_param


class MaskLayer(nn.Module):
    def __init__(self, append):
        super().__init__()
        self.append = append

    def forward(self, x, m):
        out = x * m
        if self.append:
            out = torch.cat([out, m], dim=1)
        return out


# class SoftmaxSelector(nn.Module):
#     '''Not successful.'''

#     def __init__(self):
#         super().__init__()
    
#     def forward(self, logits, temp):
#         return torch.softmax(logits / temp, dim=1)


# class STSelector(nn.Module):
#     '''Not successful.'''

#     def __init__(self):
#         super().__init__()
    
#     def forward(self, logits, temp):
#         argmax = torch.argmax(logits, dim=1)
#         onehot = torch.zeros(logits.shape, dtype=logits.dtype,
#                              device=logits.device)
#         onehot[torch.arange(len(logits)), argmax] = 1
#         softmax = torch.softmax(logits / temp, dim=1)
#         return softmax + (onehot - softmax).detach()


class ConcreteSelector(nn.Module):
    '''Output layer for selector models.'''

    def __init__(self, gamma=0.2):
        super().__init__()
        self.gamma = gamma

    def forward(self, logits, temp):
        dist = RelaxedOneHotCategorical(temp, logits=logits / self.gamma)
        sample = dist.rsample()
        return sample


# class ConcreteSTSelector(nn.Module):
#     '''Not successful.'''

#     def __init__(self, gamma=0.2):
#         super().__init__()
#         self.gamma = gamma
        
#     def forward(self, logits, temp):
#         dist = RelaxedOneHotCategorical(temp, logits=logits / self.gamma)
#         sample = dist.rsample()
#         argmax = torch.argmax(sample, dim=1)
#         onehot = torch.zeros(sample.shape, dtype=sample.dtype,
#                              device=sample.device)
#         onehot[torch.arange(len(sample)), argmax] = 1
#         return sample + (onehot - sample).detach()


class ConcreteMask(nn.Module):
    '''For global feature selection baseline.'''

    def __init__(self, dim, num, append=True):
        super().__init__()
        self.logits = nn.Parameter(torch.randn(num, dim, dtype=torch.float32))
        self.append = append

    def forward(self, x, temp):
        dist = RelaxedOneHotCategorical(temp, logits=self.logits)
        sample = dist.rsample([len(x)])
        m = sample.max(dim=1).values
        out = x * m
        if self.append:
            out = torch.cat([out, m], dim=1)
        return out


class Accuracy(nn.Module):
    '''0-1 accuracy.'''

    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        return torch.mean((torch.argmax(pred, dim=1) == target).float())


class NegAccuracy(nn.Module):
    '''Negative accuracy, for use as validation loss.'''

    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        return - torch.mean((torch.argmax(pred, dim=1) == target).float())
