import torch
import torch.nn as nn
from torch.distributions.transforms import Transform

class BijectLeakyReLU(nn.LeakyReLU):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def inverse(self, x):

        if x > 0:
            inv_x = x
        else:
            inv_x = x / self.negative_slope

        return inv_x

    def log_abs_det_jacobian(self, x):

        shape = x.shape
        batch_size = shape[0]
        n_elements = shape[1:].numel()

        if x > 0:
            logdet = torch.Tensor([0])
        else:
            shape[1:] = 1
            logdet = torch.zeros_like(shape)
            logdet_per_batch = n_elements*torch.log(torch.abs(self.negative_slope))
            logdet += torch.Tensor([logdet_per_batch for _ in range(batch_size)])

        return logdet

class BijectBatchNorm(nn.BatchNorm1d):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # ensure we ALWAYS track running stat
        if 'track_running_stats' in kwargs:
            del kwargs['track_running_stats']

        super().__init__(*args, **kwargs)

    def forward(self, x):
        return super()(x)

    def inverse(self):
        pass

    def log_abs_det_jacobian(self):
        pass

class AffineCoupling(nn.Module):

    def __init__(self, scale_fn, bias_fn, mask):
        self._input_shape = input_shape
        self._input_shape = input_shape
        self._maks = mask
        self.scale_fn = scale_fn
        self.bias_fn = bias_fn

    def forward(self, x):
        y = torch.zeros_like(x)
        y[mask] = x[mask]
        y[~mask] = (x[~mask]*self.scale_fn(x[mask]) +
                    self.bias_fn(x[mask]))

        return y

    def inverse(self, y):
        x = torch.zeros_like(y)
        x[self._mask] = y[self._mask]
        x[~self._mask] = (y[~self._mask] -
                          self.bias_fn(y[~self._mask]))/self.scale_fn(x[~self._mask])

        return x

    def log_abs_det_jacobian(self, x):
        batch_size = x.size(0)
        return torch.log(torch.abs(self.scale_fn(x[~self._mask]))).view(batch_size, -1).sum(dim=1)

class ActNorm(nn.Module):

    def __init__(self, n_channels):
        """single scale and bias per channel"""
        self._H = H
        self._W = W

        self.scale = nn.Parameter(torch.zeros(n_channels).normal_())
        self.bias = nn.Parameter(torch.zeros(n_channels).normal_())

    def forward(self, x):
        return x*self.scale + self.bias

    def inverse(self, y):
        return (y-self.bias)/self.scale

    def log_abs_det_jacobian(self, x):
        batch_size, H, W, _ = x.shape
        return H*W*torch.log(torch.abs(self.scale)).sum().expand(batch_size)


def get_linear_ar_mask(c_out, c_in, upper=True):
    assert c_out == c_in, "Not equal!! : d - %d" % (c_out, c_in)

    mask = np.ones([c_out, c_in], dtype=np.float32)
    for i in range(c_in):
        if upper:
            mask[c_in - i - 1, :c_in-1-i] = 0
        else:
            mask[i, i+1:] = 0
            plt.imshow(mask)
            plt.show()
    return torch.from_numpy(mask)

def get_invertible_filter_mask(kernel_size, n_channels, upper=True):

    m = (kernel_size - 1) // 2
    mask = torch.ones([n_channels, n_channels, kernel_size, kernel_size],
                      dtype=torch.long)
    if upper:
        mask[:, :, :m, :] = 0
        mask[:, :, :, :m] = 0
        mask[:, :, m, m] = get_linear_ar_mask(n_channels, n_channels, upper=upper)
    else:
        mask[:, :, -m:, :] = 0
        mask[:, :, :, -m:] = 0
        mask[:, :, m, m] = get_linear_ar_mask(n_channels, n_channels, upper=False)
        print(mask)
    return mask, m

def init_emerging_filters(c_out, c_in, kernel_size, mask):

    # add to the diagonal: https://en.wikipedia.org/wiki/Tikhonov_regularization
    kcenter = kernel_size // 2
    kernel = torch.zeros(mask.shape).normal_()
    # ensure stable inversion
    torch.nn.init.eye_(kernel[:,:,kcenter, kcenter])
    return kernel*mask

