import numpy as np
from math import *

import torch,math
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

class add_Module(nn.Module):#input:[batch,dim,length]
    def __init__(self, Module1,Module2):#hidden:[kernal,length_in,in_dim,out_dim,fc_hidden,outputs]
        super(add_Module, self).__init__()
        self.Module1 = Module1
        self.Module2 = Module2
    def forward(self, input):
        input1 = self.Module1(input)
        input2 = self.Module2(input)

        return input1+input2

class Linear_N_H(nn.Module):
    def __init__(self, shape):#hidden:[inputs,h1,h2,outputs]
        super(Linear_N_H, self).__init__()
        self.shape = shape
        self.layer_list = nn.ModuleList([])
        for i in range(len(shape)):
            if(i==0):
                h0=shape[i]
            else:
                fc = nn.Linear(in_features=h0, out_features=shape[i])
                h0 = shape[i]
                self.layer_list.append(fc)
    def forward(self, input):
        for li in range(len(self.layer_list)):
            input=self.layer_list[li](input)
            if(li<len(self.layer_list)-1):
                input=F.relu(input)
        return input

class Linear_N_H_1d(nn.Module):
    def __init__(self, shape):#hidden:[inputs,h1,h2,outputs]
        super(Linear_N_H_1d, self).__init__()
        self.shape = shape
        self.layer_list = nn.ModuleList([])
        for i in range(len(shape)):
            if(i==0):
                h0=shape[i]
            else:
                fc = nn.Linear(in_features=h0, out_features=shape[i])
                h0 = shape[i]
                self.layer_list.append(fc)
    def forward(self, input):
        input = input.view(input.shape[0], -1)
        for li in range(len(self.layer_list)):
            input=self.layer_list[li](input)
            if(li<len(self.layer_list)-1):
                input=F.relu(input)
        return input

class LSTM_1_H(nn.Module):#input:[batch,dim,length]
    def __init__(self, shape):#shape:[dim,h1,fc_hidden,outputs]
        super(LSTM_1_H, self).__init__()
        self.LSTM = nn.LSTM(shape[0], shape[1], batch_first=True, bidirectional=False)
        self.shape = shape
        self.layer_list = nn.ModuleList([])
        h0 = shape[1]
        shapefc=shape[2:]
        for i in range(len(shapefc)):
            fc = nn.Linear(in_features=h0, out_features=shapefc[i])
            h0 = shapefc[i]
            self.layer_list.append(fc)
    def forward(self, input):
        input = input.transpose(1, 2).contiguous()
        hidden, (h_n, c_n) = self.LSTM(input)
        input=h_n.permute(1, 0, 2)

        input = input.view(input.shape[0], -1)
        for li in range(len(self.layer_list)):
            input=self.layer_list[li](input)
            if(li<len(self.layer_list)-1):
                input=F.relu(input)
        return input
class GRU_1_H(nn.Module):#input:[batch,dim,length]
    def __init__(self, shape):#shape:[dim,h1,fc_hidden,outputs]
        super(GRU_1_H, self).__init__()
        self.LSTM = nn.GRU(shape[0], shape[1], batch_first=True, bidirectional=False)
        self.shape = shape
        self.layer_list = nn.ModuleList([])
        h0 = shape[1]
        shapefc=shape[2:]
        for i in range(len(shapefc)):
            fc = nn.Linear(in_features=h0, out_features=shapefc[i])
            h0 = shapefc[i]
            self.layer_list.append(fc)
    def forward(self, input):
        input = input.transpose(1, 2).contiguous()
        hidden, (h_n, c_n) = self.LSTM(input)
        input=h_n.permute(1, 0, 2)

        input = input.view(input.shape[0], -1)
        for li in range(len(self.layer_list)):
            input=self.layer_list[li](input)
            if(li<len(self.layer_list)-1):
                input=F.relu(input)
        return input
