#  Copyright (c) 2024, Salesforce, Inc.
#  SPDX-License-Identifier: Apache-2
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from functools import partial
from typing import Any, Tuple, List, Literal

import torch
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
from hydra.utils import instantiate
from jaxtyping import Bool, Float, Int
from torch import nn, Tensor
from torch.distributions import Distribution
from torch.utils._pytree import tree_map

from uni2ts.common.torch_util import mask_fill, packed_attention_mask, packed_causal_attention_mask
from uni2ts.distribution import DistributionOutput
from uni2ts.module.norm import RMSNorm
from uni2ts.module.packed_scaler import PackedNOPScaler, PackedStdScaler
from uni2ts.module.position import (
    BinaryAttentionBias,
    QueryKeyProjection,
    RotaryProjection,
)
from uni2ts.module.transformer import TransformerEncoder
from uni2ts.module.ts_embed import MultiInSizeLinear


def encode_distr_output(
    distr_output: DistributionOutput,
) -> dict[str, str | float | int]:
    """Serialization function for DistributionOutput"""

    def _encode(val):
        if not isinstance(val, DistributionOutput):
            return val

        return {
            "_target_": f"{val.__class__.__module__}.{val.__class__.__name__}",
            **tree_map(_encode, val.__dict__),
        }

    return _encode(distr_output)


def decode_distr_output(config: dict[str, str | float | int]) -> DistributionOutput:
    """Deserialization function for DistributionOutput"""
    return instantiate(config, _convert_="all")