class OneConv(nn.Module):

    def __init__(self, n_channels):
        self.W = nn.Parameter(torch.zeros(n_channels, n_channels).normal_())

    def forward(self, x):
        # CHECK X has shape B x H x W x C

        return torch.matmul(x, self.W)

    def inverse(self, y):
        # Se various decompositions
        return torch.matmul(y, self.W_invert())

    def log_abs_det_jacobian(self, x):
        # See paper for potentially better LU decomposition
        batch_size, H, W, _ = x.shape
        return H*W*torch.log().expand(batch_size)

class EmergingConv(nn.Module):

    # currently no dilation!
    def __init__(self, input_shape, kernel_size, n_channels, bias=False):

        assert (kernel_size - 1) % 2 == 0, "kernel size (dxd) should have d uneven"
        self.H, self.W = input_shape
        self.kernel_size = kernel_size
        self.n_channels = n_channels

        self.mask_w1, pad_size = get_invertible_filter_mask(kernel_size, n_channels, w1=True)
        self.mask_w2, _ = get_invertible_filter_mask(kernel_size, n_channels, w1=False)

        # consider making diagonal strictly positive as init --> ensure positive definite
        self.kernel1 = nn.Parameter(torch.zeros(n_channels, n_channels, kernel_size, kernel_size).normal_()*mask_w1)
        self.kernel2 = nn.Parameter(torch.zeros(n_channels, n_channels, kernel_size, kernel_size).normal_()*mask_w2)

        # padding left, right, top, bottom of a N x C x H x W tensor
        self._pad_w1 = nn.ConstantPad2s = (0, pad_size, 0, pad_size)
        self._pad_w2 = nn.ConstantPad2s = (pad_size, 0, pad_size, 0)

        HW = self.H*self.W
        self.alternate_channel_index = torch.Tensor([i+HW*c
                                                     for i in range(HW)
                                                     for c in range(self.n_channels)]).flatten().long()

    def forward(self, x):
        # for fast evaluation crop the zero-valued elements and adjust using padding
        x = self._pad_w1(x)
        x = nn.functional(x, self.w1[self.mask_w1])
        x = self._pad_w2(x)
        x = nn.functional(x, self.w2[self.mask_w2])

        return x

    def inverse(self, y):
        d = (kernel_size +1)//2
        x = conv2d_tri_inv_filter_lower(y, self.kernel, d, self.W, self.n_channels)
        return x.reshape(y.shape)

    def log_abs_det_jacobian(self, x):
        pass

    def visualize_kernel(self):
        """ Visualize the kernel as a matrix such that

        y=Kx, where y is the convolution of x

        K should be either upper or lower triangular

        """
        import matplotlib.pyplot as plt
        new_H = self.H + 2*self._pad,
        new_W = self.W + 2*self._pad,
        K = convmatrix2d(self.kernel, self.n_channels, new_H, new_W)

        tmp = []
        remove_j_interleaves = torch.tensor([list(range(new_W-self._pad+(new_W*k),
                                                        new_W+(new_W*k) + self._pad)) for k in
                                            range(new_H*self.n_channels)]).flatten()
        remove_j_blocks = torch.tensor([list(range(new_H*new_W*c,
                                                new_H*new_W*c + new_W*pad)) +
                                        list(range(new_H*new_W*(c+1) -
                                                new_W*pad, new_H*new_W*(c+1))) for
                                        c in range(self.n_channels)]).flatten()
        for i, row in enumerate(K):
            for j, element in enumerate(row):
                # remove blocks of padded columns
                if j not in remove_j_blocks:
                    if (not j % new_W == 0) and (j not in remove_j_interleaves):
                        tmp.append(element)

        N = self.H*self.W*self.n_channels
        tmp = torch.Tensor(tmp).reshape(N, N)

        tmp = tmp[self.alternate_channel_index,:]
        tmp = tmp[:,self.alternate_channel_index]
        plt.imshow(tmp)
        plt.show()

