import math
import torch
import torch.nn as nn

from torch.nn.parameter import Parameter
from torch.nn import init
from autograd_funcs import UnBlind as unblind_func
from autograd_funcs import Blind as blind_func
from autograd_funcs import Conv2d_sp as conv2d_func
from autograd_funcs import Linear_sp as linear_func

class Blind(torch.nn.Module):
    def __init__(self, blind_matrix, inverse_gradient_maxtrix):
        '''
            save blind and inverse_gradient matrix
        '''
        # super init
        super(Blind, self).__init__()
        # saving matrix
        self.bm = blind_matrix
        self.inv_gm = inverse_gradient_maxtrix
        self.mean = 0.0
        self.std  = 0.0
        self.count = 1
        self.max = 0.0
        self.min = 0.0

    def forward(self, x):
        #local_mean = x[0:+int(x.size(0) / 3 * 2)].mean()
        #local_std  = x[0:+int(x.size(0) / 3 * 2)].std()
        #local_max  = x[0:+int(x.size(0) / 3 * 2)].max()
        #local_min  = x[0:+int(x.size(0) / 3 * 2)].min()
        #self.max += (local_max.item() - self.max) / self.count
        #self.min += (local_min.item() - self.min) / self.count
        #self.mean += (local_mean.item() - self.mean) / self.count
        #self.std += (local_std.item() - self.std) / self.count
        #self.count += 1

        return blind_func.apply(x, self.bm, self.inv_gm)        

class UnBlind(torch.nn.Module):
    def __init__(self, unblind_matrix, gradient_matrix, bias_size, use_bias=True):
        '''
            save unblind and gradient matrix
        '''
        # super init
        super(UnBlind, self).__init__()
        # saving matrix
        self.um = unblind_matrix
        self.gm = gradient_matrix
        if use_bias:
            self.bias = Parameter(torch.Tensor(bias_size))
        else:
            self.bias = None

    def forward(self, input):
        retval =  unblind_func.apply(input, self.um, self.gm, self.bias)
        return retval
    
    def init_bias(self):
        nn.init.constant_(self.bias, 0)

# subclass of nn.conv2d
class Conv2d_sp(torch.nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros'):
        
        super(Conv2d_sp, self).__init__(
            in_channels, out_channels, kernel_size=kernel_size, 
            stride=stride, padding=padding, dilation=dilation,
            groups = 1, bias=bias)

    def forward(self, input):
        y, bias = conv2d_func.apply(input, self.weight, self.bias)

        return (y, bias)

# special linear layer
class Linear_sp(torch.nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(Linear_sp, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        y, bias = linear_func.apply(input, self.weight, self.bias)
        return (y, bias)