import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from torch.nn.parameter import Parameter

def calculate_mask_index(kernel_length_now,largest_kernel_length):
    right_zero_mast_length = math.ceil((largest_kernel_length-1)/2)-math.ceil((kernel_length_now-1)/2)
    left_zero_mask_length = largest_kernel_length - kernel_length_now - right_zero_mast_length
    return left_zero_mask_length, left_zero_mask_length+ kernel_length_now

def creat_mask(number_of_input_channel, number_of_output_channel, kernel_length_now, largest_kernel_length):
    ind_left, ind_right= calculate_mask_index(kernel_length_now,largest_kernel_length)
    mask = np.ones((number_of_input_channel,number_of_output_channel,largest_kernel_length))
    mask[:,:,0:ind_left]=0
    mask[:,:,ind_right:]=0
    return mask

def use_dilation_to_insert_zeros_to_weight(weight,dilation):
    if dilation == 1:
        return weight
    else:
        length_of_weight = weight.shape[-1]
        zero_pad = np.zeros([weight.shape[0],weight.shape[1],(weight.shape[-1]-1)*dilation+1])
        index = np.asarray([i*dilation for i in range(weight.shape[-1])])
        zero_pad[:,:,index] = weight
        return zero_pad

def has_dilation(layer_parameter):
    if len(layer_parameter) == 4 and layer_parameter[3] !=1:
        return True
    else:
        return False
    
    
def creak_layer_mask(layer_parameter_list):
    # [[in, out, kernel_size, dilation],
    #  [in, out, kernel_size, dilation],
    #  [in, out, kernel_size, dilation]]
    
    largest_kernel_length = 0
    for layer_parameter in layer_parameter_list:
        if has_dilation(layer_parameter):
            kernel_size_this_layer = 1 + (layer_parameter[2]-1)*(layer_parameter[3])
        else:
            kernel_size_this_layer = layer_parameter[2]
            
        if kernel_size_this_layer > largest_kernel_length:
            largest_kernel_length = kernel_size_this_layer

    mask_list = []
    init_weight_list = []
    bias_list = []
    for i in layer_parameter_list:
        conv = torch.nn.Conv1d(in_channels=i[0], out_channels=i[1], kernel_size=i[2])
        
        if has_dilation(i):
            kernel_size_this_layer = 1 + (i[2]-1)*(i[3])
        else:
            kernel_size_this_layer = i[2]
        
        ind_l, ind_r= calculate_mask_index(kernel_size_this_layer, largest_kernel_length)
        big_weight = np.zeros((i[1],i[0],largest_kernel_length))
        mask = creat_mask(i[1],i[0],kernel_size_this_layer, largest_kernel_length)
        
        if has_dilation(i):
            big_weight[:,:,ind_l:ind_r] = use_dilation_to_insert_zeros_to_weight(conv.weight.detach().numpy(),i[3])
            mask[:,:,np.asarray([i[3]*index+ind_l+1 for index in range(i[2]-1)])] = 0
            
        else:
            big_weight[:,:,ind_l:ind_r] = conv.weight.detach().numpy()
            

            
            
        bias_list.append(conv.bias.detach().numpy())
        init_weight_list.append(big_weight)
        mask_list.append(mask)
        
    mask = np.concatenate(mask_list, axis=0)
    init_weight = np.concatenate(init_weight_list, axis=0)
    init_bias = np.concatenate(bias_list, axis=0)
    return mask.astype(np.float32), init_weight.astype(np.float32), init_bias.astype(np.float32)


class Conv1d_Layer_with_Mask(nn.Module):    
    def __init__(self, layer_parameters, use_bias = True, use_batch_Norm =True, use_relu =True):
        super(Conv1d_Layer_with_Mask, self).__init__()
        
        self.use_bias = use_bias
        self.use_batch_Norm = use_batch_Norm
        self.use_relu = use_relu

        os_mask, init_weight, init_bias= creak_layer_mask(layer_parameters)        
        
        in_channels = os_mask.shape[1] 
        out_channels = os_mask.shape[0] 
        max_kernel_size = os_mask.shape[-1]

        self.weight_mask = nn.Parameter(torch.from_numpy(os_mask),requires_grad=False)
        
        self.padding = nn.ConstantPad1d((int((max_kernel_size-1)/2), int(max_kernel_size/2)), 0)
         
        self.conv1d = torch.nn.Conv1d(in_channels = in_channels, 
                                      out_channels = out_channels, 
                                      kernel_size = max_kernel_size, 
                                      bias =self.use_bias)
        
        self.conv1d.weight = nn.Parameter(torch.from_numpy(init_weight),requires_grad=True)
        self.conv1d.bias =  nn.Parameter(torch.from_numpy(init_bias),requires_grad=True)
        self.bn = nn.BatchNorm1d(num_features=out_channels)
        
    def forward(self, X):
        
        self.conv1d.weight.data = self.conv1d.weight*self.weight_mask
        
        result_1 = self.padding(X)
        result_2 = self.conv1d(result_1)
                
        if self.use_batch_Norm:
            result_3 = self.bn(result_2)
        else:
            result_3 = result_2
            
        if self.use_relu:
            result = F.relu(result_3)
            return result
        
        else:
            return result_3