# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/108c_models.TSTPlus.ipynb (unless otherwise specified).

__all__ = ['Transpose', 'LinBnDrop', 'SigmoidRange', 'Reshape', 'SimpleLinearHead', 'BiLinearHead', 'sigmoid_range', 'get_activation_fn']
           

import torch
from torch import nn

class Transpose(nn.Module):
    def __init__(self, *dims, contiguous=False): 
        super().__init__()
        self.dims, self.contiguous = dims, contiguous
    def forward(self, x):
        if self.contiguous: return x.transpose(*self.dims).contiguous()
        else: return x.transpose(*self.dims)


class LinBnDrop(nn.Sequential):
    "Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers"
    def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=False):
        layers = [nn.BatchNorm2d(n_out if lin_first else n_in, ndim=1)] if bn else []
        if p != 0: layers.append(nn.Dropout(p))
        lin = [nn.Linear(n_in, n_out, bias=not bn)]
        if act is not None: lin.append(act)
        layers = lin+layers if lin_first else layers+lin
        super().__init__(*layers)


class SigmoidRange(nn.Sequential):
    def __init__(self, low, high):
        super().__init__()
        self.low, self.high = low, high
    def foward(self, x):        
        # return sigmoid_range(x, self.low, self.high)
        return torch.sigmoid(x) * (self.high - self.low) + self.low

    
class Reshape(nn.Module):
    def __init__(self, c_out, target_len):
        super().__init__()
        self.c_out = c_out
        self.target_len = target_len

    def forward(self, x):
        return x.view(-1, self.c_out, self.target_len)
    
    
class SimpleLinearHead(nn.Module):
    def __init__(self, seq_in, seq_out, head_dropout=0):
        super().__init__()
        self.head = nn.Linear(seq_in, seq_out)
        self.dropout = nn.Dropout(head_dropout)
        

    def forward(self, x):       # x : [bs x d_model x seq_in]
        x = self.head(x)        # x : [bs x d_model x seq_out]
        x = nn.Flatten()(x)     # x : [bs x d_model * seq_out]
        x = self.dropout(x)
        return x
    
    
class BiLinearHead(nn.Module):
    def __init__(self, c_in, seq_in, c_out, seq_out, head_dropout=0):
        super().__init__()
        self.linear1 = nn.Linear(c_in, c_out)
        self.linear2 = nn.Linear(seq_in, seq_out)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x):       # x : [bs x d_model x seq_in]
        x = x.permute(0,2,1)    # x : [bs x seq_in x d_model]
        x = self.linear1(x)     # x : [bs x seq_in x feature]
        x = x.permute(0,2,1)    # x : [bs x feature x seq_in]
        x = self.linear2(x)     # x : [bs x feature x seq_out]
        x = nn.Flatten()(x)     # x : [bs x d_model * seq_out]
        x = self.dropout(x)
        return x
    

def sigmoid_range(x, low, high):
    "Sigmoid function with range `(low, high)`"
    return torch.sigmoid(x) * (high - low) + low

def get_activation_fn(activation):
    if callable(activation): return activation()
    elif activation.lower() == "relu": return nn.ReLU()
    elif activation.lower() == "gelu": return nn.GELU()
    raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable')