# class LSTM_5_s_H(nn.Module):#input:[batch,dim,length]
#     def __init__(self, shape):#shape:[dim,h1,fc_hidden,outputs]
#         super(LSTM_5_s_H, self).__init__()
#         self.LSTM = nn.LSTM(shape[0], shape[1], batch_first=True, bidirectional=False)
#         self.shape = shape
#         self.layer_list = nn.ModuleList([])
#         h0 = shape[1]
#         shapefc=shape[2:]
#         for i in range(len(shapefc)):
#             fc = nn.Linear(in_features=h0, out_features=shapefc[i])
#             h0 = shapefc[i]
#             self.layer_list.append(fc)
#     def forward(self, input):
#         input = input.transpose(1, 2).contiguous()
#         hidden, (h_n, c_n) = self.LSTM(input)
#         input=h_n.permute(1, 0, 2)
#
#         input = input.view(input.shape[0], -1)
#         for li in range(len(self.layer_list)):
#             input=self.layer_list[li](input)
#             if(li<len(self.layer_list)-1):
#                 input=F.relu(input)
#         return input
class LSTM_2_H(nn.Module):#input:[batch,dim,length]
    def __init__(self, shape):#shape:[dim,h1,h2,fc_hidden,outputs]
        super(LSTM_2_H, self).__init__()
        self.LSTM1 = nn.LSTM(shape[0], shape[1], batch_first=True, bidirectional=False)
        self.LSTM2 = nn.LSTM(shape[1], shape[2], batch_first=True, bidirectional=False)
        self.shape = shape
        self.layer_list = nn.ModuleList([])
        h0 = shape[2]
        shapefc=shape[3:]
        for i in range(len(shapefc)):
            fc = nn.Linear(in_features=h0, out_features=shapefc[i])
            h0 = shapefc[i]
            self.layer_list.append(fc)
    def forward(self, input):
        input = input.transpose(1, 2).contiguous()
        hidden, (h_n, c_n) = self.LSTM1(input)
        hidden, (h_n, c_n) = self.LSTM2(hidden)
        input=h_n.permute(1, 0, 2)

        input = input.view(input.shape[0], -1)
        for li in range(len(self.layer_list)):
            input=self.layer_list[li](input)
            if(li<len(self.layer_list)-1):
                input=F.relu(input)
        return input
class LSTM_N_S_H(nn.Module):#input:[batch,dim,length]
    def __init__(self, shape):#shape:[N-same,dim,h1,fc_hidden,outputs]
        super(LSTM_N_S_H, self).__init__()
        N_LSTM=shape[0]
        shape = shape[1:]
        self.LSTM_list = nn.ModuleList([])
        self.LSTM = nn.LSTM(shape[0], shape[1], batch_first=True, bidirectional=False,num_layers=N_LSTM)
        self.shape = shape
        self.layer_list = nn.ModuleList([])
        h0 = shape[1]
        shapefc = shape[2:]
        for i in range(len(shapefc)):
            fc = nn.Linear(in_features=h0, out_features=shapefc[i])
            h0 = shapefc[i]
            self.layer_list.append(fc)

    def forward(self, input):
        input = input.transpose(1, 2).contiguous()
        hidden, (h_n, c_n) = self.LSTM(input)
        h_n=h_n[-1,:,:].unsqueeze(0)
        input = h_n.permute(1, 0, 2)

        input = input.view(input.shape[0], -1)
        for li in range(len(self.layer_list)):
            input = self.layer_list[li](input)
            if (li < len(self.layer_list) - 1):
                input = F.relu(input)
        return input
class LSTM_N_S_H_2(nn.Module):#input:[batch,dim,length]
    def __init__(self, shape):#shape:[N-same,dim,h1,fc_hidden,outputs]
        super(LSTM_N_S_H_2, self).__init__()
        N_LSTM=shape[0]
        shape = shape[1:]
        self.LSTM_list = nn.ModuleList([])
        self.LSTM = nn.LSTM(shape[0], shape[1], batch_first=True, bidirectional=False,num_layers=N_LSTM)
        self.shape = shape
        self.layer_list = nn.ModuleList([])
        h0 = shape[1]*N_LSTM
        shapefc = shape[2:]
        for i in range(len(shapefc)):
            fc = nn.Linear(in_features=h0, out_features=shapefc[i])
            h0 = shapefc[i]
            self.layer_list.append(fc)

    def forward(self, input):
        input = input.transpose(1, 2).contiguous()
        hidden, (h_n, c_n) = self.LSTM(input)
        #h_n=h_n[-1,:,:].unsqueeze(0)
        input = h_n.permute(1, 0, 2)

        input = input.reshape(input.shape[0], -1)
        for li in range(len(self.layer_list)):
            input = self.layer_list[li](input)
            if (li < len(self.layer_list) - 1):
                input = F.relu(input)
        return input
