# coding=utf-8
# Copyright 2023 IBM and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch PatchTSMixer model. """

import inspect
import itertools
import random
import tempfile
import unittest
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
from huggingface_hub import hf_hub_download
from parameterized import parameterized

from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin


TOLERANCE = 1e-4

if is_torch_available():
    import torch

    from transformers import (
        MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING,
        MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING,
        PatchTSMixerConfig,
        PatchTSMixerForPrediction,
        PatchTSMixerForPretraining,
        PatchTSMixerForRegression,
        PatchTSMixerForTimeSeriesClassification,
        PatchTSMixerModel,
    )
    from transformers.models.patchtsmixer.modeling_patchtsmixer import (
        PatchTSMixerEncoder,
        PatchTSMixerForPredictionHead,
        PatchTSMixerForPredictionOutput,
        PatchTSMixerForRegressionOutput,
        PatchTSMixerForTimeSeriesClassificationOutput,
        PatchTSMixerLinearHead,
        PatchTSMixerPretrainHead,
    )


@require_torch
class PatchTSMixerModelTester:
    def __init__(
        self,
        context_length: int = 32,
        patch_length: int = 8,
        num_input_channels: int = 3,
        patch_stride: int = 8,
        # d_model: int = 128,
        hidden_size: int = 8,
        # num_layers: int = 8,
        num_hidden_layers: int = 2,
        expansion_factor: int = 2,
        dropout: float = 0.5,
        mode: str = "common_channel",
        gated_attn: bool = True,
        norm_mlp="LayerNorm",
        swin_hier: int = 0,
        # masking related
        mask_type: str = "forecast",
        random_mask_ratio=0.5,
        mask_patches: list = [2, 3],
        forecast_mask_ratios: list = [1, 1],
        mask_value=0,
        masked_loss: bool = False,
        mask_mode: str = "mask_before_encoder",
        channel_consistent_masking: bool = True,
        scaling: Optional[Union[str, bool]] = "std",
        # Head related
        head_dropout: float = 0.2,
        # forecast related
        prediction_length: int = 16,
        out_channels: int = None,
        # Classification/regression related
        # num_labels: int = 3,
        num_targets: int = 3,
        output_range: list = None,
        head_aggregation: str = None,
        # Trainer related
        batch_size=13,
        is_training=True,
        seed_number=42,
        post_init=True,
        num_parallel_samples=4,
    ):
        self.num_input_channels = num_input_channels
        self.context_length = context_length
        self.patch_length = patch_length
        self.patch_stride = patch_stride
        # self.d_model = d_model
        self.hidden_size = hidden_size
        self.expansion_factor = expansion_factor
        # self.num_layers = num_layers
        self.num_hidden_layers = num_hidden_layers
        self.dropout = dropout
        self.mode = mode
        self.gated_attn = gated_attn
        self.norm_mlp = norm_mlp
        self.swin_hier = swin_hier
        self.scaling = scaling
        self.head_dropout = head_dropout
        # masking related
        self.mask_type = mask_type
        self.random_mask_ratio = random_mask_ratio
        self.mask_patches = mask_patches
        self.forecast_mask_ratios = forecast_mask_ratios
        self.mask_value = mask_value
        self.channel_consistent_masking = channel_consistent_masking
        self.mask_mode = mask_mode
        self.masked_loss = masked_loss
        # patching related
        self.patch_last = True
        # forecast related
        self.prediction_length = prediction_length
        self.out_channels = out_channels
        # classification/regression related
        # self.num_labels = num_labels
        self.num_targets = num_targets
        self.output_range = output_range
        self.head_aggregation = head_aggregation
        # Trainer related
        self.batch_size = batch_size
        self.is_training = is_training
        self.seed_number = seed_number
        self.post_init = post_init
        self.num_parallel_samples = num_parallel_samples

    def get_config(self):
        config_ = PatchTSMixerConfig(
            num_input_channels=self.num_input_channels,
            context_length=self.context_length,
            patch_length=self.patch_length,
            patch_stride=self.patch_stride,
            # d_model = self.d_model,
            d_model=self.hidden_size,
            expansion_factor=self.expansion_factor,
            # num_layers = self.num_layers,
            num_layers=self.num_hidden_layers,
            dropout=self.dropout,
            mode=self.mode,
            gated_attn=self.gated_attn,
            norm_mlp=self.norm_mlp,
            swin_hier=self.swin_hier,
            scaling=self.scaling,
            head_dropout=self.head_dropout,
            mask_type=self.mask_type,
            random_mask_ratio=self.random_mask_ratio,
            mask_patches=self.mask_patches,
            forecast_mask_ratios=self.forecast_mask_ratios,
            mask_value=self.mask_value,
            channel_consistent_masking=self.channel_consistent_masking,
            mask_mode=self.mask_mode,
            masked_loss=self.masked_loss,
            prediction_length=self.prediction_length,
            out_channels=self.out_channels,
            # num_labels=self.num_labels,
            num_targets=self.num_targets,
            output_range=self.output_range,
            head_aggregation=self.head_aggregation,
            post_init=self.post_init,
        )
        self.num_patches = config_.num_patches
        return config_

    def prepare_patchtsmixer_inputs_dict(self, config):
        _past_length = config.context_length
        # bs, n_vars, num_patch, patch_length

        # [bs x context_length x n_vars]
        past_values = floats_tensor([self.batch_size, _past_length, self.num_input_channels])

        inputs_dict = {
            "past_values": past_values,
        }
        return inputs_dict

    def prepare_config_and_inputs(self):
        config = self.get_config()
        inputs_dict = self.prepare_patchtsmixer_inputs_dict(config)
        return config, inputs_dict

    def prepare_config_and_inputs_for_common(self):
        config, inputs_dict = self.prepare_config_and_inputs()
        return config, inputs_dict


