import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from .utils.Parallel_Conv1d_achieved_by_mask import Parallel_Conv1d_achieved_by_mask


def get_Prime_number_in_a_range(start, end):
    # get all prime number between start to end
    Prime_list = []
    for val in range(start, end + 1): 
        prime_or_not = True
        for n in range(2, val):
            if (val % n) == 0:
                prime_or_not = False
                break
        if prime_or_not:
            Prime_list.append(val)
    return Prime_list


#### below is OS_CNN_structure_builder ####

def get_out_channel_number(paramenter_layer, in_channel, prime_list):
    out_channel_expect = int(paramenter_layer/(in_channel*sum(prime_list)))
    return out_channel_expect

def generate_layer_parameter_list(start,end,paramenter_number_of_layer_list, in_channel = 1):
    prime_list = get_Prime_number_in_a_range(start, end)
    if prime_list == []:
        print('start = ',start, 'which is larger than end = ', end)
        
    
    input_in_channel = in_channel
    layer_parameter_list = []
    paramenter_number_of_layer_list[0] =  paramenter_number_of_layer_list[0]*in_channel
    
    for paramenter_number_of_layer in paramenter_number_of_layer_list:
        out_channel = get_out_channel_number(paramenter_number_of_layer, in_channel, prime_list)
        
        tuples_in_layer= []
        for prime in prime_list:
            tuples_in_layer.append((in_channel,out_channel,prime))
        in_channel =  len(prime_list)*out_channel
        
        layer_parameter_list.append(tuples_in_layer)
    
    tuples_in_layer_last = []
    first_out_channel = len(prime_list)*get_out_channel_number(paramenter_number_of_layer_list[0], input_in_channel, prime_list)
    tuples_in_layer_last.append((in_channel,first_out_channel,start))
    tuples_in_layer_last.append((in_channel,first_out_channel,start+1))
    layer_parameter_list.append(tuples_in_layer_last)
    return layer_parameter_list


class OS_CNN_structure_builder():
    def __init__(self,
                 length_of_TS_data,
                 in_channel = 1,
                 start_kernel_size = 1,
                 Max_kernel_size = 89, 
                 paramenter_number_of_layer_list = [8*128, 5*128*256 + 2*256*128]):
        
        super(OS_CNN_structure_builder,self).__init__()
        
        self.start_kernel_size = start_kernel_size
        self.Max_kernel_size = Max_kernel_size
        self.receptive_field_shape = receptive_field_shape= min(int(length_of_TS_data/4),self.Max_kernel_size)
        self.paramenter_number_of_layer_list = paramenter_number_of_layer_list
        self.in_channel = in_channel
        
    def get_OS_CNN_structure(self):
        
        return generate_layer_parameter_list(start = self.start_kernel_size,
                                             end = self.receptive_field_shape,
                                             paramenter_number_of_layer_list = self.paramenter_number_of_layer_list, 
                                             in_channel = self.in_channel)



class OS_CNN(nn.Module):
    def __init__(self,layer_parameter_list, n_class, out_put_feature = False):
        super(OS_CNN, self).__init__()
        self.out_put_feature = out_put_feature
        
        self.layer_parameter_list = layer_parameter_list
        
        layer_list =[]
        for layer_parameter in layer_parameter_list:
            out_channels = sum([kernel[1] for kernel in layer_parameter])
            layer_list.append(Parallel_Conv1d_achieved_by_mask(layer_parameter))
            layer_list.append(nn.BatchNorm1d(num_features=out_channels))
            layer_list.append(nn.ReLU())
        self.os_layer =  nn.Sequential(*layer_list)
        self.averagepool = nn.AdaptiveAvgPool1d(1)
        
        out_put_channel_numebr = 0
        for final_layer_parameters in layer_parameter_list[-1]:
            out_put_channel_numebr = out_put_channel_numebr+ final_layer_parameters[1] 
            
        self.hidden = nn.Linear(out_put_channel_numebr, n_class)

    def forward(self, X):
        
        X = self.os_layer(X)

        X = self.averagepool(X)
        X = X.squeeze_(-1)

        if not self.out_put_feature:
            X = self.hidden(X)
        return X

        
    
    