# class LSTM_N_S_H(nn.Module):#input:[batch,dim,length]
#     def __init__(self, shape):#shape:[N-same,dim,h1,fc_hidden,outputs]
#         super(LSTM_N_S_H, self).__init__()
#         N_LSTM=shape[0]
#         shape = shape[1:]
#         self.LSTM_list = nn.ModuleList([])
#         self.LSTM = nn.LSTM(shape[0], shape[1], batch_first=True, bidirectional=False)
#         self.LSTM_list.append(self.LSTM)
#         for ii in range(N_LSTM-1):
#             LSTM_same=nn.LSTM(shape[1], shape[1], batch_first=True, bidirectional=False)
#             self.LSTM_list.append(LSTM_same)
#         self.shape = shape
#         self.layer_list = nn.ModuleList([])
#         h0 = shape[1]
#         shapefc = shape[2:]
#         for i in range(len(shapefc)):
#             fc = nn.Linear(in_features=h0, out_features=shapefc[i])
#             h0 = shapefc[i]
#             self.layer_list.append(fc)
#
#     def forward(self, input):
#         input = input.transpose(1, 2).contiguous()
#         for LSTMi in range(len(self.LSTM_list)):
#             input, (h_n, c_n) = self.LSTM_list[LSTMi](input)
#
#         input = h_n.permute(1, 0, 2)
#
#         input = input.view(input.shape[0], -1)
#         for li in range(len(self.layer_list)):
#             input = self.layer_list[li](input)
#             if (li < len(self.layer_list) - 1):
#                 input = F.relu(input)
#         return input
class CNN_LSTM_1_H(nn.Module):#input:[batch,dim,length]
    def __init__(self, shape):#shape:[kernal1,in_dim,out_dim1,lstmdim,h1,fc_hidden,outputs]
        super(CNN_LSTM_1_H, self).__init__()
        self.mainconv1 = nn.Conv1d(shape[1], shape[2], kernel_size=shape[0], stride=1)
        self.LSTM = nn.LSTM(shape[2], shape[3], batch_first=True, bidirectional=False)
        self.shape = shape
        self.layer_list = nn.ModuleList([])
        h0 = shape[3]
        shapefc=shape[4:]
        for i in range(len(shapefc)):
            fc = nn.Linear(in_features=h0, out_features=shapefc[i])
            h0 = shapefc[i]
            self.layer_list.append(fc)
    def forward(self, input):
        input = self.mainconv1(input)
        input = F.relu(input)
        input = input.transpose(1, 2).contiguous()
        hidden, (h_n, c_n) = self.LSTM(input)
        input=h_n.permute(1, 0, 2)

        input = input.view(input.shape[0], -1)
        for li in range(len(self.layer_list)):
            input=self.layer_list[li](input)
            if(li<len(self.layer_list)-1):
                input=F.relu(input)
        return input

class CNN_2_H(nn.Module):#input:[batch,dim,length]
    def __init__(self, shape):#hidden:[length_in,kernal1,kernal2,in_dim,out_dim1,out_dim2,pools1,pools2,fc_hidden,outputs]
        super(CNN_2_H, self).__init__()
        self.shape = shape
        length_in = shape[0]
        kernal=shape[1:3]
        dim=shape[3:6]
        stride=shape[6:8]
        shapefc = shape[8:]
        self.layer_list = nn.ModuleList([])
        pad_size,dilation,stride_=0,1,1
        length = math.floor((length_in + 2 * pad_size - dilation*(kernal[0] - 1) - 1) / stride_ + 1)
        self.mainconv1 = nn.Conv1d(dim[0],  dim[1], kernal[0], stride=1)
        length = math.floor((length + 2 * pad_size - kernal[0]) / stride[0] + 1)
        self.mainavgpool1 = nn.AvgPool1d(kernal[0], stride=stride[0])
        length = math.floor((length + 2 * pad_size - dilation * (kernal[0] - 1) - 1) / stride_ + 1)
        self.mainconv2 = nn.Conv1d(dim[1], dim[2], kernal[1], stride=1)
        length = math.floor((length + 2 * pad_size - kernal[1]) / stride[1] + 1)
        self.mainavgpool2 = nn.AvgPool1d(kernal[1], stride=stride[1])
        h0 = dim[2]*length
        for i in range(len(shapefc)):
            fc = nn.Linear(in_features=h0, out_features=shapefc[i])
            h0 = shapefc[i]
            self.layer_list.append(fc)
    def forward(self, input):
        input = self.mainconv1(input)
        input = F.relu(input)
        input = self.mainavgpool1(input)
        input = self.mainconv2(input)
        input = F.relu(input)
        input = self.mainavgpool2(input)
        #input = self.adaavgpool(input)
        input = input.view(input.shape[0], -1)
        for li in range(len(self.layer_list)):
            input = self.layer_list[li](input)
            if (li < len(self.layer_list) - 1):
                input = F.relu(input)
        return input








