# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models.lstm.ipynb.

# %% auto 0
__all__ = ['LSTM']

# %% ../../nbs/models.lstm.ipynb 6
from typing import Optional

import torch
import torch.nn as nn

from ..losses.pytorch import MAE
from ..common._base_recurrent import BaseRecurrent
from ..common._modules import MLP

# %% ../../nbs/models.lstm.ipynb 7

class ComplexLinear(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ComplexLinear, self).__init__()
        self.real_weights = nn.Parameter(torch.randn(input_dim, output_dim) * 0.1)
        self.real_bias = nn.Parameter(torch.randn(output_dim) * 0.01)
        self.imag_bias = nn.Parameter(torch.randn(output_dim) * 0.01 - 0.1)
        self.lambda_k1 = nn.Parameter(torch.randn(output_dim) * 0.1 + 0.5)
        self.lambda_k2 = nn.Parameter(torch.randn(output_dim) * 0.1 + 0.5)
    
    def forward(self, x):
        if torch.is_complex(x):
            x_real = x.real
            x_imag = x.imag
        else:
            x_real = x
            x_imag = torch.zeros_like(x)
        
        real_1 = torch.matmul(x_real, self.real_weights) + self.real_bias
        real = real_1 ** 2
        complex_output = real + self.imag_bias ** 2
        complex_output_1 = real_1 / complex_output
        alpha_max = torch.max(torch.abs(complex_output_1))
        output_1 = complex_output_1 * self.lambda_k1
        
        complex_output_2 = self.imag_bias / complex_output
        beta_max = torch.max(torch.abs(complex_output_2))
        output_2 = complex_output_2 * self.lambda_k2 
        
        return output_1 + output_2

# 自定义LSTM单元
class CustomLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_size):
        super(CustomLSTMCell, self).__init__()
        self.hidden_size = hidden_size
        self.complex_linear = ComplexLinear(input_dim, hidden_size * 4)
        self.recurrent_kernel = nn.Parameter(torch.randn(hidden_size, hidden_size * 4) * 0.1)
        self.bias = nn.Parameter(torch.zeros(hidden_size * 4))
    
    def forward(self, inputs, states):
        h_tm1, c_tm1 = states
        z = self.complex_linear(inputs) + torch.matmul(h_tm1, self.recurrent_kernel) + self.bias
        z0, z1, z2, z3 = torch.chunk(z, 4, dim=1)
        
        i = torch.sigmoid(z0)
        f = torch.sigmoid(z1)
        c = f * c_tm1 + i * z2
        o = torch.sigmoid(z3)
        h = o * c
        
        return h, [h, c]

