import torch.nn as nn
import torch.nn.functional as F
import torch 
import math

class CNN1D_Classifier(nn.Module):
    '''
    CNN Model with relatively small number of parameters. Uses residual connections.
    Adapted from https://github.com/eddymina/ECG_Classification_Pytorch 
    '''
    def __init__(self, input_size,num_classes):
        super(CNN1D_Classifier, self).__init__()
    
        self.conv= nn.Conv1d(in_channels=input_size, out_channels=32, kernel_size=5,stride=1)
        
        self.conv_pad = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=5,stride=1,padding=2)
        self.drop_50 = nn.Dropout(p=0.5)

        self.maxpool = nn.MaxPool1d(kernel_size=5,stride=2) 

        self.dense1 = nn.Linear(32 * 8, 32) 
        self.dense2 = nn.Linear(32, 32) 
        
        self.dense_final = nn.Linear(32, num_classes)
        self.softmax= nn.LogSoftmax(dim=1)

    def forward(self, x,tsne_out = False):
        if x.dim() ==2:
            x.unsqueeze_(1)

        residual= self.conv(x)
      
        #block1 
        x = F.relu(self.conv_pad(residual))
        x = self.conv_pad(x)
        x+= residual 
        x = F.relu(x)
        residual = self.maxpool(x) #[512 32 90]
       
        #block2
        x=F.relu(self.conv_pad(residual))
        x=self.conv_pad(x)
        x+=residual
        x= F.relu(x)
        residual = self.maxpool(x) #[512 32 43]
        
        
        #block3
        x=F.relu(self.conv_pad(residual))
        x=self.conv_pad(x)
        x+=residual
        x= F.relu(x)
        residual = self.maxpool(x) #[512 32 20]
        
        
        #block4
        x=F.relu(self.conv_pad(residual))
        x=self.conv_pad(x)
        x+=residual
        x= F.relu(x)
        x= self.maxpool(x) #[512 32 8]
        
        #MLP
        x = x.view(-1, 32 * 8) #Reshape (current_dim, 32*2)
        if tsne_out ==True:
            return x
        x = F.relu(self.dense1(x))
        #x = self.drop_60(x)
        x= self.dense2(x)
        x = self.softmax(self.dense_final(x))
        return x
      
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=0.1)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x = x + self.pe[:x.size(1), :].squeeze(1)
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
        


class CNN_NoRes1D_Classifier(nn.Module):

    def __init__(self,in_channel,n_classes,out_channel=None) -> None:
        super(CNN_NoRes1D_Classifier,self).__init__()
        self.in_channel = in_channel #1 
        out_channel = out_channel if out_channel is not None else 64
        self.out_channel = out_channel
        self.n_classes = n_classes

        self.conv_in = nn.Conv1d(in_channels=in_channel, out_channels=out_channel,
                                 kernel_size=5, stride=1, padding='same')

        self.bn1 = nn.BatchNorm1d(out_channel)
        self.dp1 = nn.Dropout1d(0.1)


        self.conv_1 = nn.Conv1d(in_channels=out_channel, out_channels=out_channel,
                                 kernel_size=5, stride=1, padding='same')

        self.bn2 = nn.BatchNorm1d(out_channel)
        self.dp2 = nn.Dropout1d(0.1)

        self.conv_2 = nn.Conv1d(in_channels=out_channel, out_channels=out_channel,
                                 kernel_size=1, stride=1, padding='same')

        self.bn2 = nn.BatchNorm1d(out_channel)
        self.dp2 = nn.Dropout1d(0.1)

        self.maxpool = nn.MaxPool1d(kernel_size=5,stride=2,padding=1)

        self.bn3 = nn.BatchNorm1d(out_channel)
        self.dp3 = nn.Dropout(0.1)

        self.classifier = nn.Linear(out_channel, n_classes)

    def forward(self,x):
        if x.dim() ==2:
            x.unsqueeze_(1)
        hidden = self.conv_in(x)
        hidden = self.bn1(hidden)
        hidden = F.relu(hidden)
        hidden = self.dp1(hidden)
        
        hidden = self.conv_1(hidden)
        hidden = self.bn2(hidden)
        hidden = F.relu(hidden)
        hidden = self.dp2(hidden)
        
        hidden = self.maxpool(hidden)
        hidden = self.conv_2(hidden)
        hidden = self.bn3(hidden)
        hidden = self.dp3(hidden)

        hidden_avg_pooled = torch.amax(hidden,dim=2)
        outp = self.classifier(hidden_avg_pooled)

        return outp

class Transformer1D_Classifier(nn.Module):

    def __init__(self,in_channel,out_channel,n_heads,n_transformer_layer,dim_feedforward,n_classes) -> None:
        super(Transformer1D_Classifier,self).__init__()
        self.in_channel = in_channel #1 
        self.out_channel = out_channel #128
        self.n_heads = n_heads
        self.n_transformer_layer = n_transformer_layer
        self.n_classes = n_classes

        self.conv_in = nn.Conv1d(in_channels=in_channel, out_channels=out_channel,
                                 kernel_size=5, stride=1, padding='same')

        self.conv_1 = nn.Conv1d(in_channels=out_channel, out_channels=out_channel,
                                 kernel_size=5, stride=1, padding='same')

        self.position_encoding = PositionalEncoding(out_channel,dropout=0.2,max_len=1000)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=out_channel,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=0.2,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=n_transformer_layer,
        )

        self.classifier = nn.Linear(out_channel, n_classes)

    def forward(self,x):
        
        hidden = self.conv_in(x)
        hidden = F.relu(hidden)

        hidden = self.conv_1(hidden)
        hidden = F.relu(hidden) 

        hidden = self.position_encoding(hidden)

        hidden = self.transformer_encoder()


        x_r = x.permute(2,0,1)




MODEL_DICT = {'cnn1d':[CNN1D_Classifier,'out/cnn1d.pth','out/cnn1d_result.pkl'],
              'cnn_nores1d':[CNN_NoRes1D_Classifier,'out/cnn_nores1d.pth','out/cnn_nores1d_result.pkl']}