@require_torch
class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
    all_model_classes = (
        (
            PatchTSMixerModel,
            PatchTSMixerForPrediction,
            PatchTSMixerForPretraining,
            PatchTSMixerForTimeSeriesClassification,
            PatchTSMixerForRegression,
        )
        if is_torch_available()
        else ()
    )
    all_generative_model_classes = (
        (PatchTSMixerForPrediction, PatchTSMixerForPretraining) if is_torch_available() else ()
    )
    pipeline_model_mapping = {"feature-extraction": PatchTSMixerModel} if is_torch_available() else {}
    is_encoder_decoder = False
    test_pruning = False
    test_head_masking = False
    test_missing_keys = False
    test_torchscript = False
    test_inputs_embeds = False
    test_model_common_attributes = False

    test_resize_embeddings = True
    test_resize_position_embeddings = False
    test_mismatched_shapes = True
    test_model_parallel = False
    has_attentions = False

    def setUp(self):
        self.model_tester = PatchTSMixerModelTester()
        self.config_tester = ConfigTester(
            self,
            config_class=PatchTSMixerConfig,
            has_text_modality=False,
            prediction_length=self.model_tester.prediction_length,
            common_properties=["hidden_size", "expansion_factor", "num_hidden_layers"],
        )

    def test_config(self):
        self.config_tester.run_common_tests()

    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
        inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)

        if model_class == PatchTSMixerForPrediction:
            rng = random.Random(self.model_tester.seed_number)
            labels = floats_tensor(
                [
                    self.model_tester.batch_size,
                    self.model_tester.prediction_length,
                    self.model_tester.num_input_channels,
                ],
                rng=rng,
            )
            inputs_dict["future_values"] = labels
        elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING):
            rng = random.Random(self.model_tester.seed_number)
            labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng)
            inputs_dict["target_values"] = labels
        elif model_class in get_values(MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING):
            rng = random.Random(self.model_tester.seed_number)
            labels = floats_tensor([self.model_tester.batch_size, self.model_tester.num_targets], rng=rng)
            inputs_dict["target_values"] = labels

        inputs_dict["output_hidden_states"] = True
        return inputs_dict

    def test_save_load_strict(self):
        config, _ = self.model_tester.prepare_config_and_inputs()
        for model_class in self.all_model_classes:
            model = model_class(config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
            self.assertEqual(info["missing_keys"], [])

    def test_hidden_states_output(self):
        def check_hidden_states_output(inputs_dict, config, model_class):
            model = model_class(config)
            model.to(torch_device)
            model.eval()

            with torch.no_grad():
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))

            hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states

            expected_num_layers = getattr(
                self.model_tester,
                "expected_num_hidden_layers",
                self.model_tester.num_hidden_layers,
            )
            self.assertEqual(len(hidden_states), expected_num_layers)

            expected_hidden_size = self.model_tester.hidden_size
            self.assertEqual(hidden_states[0].shape[-1], expected_hidden_size)

            num_patch = self.model_tester.num_patches
            self.assertListEqual(
                list(hidden_states[0].shape[-2:]),
                [num_patch, self.model_tester.hidden_size],
            )

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            check_hidden_states_output(inputs_dict, config, model_class)

    @unittest.skip("No tokens embeddings")
    def test_resize_tokens_embeddings(self):
        pass

    def test_model_outputs_equivalence(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        def set_nan_tensor_to_zero(t):
            t[t != t] = 0
            return t

        def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
            with torch.no_grad():
                tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
                output_ = model(**dict_inputs, return_dict=True, **additional_kwargs)
                attributes_ = vars(output_)
                dict_output = tuple(attributes_.values())

                def recursive_check(tuple_object, dict_object):
                    if isinstance(tuple_object, (List, Tuple)):
                        for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
                            recursive_check(tuple_iterable_value, dict_iterable_value)
                    elif isinstance(tuple_object, Dict):
                        for tuple_iterable_value, dict_iterable_value in zip(
                            tuple_object.values(), dict_object.values()
                        ):
                            recursive_check(tuple_iterable_value, dict_iterable_value)
                    elif tuple_object is None:
                        return
                    else:
                        self.assertTrue(
                            torch.allclose(
                                set_nan_tensor_to_zero(tuple_object),
                                set_nan_tensor_to_zero(dict_object),
                                atol=1e-5,
                            ),
                            msg=(
                                "Tuple and dict output are not equal. Difference:"
                                f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
                                f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
                                f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
                            ),
                        )

                recursive_check(tuple_output, dict_output)

        for model_class in self.all_model_classes:
            print(model_class)
            model = model_class(config)
            model.to(torch_device)
            model.eval()

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)

            check_equivalence(model, tuple_inputs, dict_inputs)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            check_equivalence(model, tuple_inputs, dict_inputs)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            tuple_inputs.update({"output_hidden_states": False})
            dict_inputs.update({"output_hidden_states": False})
            check_equivalence(model, tuple_inputs, dict_inputs)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            tuple_inputs.update({"output_hidden_states": False})
            dict_inputs.update({"output_hidden_states": False})
            check_equivalence(
                model,
                tuple_inputs,
                dict_inputs,
            )

    def test_model_main_input_name(self):
        model_signature = inspect.signature(getattr(PatchTSMixerModel, "forward"))
        # The main input is the name of the argument after `self`
        observed_main_input_name = list(model_signature.parameters.keys())[1]
        self.assertEqual(PatchTSMixerModel.main_input_name, observed_main_input_name)

    def test_forward_signature(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            signature = inspect.signature(model.forward)
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
            arg_names = [*signature.parameters.keys()]

            if model_class == PatchTSMixerForPretraining:
                expected_arg_names = [
                    "past_values",
                    "observed_mask",
                    "output_hidden_states",
                    "return_loss",
                ]
            elif model_class == PatchTSMixerModel:
                expected_arg_names = [
                    "past_values",
                    "observed_mask",
                    "output_hidden_states",
                ]
            elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING) or model_class in get_values(
                MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING
            ):
                expected_arg_names = [
                    "past_values",
                    "target_values",
                    "output_hidden_states",
                    "return_loss",
                ]
            else:
                # PatchTSMixerForPrediction
                expected_arg_names = [
                    "past_values",
                    "observed_mask",
                    "future_values",
                    "output_hidden_states",
                    "return_loss",
                ]

            self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)

    @is_flaky()
    def test_retain_grad_hidden_states_attentions(self):
        super().test_retain_grad_hidden_states_attentions()