class LSTM(BaseRecurrent):

    # Class attributes
    SAMPLING_TYPE = "recurrent"
    EXOGENOUS_FUTR = True
    EXOGENOUS_HIST = True
    EXOGENOUS_STAT = True

    def __init__(
        self,
        h: int,
        input_size: int = -1,
        inference_input_size: int = -1,
        encoder_n_layers: int = 2,
        encoder_hidden_size: int = 200,
        encoder_bias: bool = True,
        encoder_dropout: float = 0.0,
        context_size: int = 10,
        decoder_hidden_size: int = 200,
        decoder_layers: int = 2,
        futr_exog_list=None,
        hist_exog_list=None,
        stat_exog_list=None,
        loss=MAE(),
        valid_loss=None,
        max_steps: int = 1000,
        learning_rate: float = 1e-3,
        num_lr_decays: int = -1,
        early_stop_patience_steps: int = -1,
        val_check_steps: int = 100,
        batch_size=32,
        valid_batch_size: Optional[int] = None,
        scaler_type: str = "robust",
        random_seed=1,
        num_workers_loader=0,
        drop_last_loader=False,
        optimizer=None,
        optimizer_kwargs=None,
        lr_scheduler=None,
        lr_scheduler_kwargs=None,
        **trainer_kwargs
    ):
        super(LSTM, self).__init__(
            h=h,
            input_size=input_size,
            inference_input_size=inference_input_size,
            loss=loss,
            valid_loss=valid_loss,
            max_steps=max_steps,
            learning_rate=learning_rate,
            num_lr_decays=num_lr_decays,
            early_stop_patience_steps=early_stop_patience_steps,
            val_check_steps=val_check_steps,
            batch_size=batch_size,
            valid_batch_size=valid_batch_size,
            scaler_type=scaler_type,
            futr_exog_list=futr_exog_list,
            hist_exog_list=hist_exog_list,
            stat_exog_list=stat_exog_list,
            num_workers_loader=num_workers_loader,
            drop_last_loader=drop_last_loader,
            random_seed=random_seed,
            optimizer=optimizer,
            optimizer_kwargs=optimizer_kwargs,
            lr_scheduler=lr_scheduler,
            lr_scheduler_kwargs=lr_scheduler_kwargs,
            **trainer_kwargs
        )

        # LSTM
        self.encoder_n_layers = encoder_n_layers
        self.encoder_hidden_size = encoder_hidden_size
        self.encoder_bias = encoder_bias
        self.encoder_dropout = encoder_dropout

        # Context adapter
        self.context_size = context_size

        # MLP decoder
        self.decoder_hidden_size = decoder_hidden_size
        self.decoder_layers = decoder_layers

        # LSTM input size (1 for target variable y)
        input_encoder = 1 + self.hist_exog_size + self.stat_exog_size

        # Instantiate model
        self.hist_encoder = nn.LSTM(
            input_size=input_encoder,
            hidden_size=self.encoder_hidden_size,
            num_layers=self.encoder_n_layers,
            bias=self.encoder_bias,
            dropout=self.encoder_dropout,
            batch_first=True,
        )

        # Context adapter
        self.context_adapter = nn.Linear(
            in_features=self.encoder_hidden_size + self.futr_exog_size * h,
            out_features=self.context_size * h,
        )

        # Decoder MLP
        self.mlp_decoder = MLP(
            in_features=self.context_size + self.futr_exog_size,
            out_features=self.loss.outputsize_multiplier,
            hidden_size=self.decoder_hidden_size,
            num_layers=self.decoder_layers,
            activation="ReLU",
            dropout=0.0,
        )

    def forward(self, windows_batch):

        # Parse windows_batch
        encoder_input = windows_batch["insample_y"]  # [B, seq_len, 1]
        futr_exog = windows_batch["futr_exog"]
        hist_exog = windows_batch["hist_exog"]
        stat_exog = windows_batch["stat_exog"]

        # Concatenate y, historic and static inputs
        # [B, C, seq_len, 1] -> [B, seq_len, C]
        # Contatenate [ Y_t, | X_{t-L},..., X_{t} | S ]
        batch_size, seq_len = encoder_input.shape[:2]
        if self.hist_exog_size > 0:
            hist_exog = hist_exog.permute(0, 2, 1, 3).squeeze(
                -1
            )  # [B, X, seq_len, 1] -> [B, seq_len, X]
            encoder_input = torch.cat((encoder_input, hist_exog), dim=2)

        if self.stat_exog_size > 0:
            stat_exog = stat_exog.unsqueeze(1).repeat(
                1, seq_len, 1
            )  # [B, S] -> [B, seq_len, S]
            encoder_input = torch.cat((encoder_input, stat_exog), dim=2)

        # RNN forward
        hidden_state, _ = self.hist_encoder(
            encoder_input
        )  # [B, seq_len, rnn_hidden_state]

        if self.futr_exog_size > 0:
            futr_exog = futr_exog.permute(0, 2, 3, 1)[
                :, :, 1:, :
            ]  # [B, F, seq_len, 1+H] -> [B, seq_len, H, F]
            hidden_state = torch.cat(
                (hidden_state, futr_exog.reshape(batch_size, seq_len, -1)), dim=2
            )

        # Context adapter
        context = self.context_adapter(hidden_state)
        context = context.reshape(batch_size, seq_len, self.h, self.context_size)

        # Residual connection with futr_exog
        if self.futr_exog_size > 0:
            context = torch.cat((context, futr_exog), dim=-1)

        # Final forecast
        output = self.mlp_decoder(context)
        output = self.loss.domain_map(output)

        return output
