# Copyright (C) Authors of submission, all rights reserved


from abc import abstractmethod
from dataclasses import dataclass
import logging
from typing import Dict, List, Optional, Tuple
from dacite import from_dict, Config
import torch

from .mixed_stack import xLSTMMixedLargeConfig, xLSTMMixedLargeBlockStack

LOGGER = logging.getLogger()


class ResidualBlock(torch.nn.Module):
    def __init__(
        self,
        in_dim: int,
        h_dim: int,
        out_dim: int,
    ) -> None:
        super().__init__()
        self.hidden_layer = torch.nn.Linear(in_dim, h_dim)
        self.output_layer = torch.nn.Linear(h_dim, out_dim)
        self.residual_layer = torch.nn.Linear(in_dim, out_dim)
        self.act = torch.nn.ReLU()

    def forward(self, x: torch.Tensor):
        hid = self.act(self.hidden_layer(x))
        out = self.output_layer(hid)
        res = self.residual_layer(x)
        out = out + res
        return out


@dataclass
class PatchDecOnlyLModelConfig:
    input_patch_size: int
    output_patch_size: int
    quantiles: List[float]
    block_kwargs: Dict
    input_ff_dim: int


@dataclass
class PatchOutput:
    quantile_preds: Optional[torch.FloatTensor] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None


class _PatchDecOnlyLModel(torch.nn.Module):
    def __init__(self, model_config: dict):
        super().__init__()
        self.model_config: PatchDecOnlyLModelConfig = from_dict(PatchDecOnlyLModelConfig, model_config, config=Config(strict=True))
        assert self.model_config.input_patch_size == self.model_config.output_patch_size

        # Setup Model
        self.nan_mask_value = 0

        # Block Stack
        self.block_stack, resolved_config = self.init_block(self.model_config.block_kwargs)
        self.model_config.block_kwargs = resolved_config

        # Input Layer
        self.input_patch_embedding = ResidualBlock(
            in_dim=self.model_config.input_patch_size * 2,
            h_dim=self.model_config.input_ff_dim,
            out_dim=self.model_config.block_kwargs.embedding_dim,
        )

        # Output Layer
        self.num_quantiles = len(self.model_config.quantiles)
        quantiles = torch.tensor(self.model_config.quantiles)
        self.register_buffer("quantiles", quantiles, persistent=False)

        self.output_patch_embedding = ResidualBlock(
            in_dim=self.model_config.block_kwargs.embedding_dim,
            h_dim=self.model_config.input_ff_dim,
            out_dim=self.num_quantiles * self.model_config.output_patch_size,
        )


    @abstractmethod
    def init_block(self, block_kwargs):
        pass


    def forward(
        self,
        input_token,
        input_mask,
    ):
        input_mask = (
            input_mask.to(input_token.dtype)
            if input_mask is not None
            else torch.isnan(input_token).logical_not().to(input_token.dtype)
        )

        batch_size, numb_token, token_dim = input_token.shape
        input_token = torch.nan_to_num(input_token, nan=self.nan_mask_value)

        input_embeds = self.input_patch_embedding(torch.cat((input_token, input_mask), dim=2))
        x = self.block_stack(input_embeds)
        if isinstance(x, tuple):
            hidden_states = x[0]
        else:
            hidden_states = x

        quantile_preds = self.output_patch_embedding(hidden_states)
        quantile_preds = torch.unflatten(quantile_preds, -1, (self.num_quantiles, self.model_config.output_patch_size))
        quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension

        return PatchOutput(
            quantile_preds=quantile_preds,
            hidden_states=hidden_states,
        )


class TiRex(_PatchDecOnlyLModel):

    def init_block(self, block_kwargs):
        config = from_dict(xLSTMMixedLargeConfig, block_kwargs)
        return xLSTMMixedLargeBlockStack(config), config