def prepare_batch(repo_id="ibm/patchtsmixer-etth1-test-data", file="pretrain_batch.pt"):
    # TODO: Make repo public
    file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset")
    batch = torch.load(file, map_location=torch_device)
    return batch


@require_torch
@slow
class PatchTSMixerModelIntegrationTests(unittest.TestCase):
    def test_pretrain_head(self):
        model = PatchTSMixerForPretraining.from_pretrained("ibm/patchtsmixer-etth1-pretrain").to(torch_device)
        batch = prepare_batch()

        torch.manual_seed(0)
        with torch.no_grad():
            output = model(past_values=batch["past_values"].to(torch_device)).prediction_outputs
        num_patch = (
            max(model.config.context_length, model.config.patch_length) - model.config.patch_length
        ) // model.config.patch_stride + 1
        expected_shape = torch.Size(
            [
                64,
                model.config.num_input_channels,
                num_patch,
                model.config.patch_length,
            ]
        )
        self.assertEqual(output.shape, expected_shape)

        expected_slice = torch.tensor([[[[-0.9106]],[[1.5326]],[[-0.8245]],[[0.7439]],[[-0.7830]],[[2.6256]],[[-0.6485]],]],device=torch_device)  # fmt: skip
        self.assertTrue(torch.allclose(output[0, :7, :1, :1], expected_slice, atol=TOLERANCE))

    def test_forecasting_head(self):
        model = PatchTSMixerForPrediction.from_pretrained("ibm/patchtsmixer-etth1-forecasting").to(torch_device)
        batch = prepare_batch(file="forecast_batch.pt")

        model.eval()
        torch.manual_seed(0)
        with torch.no_grad():
            output = model(
                past_values=batch["past_values"].to(torch_device),
                future_values=batch["future_values"].to(torch_device),
            ).prediction_outputs

        expected_shape = torch.Size([64, model.config.prediction_length, model.config.num_input_channels])
        self.assertEqual(output.shape, expected_shape)

        expected_slice = torch.tensor(
            [[0.2471, 0.5036, 0.3596, 0.5401, -0.0985, 0.3423, -0.8439]],
            device=torch_device,
        )
        self.assertTrue(torch.allclose(output[0, :1, :7], expected_slice, atol=TOLERANCE))

    def test_prediction_generation(self):
        model = PatchTSMixerForPrediction.from_pretrained("ibm/patchtsmixer-etth1-generate").to(torch_device)
        batch = prepare_batch(file="forecast_batch.pt")
        print(batch["past_values"])

        torch.manual_seed(0)
        model.eval()
        with torch.no_grad():
            outputs = model.generate(past_values=batch["past_values"].to(torch_device))
        expected_shape = torch.Size((64, 1, model.config.prediction_length, model.config.num_input_channels))

        self.assertEqual(outputs.sequences.shape, expected_shape)

        expected_slice = torch.tensor(
            [[0.4308, -0.4731, 1.3512, -0.1038, -0.4655, 1.1279, -0.7179]],
            device=torch_device,
        )

        mean_prediction = outputs.sequences.mean(dim=1)

        self.assertTrue(torch.allclose(mean_prediction[0, -1:], expected_slice, atol=TOLERANCE))