class CNN_1_H(nn.Module):#input:[batch,dim,length]
    def __init__(self, shape):#hidden:[kernal,length_in,in_dim,out_dim,fc_hidden,outputs]
        super(CNN_1_H, self).__init__()
        self.shape = shape
        kernal=shape[0]
        length_in=shape[1]
        shape = shape[2:]
        self.layer_list = nn.ModuleList([])
        pad_size,dilation,stride=0,1,1
        length = math.floor((length_in + 2 * pad_size - dilation*(kernal - 1) - 1) / stride + 1)
        self.mainconv1 = nn.Conv1d(shape[0],  shape[1], kernal, stride=1)
        #self.adaavgpool = nn.AdaptiveAvgPool1d(1)
        h0 = shape[1]*length
        shapefc = shape[2:]
        for i in range(len(shapefc)):
            fc = nn.Linear(in_features=h0, out_features=shape[i])
            h0 = shape[i]
            self.layer_list.append(fc)
    def forward(self, input):
        input = self.mainconv1(input)
        input = F.relu(input)
        #input = self.adaavgpool(input)
        input = input.view(input.shape[0], -1)
        for li in range(len(self.layer_list)):
            input = self.layer_list[li](input)
            if (li < len(self.layer_list) - 1):
                input = F.relu(input)
        return input


def initial_models(models_shapeA,models_shapeB,test_train_size,model_num,device,models_typeA='Linear_N_H',models_typeB='Linear_N_H'):
    modelA_list = nn.ModuleList([]) # several train models
    modelB_list = nn.ModuleList([])
    train_choicelist = []
    models_shapeA = models_shapeA.split("-")
    models_shapeA = [int(x) for x in models_shapeA]
    models_shapeB = models_shapeB.split("-")
    models_shapeB = [int(x) for x in models_shapeB]
    for _ in range(model_num):
        modela_i=model_and_shape(models_shapeA,models_typeA).to(device)
        modelA_list.append(modela_i)
        modelb_i = model_and_shape(models_shapeB,models_typeB).to(device)
        modelB_list.append(modelb_i)
        train_choice1 = np.random.choice(test_train_size[0], size=test_train_size[1], replace=False)
        train_choicelist.append(train_choice1)
    return modelA_list,modelB_list,train_choicelist


def model_and_shape(models_shape,models_type):

    if(models_type=='LSTM_1_H'):
        return LSTM_1_H(models_shape)
    if (models_type == 'LSTM_2_H'):
        return LSTM_2_H(models_shape)
    elif (models_type == 'LSTM_N_S_H'):
        return LSTM_N_S_H(models_shape)
    elif (models_type == 'LSTM_N_S_H_2'):
        return LSTM_N_S_H_2(models_shape)
    elif (models_type == 'CNN_LSTM_1_H'):
        return CNN_LSTM_1_H(models_shape)
    elif(models_type=='CNN_1_H'):
        return CNN_1_H(models_shape)
    elif (models_type == 'CNN_2_H'):
        return CNN_2_H(models_shape)
    elif (models_type == 'Linear_N_H'):
        return Linear_N_H(models_shape)
    elif (models_type == 'Linear_N_H_1d'):
        return Linear_N_H_1d(models_shape)

    else:
        print('model type error')