class PeriodicConv(nn.Module):

    def __init__(self):
        pass

    def invese(self):
        pass

    def log_abs_det_jacobian(self, x):
        pass

def channel_squeeze(x):
    pass


def conv2d_tri_inv_filter_lower(y, w, d, img_w, n_channels):
    """ Inverts y from a emerging convolution

    Args:

    y (:obj: tensor): tensor to be inverted must be flatten to be shape (batch_size, -1)
    w (:obj:tensor): c_out, c_in, kH, kW, convolution filter used in the emerging context. Must
                     satisfy in_channel==out_channel when looking at the equivalent convolution
                     form y=Kx, K must be LOWER TRIANGULAR
    d (int): effective (non-zero) width of filters
    img_w (ind): non-zero padded image width
    n_channels (int): number of input and output channels
    alternative_channel_indices (:obj:tensor:long): vector with index order for x and y,
                                                    such that first element is first pixel
                                                    first channel, second element is first
                                                    pixel second channel, etc.

    Returns:

    x = inv(y)

    """
    batch_size, N = y.shape
    x = torch.zeros_like(y, dtype=torch.float)
    jumps_up = 0
    dc = d*n_channels
    imgwc = img_w*n_channels

    normalizer = max(1, 1 / w[:,:,d-1,d-1].diag().abs().min().clone().detach())
    w *= normalizer

    for n in range(N):
        #idx = alternative_channel_indices[n]
        idx = n
        c = n % n_channels
        w_out = w[c,:] # in_channel x W x H
        solved_x_true_weighted_sum = 0.
        solved_x_weighted_sum = 0.
        if n > 0:
            pix_out = n // n_channels
            # get the n closets to the value we try and solve for

            tmp = (n % imgwc)
            if tmp < dc:
                n_closest = tmp
            else:
                n_closest = dc - n_channels + c

            n_closest_indices = torch.arange(n-n_closest, n)

            # get the n trailing, which are spanning backwards in the
            # original image - they are lagging behind by an "n_channels"
            # times img_w number of indices. This occurs n_channels times
            n_trailing_base_indices = torch.arange(n-n_closest, n - c + n_channels)
            jumps_up = pix_out // img_w if jumps_up < d - 1 else d - 1
            list_block_indices = [n_trailing_base_indices - imgwc*jump for jump in range(1, jumps_up+1)[::-1]]

            # will contain <=d number of unique lists
            if n_closest > 0:
                list_block_indices.append(n_closest_indices)

            init_row = d - jumps_up - 1
            solved_x_weighted_sum = 0.
            for rel_row, index_list in enumerate(list_block_indices):
                solved_x = x[:,index_list]
                channel_list = index_list % n_channels
                n_moves = len(index_list) // n_channels
                row = init_row + rel_row
                if row == d - 1:
                    offsets = range(0, n_moves + 1)[::-1]
                else:
                    offsets = range(0, n_moves)[::-1]
                solved_x_weights = torch.tensor([w_out[c_in, row, d - offsets[i//n_channels] - 1]
                                                 for i, c_in in enumerate(channel_list)], dtype=kernel.dtype)
                solved_x_weighted_sum += torch.matmul(solved_x, solved_x_weights)

        x[:, idx] -= solved_x_weighted_sum
        x[:, idx] += y[:, idx]
        x[:, idx] /= w_out[c, d-1, d-1]

    x *= normalizer
    return x


def convmatrix2d(kernel, n_channels, input_h, input_w):

    input_dims = n_channels*input_h*input_w
    result_dims = ((input_h - kernel.size(2))+1)*((input_w - kernel.size(3))+1)*n_channels
    m = torch.zeros(kernel.size(0), (input_h - kernel.size(2))+1,
                    (input_w - kernel.size(3))+1, n_channels,
                    input_h, input_w)
    for i in range(m.shape[1]):
        for j in range(m.shape[2]):
            m[:,i,j,:,i:i+kernel.shape[2],j:j+kernel.shape[3]] = kernel

    return m.reshape(result_dims, input_dims)