@require_torch
class PatchTSMixerFunctionalTests(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        """Setup method: Called once before test-cases execution"""
        cls.params = {}
        cls.params.update(
            context_length=32,
            patch_length=8,
            num_input_channels=3,
            patch_stride=8,
            d_model=4,
            expansion_factor=2,
            num_layers=3,
            dropout=0.2,
            mode="common_channel",  # common_channel,  mix_channel
            gated_attn=True,
            norm_mlp="LayerNorm",
            mask_type="random",
            random_mask_ratio=0.5,
            mask_patches=[2, 3],
            forecast_mask_ratios=[1, 1],
            mask_value=0,
            masked_loss=True,
            channel_consistent_masking=True,
            head_dropout=0.2,
            prediction_length=64,
            out_channels=None,
            # num_labels=3,
            num_targets=3,
            output_range=None,
            head_aggregation=None,
            scaling="std",
            use_positional_encoding=False,
            positional_encoding="sincos",
            self_attn=False,
            self_attn_heads=1,
            num_parallel_samples=4,
        )

        cls.num_patches = (
            max(cls.params["context_length"], cls.params["patch_length"]) - cls.params["patch_length"]
        ) // cls.params["patch_stride"] + 1

        # batch_size = 32
        batch_size = 2

        int(cls.params["prediction_length"] / cls.params["patch_length"])

        cls.data = torch.rand(
            batch_size,
            cls.params["context_length"],
            cls.params["num_input_channels"],
        )

        cls.enc_data = torch.rand(
            batch_size,
            cls.params["num_input_channels"],
            cls.num_patches,
            cls.params["patch_length"],
        )

        cls.enc_output = torch.rand(
            batch_size,
            cls.params["num_input_channels"],
            cls.num_patches,
            cls.params["d_model"],
        )

        cls.flat_enc_output = torch.rand(
            batch_size,
            cls.num_patches,
            cls.params["d_model"],
        )

        cls.correct_pred_output = torch.rand(
            batch_size,
            cls.params["prediction_length"],
            cls.params["num_input_channels"],
        )
        cls.correct_regression_output = torch.rand(batch_size, cls.params["num_targets"])

        cls.correct_pretrain_output = torch.rand(
            batch_size,
            cls.params["num_input_channels"],
            cls.num_patches,
            cls.params["patch_length"],
        )

        cls.correct_forecast_output = torch.rand(
            batch_size,
            cls.params["prediction_length"],
            cls.params["num_input_channels"],
        )

        cls.correct_sel_forecast_output = torch.rand(batch_size, cls.params["prediction_length"], 2)

        cls.correct_classification_output = torch.rand(
            batch_size,
            cls.params["num_targets"],
        )

        cls.correct_classification_classes = torch.randint(0, cls.params["num_targets"], (batch_size,))

    def test_patchtsmixer_encoder(self):
        config = PatchTSMixerConfig(**self.__class__.params)
        enc = PatchTSMixerEncoder(config)
        output = enc(self.__class__.enc_data)
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)

    def test_patchmodel(self):
        config = PatchTSMixerConfig(**self.__class__.params)
        mdl = PatchTSMixerModel(config)
        output = mdl(self.__class__.data)
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
        self.assertEqual(output.patch_input.shape, self.__class__.enc_data.shape)

    def test_pretrainhead(self):
        config = PatchTSMixerConfig(**self.__class__.params)
        head = PatchTSMixerPretrainHead(
            config=config,
        )
        output = head(self.__class__.enc_output)

        self.assertEqual(output.shape, self.__class__.correct_pretrain_output.shape)

    def test_pretrain_full(self):
        config = PatchTSMixerConfig(**self.__class__.params)
        mdl = PatchTSMixerForPretraining(config)
        output = mdl(self.__class__.data)
        self.assertEqual(
            output.prediction_outputs.shape,
            self.__class__.correct_pretrain_output.shape,
        )
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
        self.assertEqual(output.loss.item() < np.inf, True)

    def test_pretrain_full_with_return_dict(self):
        config = PatchTSMixerConfig(**self.__class__.params)
        mdl = PatchTSMixerForPretraining(config)
        output = mdl(self.__class__.data, return_dict=False)
        self.assertEqual(output[1].shape, self.__class__.correct_pretrain_output.shape)
        self.assertEqual(output[2].shape, self.__class__.enc_output.shape)
        self.assertEqual(output[0].item() < np.inf, True)

    def test_forecast_head(self):
        config = PatchTSMixerConfig(**self.__class__.params)
        head = PatchTSMixerForPredictionHead(
            config=config,
        )
        # output = head(self.__class__.enc_output, raw_data = self.__class__.correct_pretrain_output)
        output = head(self.__class__.enc_output)

        self.assertEqual(output.shape, self.__class__.correct_forecast_output.shape)

    def check_module(
        self,
        task,
        params=None,
        output_hidden_states=True,
    ):
        config = PatchTSMixerConfig(**params)
        if task == "forecast":
            mdl = PatchTSMixerForPrediction(config)
            target_input = self.__class__.correct_forecast_output
            if config.prediction_channel_indices is not None:
                target_output = self.__class__.correct_sel_forecast_output
            else:
                target_output = target_input
            ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1, -1)
            ground_truth_arg = "future_values"
            output_predictions_arg = "prediction_outputs"
        elif task == "classification":
            mdl = PatchTSMixerForTimeSeriesClassification(config)
            target_input = self.__class__.correct_classification_classes
            target_output = self.__class__.correct_classification_output
            ground_truth_arg = "target_values"
            output_predictions_arg = "prediction_outputs"
        elif task == "regression":
            mdl = PatchTSMixerForRegression(config)
            target_input = self.__class__.correct_regression_output
            target_output = self.__class__.correct_regression_output
            ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1)
            ground_truth_arg = "target_values"
            output_predictions_arg = "regression_outputs"
        elif task == "pretrain":
            mdl = PatchTSMixerForPretraining(config)
            target_input = None
            target_output = self.__class__.correct_pretrain_output
            ground_truth_arg = None
            output_predictions_arg = "prediction_outputs"
        else:
            print("invalid task")

        enc_output = self.__class__.enc_output

        if target_input is None:
            output = mdl(self.__class__.data, output_hidden_states=output_hidden_states)
        else:
            output = mdl(
                self.__class__.data,
                **{
                    ground_truth_arg: target_input,
                    "output_hidden_states": output_hidden_states,
                },
            )

        prediction_outputs = getattr(output, output_predictions_arg)
        if isinstance(prediction_outputs, tuple):
            for t in prediction_outputs:
                self.assertEqual(t.shape, target_output.shape)
        else:
            self.assertEqual(prediction_outputs.shape, target_output.shape)

        self.assertEqual(output.last_hidden_state.shape, enc_output.shape)

        if output_hidden_states is True:
            self.assertEqual(len(output.hidden_states), params["num_layers"])

        else:
            self.assertEqual(output.hidden_states, None)

        self.assertEqual(output.loss.item() < np.inf, True)

        if config.loss == "nll" and task in ["forecast", "regression"]:
            samples = mdl.generate(self.__class__.data)
            self.assertEqual(samples.sequences.shape, ref_samples.shape)

    @parameterized.expand(
        list(
            itertools.product(
                ["common_channel", "mix_channel"],
                [True, False],
                [True, False, "mean", "std"],
                [True, False],
                [None, [0, 2]],
                ["mse", "nll"],
            )
        )
    )
    def test_forecast(self, mode, self_attn, scaling, gated_attn, prediction_channel_indices, loss):
        params = self.__class__.params.copy()
        params.update(
            mode=mode,
            self_attn=self_attn,
            scaling=scaling,
            prediction_channel_indices=prediction_channel_indices,
            gated_attn=gated_attn,
            loss=loss,
        )

        self.check_module(task="forecast", params=params)

    @parameterized.expand(
        list(
            itertools.product(
                ["common_channel", "mix_channel"],
                [True, False],
                [True, False, "mean", "std"],
                [True, False],
                ["max_pool", "avg_pool"],
            )
        )
    )
    def test_classification(self, mode, self_attn, scaling, gated_attn, head_aggregation):
        params = self.__class__.params.copy()
        params.update(
            mode=mode,
            self_attn=self_attn,
            scaling=scaling,
            head_aggregation=head_aggregation,
            gated_attn=gated_attn,
        )

        self.check_module(task="classification", params=params)

    @parameterized.expand(
        list(
            itertools.product(
                ["common_channel", "mix_channel"],
                [True, False],
                [True, False, "mean", "std"],
                [True, False],
                ["max_pool", "avg_pool"],
                ["mse", "nll"],
            )
        )
    )
    def test_regression(self, mode, self_attn, scaling, gated_attn, head_aggregation, loss):
        params = self.__class__.params.copy()
        params.update(
            mode=mode,
            self_attn=self_attn,
            scaling=scaling,
            head_aggregation=head_aggregation,
            gated_attn=gated_attn,
            loss=loss,
        )

        self.check_module(task="regression", params=params)

    @parameterized.expand(
        list(
            itertools.product(
                ["common_channel", "mix_channel"],
                [True, False],
                [True, False, "mean", "std"],
                [True, False],
                ["random", "forecast"],
                [True, False],
                [True, False],
            )
        )
    )
    def test_pretrain(
        self,
        mode,
        self_attn,
        scaling,
        gated_attn,
        mask_type,
        masked_loss,
        channel_consistent_masking,
    ):
        params = self.__class__.params.copy()
        params.update(
            mode=mode,
            self_attn=self_attn,
            scaling=scaling,
            gated_attn=gated_attn,
            mask_type=mask_type,
            masked_loss=masked_loss,
            channel_consistent_masking=channel_consistent_masking,
        )

        self.check_module(task="pretrain", params=params)

    def forecast_full_module(self, params=None, output_hidden_states=False, return_dict=None):
        config = PatchTSMixerConfig(**params)
        mdl = PatchTSMixerForPrediction(config)

        target_val = self.__class__.correct_forecast_output

        if config.prediction_channel_indices is not None:
            target_val = self.__class__.correct_sel_forecast_output

        enc_output = self.__class__.enc_output

        output = mdl(
            self.__class__.data,
            future_values=self.__class__.correct_forecast_output,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if isinstance(output, tuple):
            output = PatchTSMixerForPredictionOutput(*output)

        if config.loss == "mse":
            self.assertEqual(output.prediction_outputs.shape, target_val.shape)

        self.assertEqual(output.last_hidden_state.shape, enc_output.shape)

        if output_hidden_states is True:
            self.assertEqual(len(output.hidden_states), params["num_layers"])

        else:
            self.assertEqual(output.hidden_states, None)

        self.assertEqual(output.loss.item() < np.inf, True)

        if config.loss == "nll":
            samples = mdl.generate(self.__class__.data)
            ref_samples = target_val.unsqueeze(1).expand(-1, params["num_parallel_samples"], -1, -1)
            self.assertEqual(samples.sequences.shape, ref_samples.shape)

    def test_forecast_full(self):
        self.check_module(task="forecast", params=self.__class__.params, output_hidden_states=True)
        # self.forecast_full_module(self.__class__.params, output_hidden_states = True)

    def test_forecast_full_2(self):
        params = self.__class__.params.copy()
        params.update(
            mode="mix_channel",
        )
        self.forecast_full_module(params, output_hidden_states=True)

    def test_forecast_full_2_with_return_dict(self):
        params = self.__class__.params.copy()
        params.update(
            mode="mix_channel",
        )
        self.forecast_full_module(params, output_hidden_states=True, return_dict=False)

    def test_forecast_full_3(self):
        params = self.__class__.params.copy()
        params.update(
            mode="mix_channel",
        )
        self.forecast_full_module(params, output_hidden_states=True)

    def test_forecast_full_5(self):
        params = self.__class__.params.copy()
        params.update(
            self_attn=True,
            use_positional_encoding=True,
            positional_encoding="sincos",
        )
        self.forecast_full_module(params, output_hidden_states=True)

    def test_forecast_full_4(self):
        params = self.__class__.params.copy()
        params.update(
            mode="mix_channel",
            prediction_channel_indices=[0, 2],
        )
        self.forecast_full_module(params)

    def test_forecast_full_distributional(self):
        params = self.__class__.params.copy()
        params.update(
            mode="mix_channel",
            prediction_channel_indices=[0, 2],
            loss="nll",
            distribution_output="normal",
        )

        self.forecast_full_module(params)

    def test_forecast_full_distributional_2(self):
        params = self.__class__.params.copy()
        params.update(
            mode="mix_channel",
            prediction_channel_indices=[0, 2],
            loss="nll",
            # distribution_output = "normal",
        )
        self.forecast_full_module(params)

    def test_forecast_full_distributional_3(self):
        params = self.__class__.params.copy()
        params.update(
            mode="mix_channel",
            # prediction_channel_indices=[0, 2],
            loss="nll",
            distribution_output="normal",
        )
        self.forecast_full_module(params)

    def test_forecast_full_distributional_4(self):
        params = self.__class__.params.copy()
        params.update(
            mode="mix_channel",
            # prediction_channel_indices=[0, 2],
            loss="nll",
            distribution_output="normal",
        )
        self.forecast_full_module(params)

    def test_classification_head(self):
        config = PatchTSMixerConfig(**self.__class__.params)
        head = PatchTSMixerLinearHead(
            config=config,
        )
        # output = head(self.__class__.enc_output, raw_data = self.__class__.correct_pretrain_output)
        output = head(self.__class__.enc_output)

        self.assertEqual(output.shape, self.__class__.correct_classification_output.shape)

    def test_classification_full(self):
        config = PatchTSMixerConfig(**self.__class__.params)
        mdl = PatchTSMixerForTimeSeriesClassification(config)
        output = mdl(
            self.__class__.data,
            target_values=self.__class__.correct_classification_classes,
        )
        self.assertEqual(
            output.prediction_outputs.shape,
            self.__class__.correct_classification_output.shape,
        )
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
        self.assertEqual(output.loss.item() < np.inf, True)

    def test_classification_full_with_return_dict(self):
        config = PatchTSMixerConfig(**self.__class__.params)
        mdl = PatchTSMixerForTimeSeriesClassification(config)
        output = mdl(
            self.__class__.data,
            target_values=self.__class__.correct_classification_classes,
            return_dict=False,
        )
        if isinstance(output, tuple):
            output = PatchTSMixerForTimeSeriesClassificationOutput(*output)
        self.assertEqual(
            output.prediction_outputs.shape,
            self.__class__.correct_classification_output.shape,
        )
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
        self.assertEqual(output.loss.item() < np.inf, True)

    def test_regression_head(self):
        config = PatchTSMixerConfig(**self.__class__.params)
        head = PatchTSMixerLinearHead(
            config=config,
        )
        output = head(self.__class__.enc_output)
        self.assertEqual(output.shape, self.__class__.correct_regression_output.shape)

    def test_regression_full(self):
        config = PatchTSMixerConfig(**self.__class__.params)
        mdl = PatchTSMixerForRegression(config)
        output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
        self.assertEqual(
            output.regression_outputs.shape,
            self.__class__.correct_regression_output.shape,
        )
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
        self.assertEqual(output.loss.item() < np.inf, True)

    def test_regression_full_with_return_dict(self):
        config = PatchTSMixerConfig(**self.__class__.params)
        mdl = PatchTSMixerForRegression(config)
        output = mdl(
            self.__class__.data,
            target_values=self.__class__.correct_regression_output,
            return_dict=False,
        )
        if isinstance(output, tuple):
            output = PatchTSMixerForRegressionOutput(*output)
        self.assertEqual(
            output.regression_outputs.shape,
            self.__class__.correct_regression_output.shape,
        )
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
        self.assertEqual(output.loss.item() < np.inf, True)

    def test_regression_full_distribute(self):
        params = self.__class__.params.copy()
        params.update(loss="nll", distribution_output="normal")

        config = PatchTSMixerConfig(**params)

        mdl = PatchTSMixerForRegression(config)
        output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
        self.assertEqual(
            output.regression_outputs[0].shape,
            self.__class__.correct_regression_output.shape,
        )
        self.assertEqual(
            output.regression_outputs[1].shape,
            self.__class__.correct_regression_output.shape,
        )
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
        self.assertEqual(output.loss.item() < np.inf, True)

        if config.loss == "nll":
            samples = mdl.generate(self.__class__.data)
            ref_samples = self.__class__.correct_regression_output.unsqueeze(1).expand(
                -1, params["num_parallel_samples"], -1
            )
            self.assertEqual(samples.sequences.shape, ref_samples.shape)

    def test_regression_full_distribute_2(self):
        params = self.__class__.params.copy()
        params.update(loss="nll", distribution_output="student_t")

        config = PatchTSMixerConfig(**params)

        mdl = PatchTSMixerForRegression(config)
        output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
        self.assertEqual(
            output.regression_outputs[0].shape,
            self.__class__.correct_regression_output.shape,
        )
        self.assertEqual(
            output.regression_outputs[1].shape,
            self.__class__.correct_regression_output.shape,
        )
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
        self.assertEqual(output.loss.item() < np.inf, True)

        if config.loss == "nll":
            samples = mdl.generate(self.__class__.data)
            ref_samples = self.__class__.correct_regression_output.unsqueeze(1).expand(
                -1, params["num_parallel_samples"], -1
            )
            self.assertEqual(samples.sequences.shape, ref_samples.shape)