class SATSModule(
    nn.Module,
    PyTorchModelHubMixin,
    coders={DistributionOutput: (encode_distr_output, decode_distr_output)},
):
    """
    Contains components of SATS, to ensure implementation is identical across models.
    Subclasses huggingface_hub.PyTorchModelHubMixin to support loading from HuggingFace Hub.
    """

    def __init__(
        self,
        distr_output: DistributionOutput,
        d_model: int,
        num_layers: int,
        patch_sizes: tuple[int, ...],  # tuple[int, ...] | list[int]
        max_seq_len: int,
        attn_dropout_p: float,
        dropout_p: float,
        attn_mask_type: str = "default",
        scaling: bool = True,
    ):
        """
        :param distr_output: distribution output object
        :param d_model: model hidden dimensions
        :param num_layers: number of transformer layers
        :param patch_sizes: sequence of patch sizes
        :param max_seq_len: maximum sequence length for inputs
        :param attn_dropout_p: dropout probability for attention layers
        :param dropout_p: dropout probability for all other layers
        :param scaling: whether to apply scaling (standardization)
        """
        super().__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.patch_sizes = patch_sizes
        self.max_seq_len = max_seq_len
        self.scaling = scaling
        self.attn_mask_type = attn_mask_type

        self.mask_encoding = nn.Embedding(num_embeddings=1, embedding_dim=d_model)
        self.scaler = PackedStdScaler() if scaling else PackedNOPScaler()
        self.in_proj = MultiInSizeLinear(
            in_features_ls=patch_sizes,
            out_features=d_model,
        )
        self.encoder = TransformerEncoder(
            d_model,
            num_layers,
            num_heads=None,
            pre_norm=True,
            attn_dropout_p=attn_dropout_p,
            dropout_p=dropout_p,
            norm_layer=RMSNorm,
            activation=F.silu,
            use_glu=True,
            use_qk_norm=True,
            var_attn_bias_layer=partial(BinaryAttentionBias),
            time_qk_proj_layer=partial(
                QueryKeyProjection,
                proj_layer=RotaryProjection,
                kwargs=dict(max_len=max_seq_len),
                partial_factor=(0.0, 0.5),
            ),
            shared_var_attn_bias=False,
            shared_time_qk_proj=True,
            d_ff=None,
        )
        self.distr_output = distr_output
        self.param_proj = self.distr_output.get_param_proj(d_model, patch_sizes)

    def forward(
        self,
        target: Float[torch.Tensor, "*batch seq_len max_patch"],
        observed_mask: Bool[torch.Tensor, "*batch seq_len max_patch"],
        sample_id: Int[torch.Tensor, "*batch seq_len"],
        time_id: Int[torch.Tensor, "*batch seq_len"],
        variate_id: Int[torch.Tensor, "*batch seq_len"],
        prediction_mask: Bool[torch.Tensor, "*batch seq_len"],
        patch_size: Int[torch.Tensor, "*batch seq_len"],
    ) -> tuple[Distribution, Tensor | None, Tensor | None] | Distribution:
        """
        Defines the forward pass of SATSModule.
        This method expects processed inputs.

        1. Apply scaling to observations
        2. Project from observations to representations
        3. Replace prediction window with learnable mask
        4. Apply transformer layers
        5. Project from representations to distribution parameters
        6. Return distribution object

        :param target: input data
        :param observed_mask: binary mask for missing values, 1 if observed, 0 otherwise
        :param sample_id: indices indicating the sample index (for packing)
        :param time_id: indices indicating the time index
        :param variate_id: indices indicating the variate index
        :param prediction_mask: binary mask for prediction horizon, 1 if part of the horizon, 0 otherwise
        :param patch_size: patch size for each token
        :return: predictive distribution
        """
        # RevIN归一化
        loc, scale = self.scaler(
            target,
            observed_mask * ~prediction_mask.unsqueeze(-1),
            sample_id,
            variate_id,
        )
        scaled_target = (target - loc) / scale
        # muti-patch映射
        reprs = self.in_proj(scaled_target, patch_size)
        if self.training:
            agg_reprs = self.aggregate_patch_embeddings(reprs=reprs,patch_size=patch_size,patch_sizes=self.patch_sizes,mode="mean")
            agg_reprs_2 = self.aggregate_patch_embeddings(reprs=reprs, patch_size=patch_size,
                                                        patch_sizes=self.patch_sizes, mode="max")
        else:
            agg_reprs = None
            agg_reprs_2 = None
        # 将 reprs 中被 prediction_mask 指示的时间点，用一个统一的可学习嵌入向量替代
        masked_reprs = mask_fill(reprs, prediction_mask, self.mask_encoding.weight)

        if self.attn_mask_type == "default":
            attn_mask = packed_attention_mask(sample_id)
        elif self.attn_mask_type == "causal":
            attn_mask = packed_causal_attention_mask(sample_id,time_id)
        else:
            raise NotImplementedError
        reprs = self.encoder(
            masked_reprs,
            attn_mask,
            time_id=time_id,
            var_id=variate_id,
        )
        # distr_params 是一个字典类型（dict）
        # 包含以下两个主要字段：
        # - "weights_logits"：一个形状为 [B, L, P, components_num] 的 torch.float32 tensor，
        # - "components"：一个长度为 components_num 的列表，每个元素是用于建模一个混合分布分量的参数; 列表结构为list-dict-tensor，tensor形状为[B, L, P]
        distr_param = self.param_proj(reprs, patch_size)
        distr = self.distr_output.distribution(distr_param, loc=loc, scale=scale)
        if self.training:
            return distr,agg_reprs,agg_reprs_2
        else:
            return distr

    @staticmethod
    def aggregate_patch_embeddings(
            reprs: Float[torch.Tensor, "*batch seq_len dim"],
            patch_size: Int[torch.Tensor, "*batch seq_len"],
            patch_sizes: Tuple[int, ...] | List[int],
            mode: Literal["mean", "max", "min", "random"] = "mean",
    ) -> Float[torch.Tensor, "*batch num_patch_size dim"]:
        """
        按 patch_sizes 聚合嵌入特征，返回每个 patch size 的平均嵌入。

        Args:
            reprs: 特征表示, [B, L, D]
            patch_size: patch size 标签, [B, L]
            patch_sizes: 要聚合的 patch size 值列表，例如 (8,16,32,64,128)
            mode: 聚合方式, "mean" 或 "max"

        Returns:
            聚合后嵌入, [B, len(patch_sizes), D]
        """
        B, L, D = reprs.shape
        patch_sizes_tensor = torch.tensor(patch_sizes, device=reprs.device)

        num_patches = len(patch_sizes)
        agg = torch.zeros(B, num_patches, D, device=reprs.device)

        if mode == "mean":
            count = torch.zeros(B, num_patches, 1, device=reprs.device)

            for i, ps in enumerate(patch_sizes_tensor):
                mask = (patch_size == ps).unsqueeze(-1)  # [B, L, 1]
                masked_reprs = reprs * mask
                sum_reprs = masked_reprs.sum(dim=1)  # [B, D]
                num = mask.sum(dim=1)  # [B, 1]

                agg[:, i, :] = sum_reprs
                count[:, i, :] = num

            # 防止除以0
            count = count.clamp(min=1)
            agg = agg / count

        elif mode == "max":
            # 用很小的值初始化，保证取最大值时不受影响
            min_val = torch.finfo(reprs.dtype).min
            agg = torch.full((B, num_patches, D), min_val, device=reprs.device)

            for i, ps in enumerate(patch_sizes_tensor):
                mask = (patch_size == ps).unsqueeze(-1)  # [B, L, 1]
                masked_reprs = reprs.masked_fill(~mask, min_val)  # 无效位置设为最小值
                agg[:, i, :], _ = masked_reprs.max(dim=1)

        elif mode == "min":
            # 用很大的值初始化，保证取最小值时不受影响
            max_val = torch.finfo(reprs.dtype).max
            agg = torch.full((B, num_patches, D), max_val, device=reprs.device)

            for i, ps in enumerate(patch_sizes_tensor):
                mask = (patch_size == ps).unsqueeze(-1)  # [B, L, 1]
                masked_reprs = reprs.masked_fill(~mask, max_val)  # 无效位置设为最大值
                agg[:, i, :], _ = masked_reprs.min(dim=1)

        elif mode == "random":
            agg = torch.empty((B, num_patches, D), device=reprs.device)

            for i, ps in enumerate(patch_sizes_tensor):
                mask = (patch_size == ps)  # [B, L]
                agg_patch = []
                for b in range(B):
                    indices = torch.nonzero(mask[b], as_tuple=False).squeeze(1)
                    if indices.numel() > 0:
                        rand_idx = indices[torch.randint(0, indices.size(0), (1,))]  # [1]
                        rand_repr = reprs[b, rand_idx].squeeze(0)  # shape: [D]
                        agg_patch.append(rand_repr)
                    else:
                        # 若没有对应的 patch，填 0 向量
                        agg_patch.append(torch.zeros(D, device=reprs.device))
                agg[:, i, :] = torch.stack(agg_patch, dim=0)

        else:
            raise ValueError(f"Unsupported aggregation mode: {mode}")

        return agg