import torch
import torch.nn as nn
from torch.distributions import RelaxedOneHotCategorical


def restore_parameters(model, best_model):
    '''Move parameters from best model to current model.'''
    for param, best_param in zip(model.parameters(), best_model.parameters()):
        param.data = best_param
        
        
def generate_bernoulli_mask(shape, p):
    return (torch.rand(shape) < p).float()


def generate_uniform_mask(batch_size, num_features):
    unif = torch.rand(batch_size, num_features)
    ref = torch.rand(batch_size, 1)
    return (unif > ref).float()
    


def make_onehot(x):
    '''
    Make an approximately one-hot vector one-hot.
    
    Args:
      x:
    '''
    argmax = torch.argmax(x, dim=1)
    onehot = torch.zeros(x.shape, dtype=x.dtype, device=x.device)
    onehot[torch.arange(len(x)), argmax] = 1
    return onehot


class MaskLayer(nn.Module):
    '''
    Mask layer for tabular data.
    
    Args:
      append:
      mask_size:
    '''
    def __init__(self, append, mask_size=None):
        super().__init__()
        self.append = append
        self.mask_size = mask_size

    def forward(self, x, m):
        out = x * m
        if self.append:
            out = torch.cat([out, m], dim=1)
        return out
    
    
class MaskLayerGrouped(nn.Module):
    '''
    Mask layer for tabular data with feature grouping.
    
    Args:
      groups:
      append:
    '''
    def __init__(self, group_matrix, append):
        # Verify group matrix.
        assert torch.all(group_matrix.sum(dim=0) == 1)
        assert torch.all((group_matrix == 0) | (group_matrix == 1))
        
        # Initialize.
        super().__init__()
        self.register_buffer('group_matrix', group_matrix.float())
        self.append = append
        self.mask_size = len(group_matrix)
        
    def forward(self, x, m):
        out = x * (m @ self.group_matrix)
        if self.append:
            out = torch.cat([out, m], dim=1)
        return out


class MaskLayer2d(nn.Module):
    '''
    Mask layer for 2d image data.
    
    Args:
      append:
      mask_width:
      patch_size:
    '''

    # TODO change argument order, including in CIFAR notebooks
    def __init__(self, append, mask_width, patch_size):
        super().__init__()
        self.append = append
        self.mask_width = mask_width
        self.mask_size = mask_width ** 2
        
        # Set up upsampling.
        self.patch_size = patch_size
        if patch_size == 1:
            self.upsample = nn.Identity()
        elif patch_size > 1:
            self.upsample = nn.Upsample(scale_factor=patch_size)
        else:
            raise ValueError('patch_size should be int >= 1')

    def forward(self, x, m):
        # Reshape if necessary.
        if len(m.shape) == 2:
            m = m.reshape(-1, 1, self.mask_width, self.mask_width)
        elif len(m.shape) != 4:
            raise ValueError(f'cannot determine how to reshape mask with shape = {m.shape}')
        
        # Apply mask.
        m = self.upsample(m)
        out = x * m
        if self.append:
            out = torch.cat([out, m], dim=1)
        return out

class StaticMaskLayer2d(nn.Module):
    '''
    Static mask layer for 2d image data.
    
    Args:
      mask:
      mask_width:
      patch_size:
      append:
    '''

    def __init__(self, mask, mask_width, patch_size, append=False):
        super().__init__()
        self.mask_width = mask_width
        self.mask_size = mask_width ** 2
        self.mask = mask
        self.patch_size = patch_size
        self.append = append

        # Reshape if necessary.
        if len(self.mask.shape) == 2:
            self.mask = self.mask.reshape(-1, 1, self.mask_width, self.mask_width)
        elif len(self.mask.shape) != 4:
            raise ValueError(f'cannot determine how to reshape mask with shape = {m.shape}')

        # upsample mask
        if patch_size == 1:
            self.upsample = nn.Identity()
        elif patch_size > 1:
            self.upsample = nn.Upsample(scale_factor=patch_size)
        else:
            raise ValueError('patch_size should be int >= 1')
        self.mask = self.upsample(self.mask)

    def forward(self, x):
        # Apply mask.
        out = x * self.mask
        if self.append:
            out = torch.cat([out, self.m], dim=1)
        return out

class ConcreteSelector(nn.Module):
    '''Output layer for selector models.'''

    def __init__(self, gamma=0.2):
        super().__init__()
        self.gamma = gamma

    def forward(self, logits, temp, deterministic=False):
        if deterministic:
            # TODO this is somewhat untested, but seems like best way to preserve argmax
            return torch.softmax(logits / (self.gamma * temp), dim=-1)
        else:
            dist = RelaxedOneHotCategorical(temp, logits=logits / self.gamma)
            return dist.rsample()
