import torch.nn.functional as F
import torch.nn as nn
import torch
from torch.distributions import Normal, Independent, kl
from torch.autograd import Variable
import numpy as np

import math




class ReflectedConvolution(nn.Module):
    """
    Class to handle all the extra ratio calculations. Exposed as layers to a
    network for future reuse.
    """

    def __init__(self, conv_type='softmax', kernel_nums = 8, kernel_size = 3):
        super(ReflectedConvolution, self).__init__()
     
        self.conv_type = conv_type
        self.kernel_nums = kernel_nums
        self.kernel_size = kernel_size  

        self.filter = torch.nn.Parameter(torch.randn(self.kernel_nums, 1,  self.kernel_size,  self.kernel_size))
        torch.nn.init.kaiming_normal_(self.filter)
     
    
    def mean_constraint(self, kernel):
        bs, cin, kw, kh =kernel.shape
        kernel_mean = torch.mean(kernel.view(bs, -1), dim=1, keepdim=True)
        kernel = (kernel.view(bs, -1) - kernel_mean).view(bs, cin, kw, kh)
        return kernel
    
    def ReflectedMap(self, img):

        zeroMasks = torch.zeros_like(img)
        zeroMasks[img == 0] = 1

        log_img = torch.log(img + 1e-7)


        red_chan = log_img[:, 0, :, :].unsqueeze(1)
        green_chan = log_img[:, 1, :, :].unsqueeze(1)
        blue_chan = log_img[:, 2, :, :].unsqueeze(1)

    

        normalized_filter = self.mean_constraint(self.filter)

        filt_r1 = F.conv2d(red_chan, weight=normalized_filter, padding= self.kernel_size//2)
        filt_g1 = F.conv2d(green_chan, weight=-normalized_filter, padding= self.kernel_size//2)
        filt_rg = filt_r1 + filt_g1
        filt_rg = self.rg_bn(filt_rg)

        # Green-Blue
        filt_g2 = F.conv2d(green_chan, weight=normalized_filter, padding= self.kernel_size//2)
        filt_b1 = F.conv2d(blue_chan, weight=-normalized_filter, padding= self.kernel_size//2)
        filt_gb = filt_g2 + filt_b1

        filt_gb = self.gb_bn(filt_gb)


        # Red-Blue
        filt_r2 = F.conv2d(red_chan, weight=normalized_filter, padding= self.kernel_size//2)
        filt_b2 = F.conv2d(blue_chan, weight=-normalized_filter, padding= self.kernel_size//2)
        filt_rb = filt_r2 + filt_b2
        filt_rb = self.rb_bn(filt_rb)


         
        rg = filt_rg
        rg = torch.where(zeroMasks[:, 0:1, ...].expand(-1, self.kernel_nums, -1, -1)==1, 0, rg)
        gb = filt_gb
        gb = torch.where(zeroMasks[:, 1:2, ...].expand(-1, self.kernel_nums, -1, -1)==1, 0, gb)
        rb = filt_rb
        rb = torch.where(zeroMasks[:, 2:3, ...].expand(-1, self.kernel_nums, -1, -1)==1, 0, rb)
            
        
        crossed_img = torch.cat([rg, gb, rb], dim=1)

        return crossed_img


        
class gaussCrossRatioCal(ratioCals):
    """
    Class to calculate the cross ratio, using a discrete filter.
    """

    def __init__(self):
        super(gaussCrossRatioCal, self).__init__()

    def forward(self, img):
        crossRatio = self.getGaussCrossRatios(img)
        return crossRatio
    
    
    
    
class RFConv(ReflectedConvolution):
    """
    Class to calculate the cross ratio, using a discrete filter.
    """

    def __init__(self, conv_type, kernel_nums=8, kernel_size=3):
        super(RFConv, self).__init__(conv_type=conv_type, kernel_nums=kernel_nums, kernel_size = kernel_size)

    def forward(self, img):
        out = self.ReflectedMap(img)
        return out

