import random
import matplotlib.pyplot as plt
from momentfm import MOMENTPipeline
from momentfm.utils.masking import Masking
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch import nn
import math


from TSB_AD.models.base import BaseDetector

from utils.torch_utility import EarlyStoppingTorch, get_gpu

# set seed
#seed = 20240903
#random.seed(seed)
#np.random.seed(seed)

#torch.manual_seed(seed)
#torch.cuda.manual_seed_all(seed)

# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
epsilon = 1e-8

import numpy as np
import torch
import numpy as np
import torch

class ReconstructDataset_Moment(torch.utils.data.Dataset):
    """
    Accepts:
      - Univariate:
        * (t, 1)             : single series
        * (n, 1, t)          : n series batch
      - Multivariate:
        * (t, d)             : single series with d variables
        * (n, t, d)          : batch, time-last
        * (n, d, t)          : batch, time-last after transpose
        * (n, 1, t) / (n, t, 1)
      - Generic samples (no windowing):
        * (N, D)             : treated as N independent samples (no sliding window)

    Produces:
      - self.samples: (num_windows, window_size, C) for time-series inputs
                      or (N, D) for generic sample inputs
      - self.targets: same as samples
      - __getitem__ returns (sample, input_mask) where
            input_mask shape: (window_size,) for time-series windows,
                               (D,) for generic sample rows.
    """
    def __init__(self, data, window_size, stride=1, normalize=True, eps=1e-8):
        super().__init__()
        # to numpy float32
        if isinstance(data, torch.Tensor):
            data = data.detach().cpu().numpy()
        data = np.asarray(data)
        self.window_size = int(window_size)
        self.stride = int(stride)
        self.normalize = bool(normalize)
        self.eps = float(eps)

        # --- helpers ---
        def _safe_norm(x, axes):
            if not self.normalize:
                return x
            mean = np.mean(x, axis=axes, keepdims=True)
            std  = np.std(x, axis=axes, keepdims=True)
            std  = np.where(std == 0, self.eps, std)
            return (x - mean) / std

        def _make_windows_2d(x_td):
            """x_td: (T, D) -> (num_windows, window_size, D)"""
            T, D = x_td.shape
            num_w = max(0, (T - self.window_size) // self.stride + 1)
            if num_w == 0:
                return np.zeros((0, self.window_size, D), dtype=np.float32)
            windows = np.stack([
                x_td[j*self.stride : j*self.stride + self.window_size, :]
                for j in range(num_w)
            ], axis=0)
            return windows.astype(np.float32)

        def _make_windows_3d(x_ntd):
            """x_ntd: (N, T, D) -> (sum_i num_windows_i, window_size, D)"""
            N, T, D = x_ntd.shape
            chunks = []
            for i in range(N):
                chunks.append(_make_windows_2d(x_ntd[i]))
            if len(chunks) == 0:
                return np.zeros((0, self.window_size, D), dtype=np.float32)
            if any(c.size == 0 for c in chunks):
                # 모두 0일 수 있으니 안전 결합
                chunks = [c for c in chunks if c.size > 0]
                if not chunks:
                    return np.zeros((0, self.window_size, D), dtype=np.float32)
            return np.concatenate(chunks, axis=0).astype(np.float32)

        # --- shape routing ---
        if data.ndim == 1:
            # (t,) -> (t,1)
            data = data.reshape(-1, 1)

        # (n,1,t) or (n,t,1) -> (n,t,1)
        if data.ndim == 3 and (data.shape[1] == 1 or data.shape[2] == 1):
            if data.shape[1] == 1:            # (N,1,T)
                data = np.transpose(data, (0, 2, 1))  # -> (N,T,1)
            # else already (N,T,1)
            # normalize over time axis per series
            data = _safe_norm(data, axes=(1,))
            X = _make_windows_3d(data)  # (M, window_size, 1)
            self.samples = torch.from_numpy(X).float()
            self.targets = self.samples.clone()
            self.sample_num = self.samples.shape[0]
            return

        # Multivariate single series: (t, d)
        if data.ndim == 2 and data.shape[1] > 1:
            # normalize per variable across time
            data = _safe_norm(data, axes=(0,))
            X = _make_windows_2d(data)  # (M, window_size, D)
            self.samples = torch.from_numpy(X).float()
            self.targets = self.samples.clone()
            self.sample_num = self.samples.shape[0]
            return

        # Multivariate batch: (n, t, d) or (n, d, t) with both dims > 1
        if data.ndim == 3 and data.shape[1] > 1 and data.shape[2] > 1:
            N, A, B = data.shape  # A,B are the two non-batch axes
            # try to infer time axis using window_size
            # priority: axis >= window_size and the other < window_size
            if A >= self.window_size and B < self.window_size:
                # assume (N, T=A, D=B)
                pass  # already time-last (N,T,D)
            elif B >= self.window_size and A < self.window_size:
                # assume (N, D=A, T=B) -> transpose to (N,T,D)
                data = np.transpose(data, (0, 2, 1))
            else:
                # fallback: assume larger one is time
                if A < B:
                    data = np.transpose(data, (0, 2, 1))  # make A the time axis

            # now (N,T,D)
            data = _safe_norm(data, axes=(1,))          # per-series, over time
            X = _make_windows_3d(data)                  # (M, window_size, D)
            self.samples = torch.from_numpy(X).float()
            self.targets = self.samples.clone()
            self.sample_num = self.samples.shape[0]
            return

        # Univariate single series: (t,1)
        if data.ndim == 2 and data.shape[1] == 1:
            data = _safe_norm(data, axes=(0,))
            T = data.shape[0]
            num_w = max(0, (T - self.window_size) // self.stride + 1)
            if num_w == 0:
                X = np.zeros((0, self.window_size, 1), dtype=np.float32)
            else:
                series = data[:, 0]
                X = np.stack([
                    series[j*self.stride : j*self.stride + self.window_size]
                    for j in range(num_w)
                ], axis=0).astype(np.float32)
                X = np.expand_dims(X, -1)     # (num_w, window_size, 1)
            self.samples = torch.from_numpy(X).float()
            self.targets = self.samples.clone()
            self.sample_num = self.samples.shape[0]
            return

        # Fallback: treat as generic (N, D) samples (no windowing)
        X = torch.from_numpy(data.reshape(data.shape[0], -1).astype(np.float32)).float()
        self.samples = X       # (N, D)
        self.targets = X.clone()
        self.sample_num = X.shape[0]

    def __len__(self):
        return self.sample_num

    def __getitem__(self, index):
        sample = self.samples[index]
        # mask length matches time dimension if windowed; else feature length
        if sample.ndim == 2:
            # (window_size, C)
            input_mask = np.ones(self.window_size, dtype=np.float32)
        else:
            # (D,)
            input_mask = np.ones(sample.shape[0], dtype=np.float32)
        return sample, input_mask


# class ReconstructDataset_Moment(torch.utils.data.Dataset):
#     """
#     Accepts:
#       - data: np.ndarray or torch.Tensor of shape (n,1,t) or (t,1) or (N, D)
#       - window_size: int
#       - stride: int
#       - normalize: bool
#     Produces:
#       - self.samples, self.targets: torch.Tensor
#       - __len__ and __getitem__ as desired
#     """
#     def __init__(self, data, window_size, stride=1, normalize=True):
#         super().__init__()
#         # Convert to numpy
#         if isinstance(data, torch.Tensor):
#             data = data.cpu().numpy()
#         data = np.array(data)
        
#         self.normalize = normalize
#         # Normalize across axis 0 (series axis)
#         if normalize:
#             data_mean = np.mean(data, axis=0, keepdims=True)
#             data_std  = np.std(data, axis=0, keepdims=True)
#             data_std  = np.where(data_std == 0, 1e-8, data_std)
#             data = (data - data_mean) / data_std
        
#         self.window_size = window_size
#         self.stride = stride

#         # Handle multi-series: shape (n,1,t)
#         if data.ndim == 3 and data.shape[1] == 1:
#             n, _, length = data.shape
#             all_windows = []
#             for i in range(n):
#                 series = data[i, 0, :]  # (t,)
#                 num_windows = max(0, (length - window_size) // stride + 1)
#                 # extract sliding windows
#                 windows = np.stack([
#                     series[j*stride : j*stride + window_size]
#                     for j in range(num_windows)
#                 ], axis=0)  # (num_windows, window_size)
#                 all_windows.append(windows)
#             if all_windows:
#                 X = np.vstack(all_windows)  # (total_samples, window_size)
#             else:
#                 X = np.zeros((0, window_size))
#             X_torch = torch.from_numpy(X).float()
#             # samples and targets both (N_samples, window_size, 1)
#             self.samples = X_torch.unsqueeze(-1)
#             self.targets = self.samples.clone()
#             self.sample_num = X.shape[0]

#         # Handle single-series univariate: shape (t,1)
#         elif data.ndim == 2 and data.shape[1] == 1:
#             series = data.squeeze()  # (t,)
#             length = series.shape[0]
#             self.sample_num = max(0, (length - window_size) // stride + 1)
#             X = torch.zeros((self.sample_num, window_size), dtype=torch.float)
#             for i in range(self.sample_num):
#                 window = series[i*stride : i*stride + window_size]
#                 X[i, :] = torch.from_numpy(window)
#             self.samples = X.unsqueeze(-1)  # (N, window_size, 1)
#             self.targets = self.samples.clone()

#         # Fallback for other shapes: each row is one sample
#         else:
#             # interpret first axis as samples
#             X = torch.from_numpy(data.reshape(data.shape[0], -1)).float()
#             self.samples = X
#             self.targets = X
#             self.sample_num = X.shape[0]

#     def __len__(self):
#         return self.sample_num

#     def __getitem__(self, index):
#         """
#         Returns:
#           - sample: Tensor of shape (window_size, 1) or (D,)
#           - input_mask: np.ndarray of ones length window_size
#         """
#         input_mask = np.ones(self.window_size, dtype=np.float32)
#         return self.samples[index], input_mask

# Example usage:
# ds = ReconstructDataset_Moment(np.random.randn(5,1,100), window_size=20, stride=5)
# print(len(ds), ds[0][0].shape, ds[0][1].shape)


from momentfm.common import TASKS
from momentfm.data.base import TimeseriesOutputs
#from momentfm.models.layers.embed import PatchEmbedding, Patching
from momentfm.models.layers.embed import  Patching
from momentfm.models.layers.revin import RevIN
from momentfm.utils.masking import Masking
from momentfm.utils.utils import (
    NamespaceWithDefaults,
    get_anomaly_criterion,
    get_huggingface_model_dimensions,
)
import logging
import warnings
from argparse import Namespace
from copy import deepcopy
from math import ceil
from momentfm import MOMENTPipeline

import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from transformers import T5Config, T5EncoderModel, T5Model

import math
SUPPORTED_HUGGINGFACE_MODELS = [
    "google/flan-t5-small",
    "google/flan-t5-base",
    "google/flan-t5-large",
    "google/flan-t5-xl",
    "google/flan-t5-xxl",
]


def embed_anomlay(
        self,
        *,
        x_enc: torch.Tensor,
        input_mask: torch.Tensor = None,
        reduction: str = "mean",
        **kwargs,
    ) -> TimeseriesOutputs:
        batch_size, n_channels, seq_len = x_enc.shape

        if input_mask is None:
            input_mask = torch.ones((batch_size, seq_len)).to(x_enc.device)

        x_enc = self.normalizer(x=x_enc, mask=input_mask, mode="norm")
        x_enc = torch.nan_to_num(x_enc, nan=0, posinf=0, neginf=0)

        input_mask_patch_view = Masking.convert_seq_to_patch_view(
            input_mask, self.patch_len
        )

        x_enc = self.tokenizer(x=x_enc)
        enc_in = self.patch_embedding(x_enc, mask=input_mask)

        n_patches = enc_in.shape[2]
        enc_in = enc_in.reshape(
            (batch_size * n_channels, n_patches, self.config.d_model)
        )

        patch_view_mask = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
        attention_mask = patch_view_mask.repeat_interleave(n_channels, dim=0)
        outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
        enc_out = outputs.last_hidden_state

        enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
        # [batch_size x n_channels x n_patches x d_model]

        if reduction == "mean":
            enc_out = enc_out.mean(dim=1, keepdim=False)  # Mean across channels
            # [batch_size x n_patches x d_model]
            input_mask_patch_view = input_mask_patch_view.unsqueeze(-1).repeat(
                1, 1, self.config.d_model
            )
            enc_out = (input_mask_patch_view * enc_out).sum(
                dim=1
            ) / input_mask_patch_view.sum(dim=1)
        elif reduction =='None':
            mean = self.normalizer.mean
            stddev = self.normalizer.stdev 
            # 평균과 분산을 n_patches에 맞게 확장 (broadcast) 시켜야 함
            mean_expanded = mean.expand(-1, -1, enc_out.size(2))  # n_patches 차원으로 broadcast (16, 3, 5)
            variance_expanded = stddev.expand(-1, -1, enc_out.size(2))  # 마찬가지로 (16, 3, 5)

            # 마지막 d_model 차원에 맞게 평균과 분산을 붙여줌
            mean_expanded = mean_expanded.unsqueeze(-1)  # 마지막 차원에 추가 (16, 3, 5, 1)
            variance_expanded = variance_expanded.unsqueeze(-1)  # 마찬가지로 추가 (16, 3, 5, 1)

            #embedding_expanded = embedding.unsqueeze(3).expand(-1, -1,-1,enc_out.size(2)).permute(0,1,3,2) # 마찬가지로 추가 (16, 3, 5, 1)
            
            #enc_concat = torch.cat([enc_out, mean_expanded, variance_expanded,embedding_expanded], dim=-1)
            enc_concat = torch.cat([enc_out, mean_expanded, variance_expanded], dim=-1)

            return TimeseriesOutputs(
            embeddings=enc_concat, input_mask=input_mask, metadata=reduction
            )
        else: 
            raise NotImplementedError(f"Reduction method {reduction} not implemented.")

        return TimeseriesOutputs(
            embeddings=enc_out, input_mask=input_mask, metadata=reduction
        )
import importlib
class MOMENT_custom(BaseDetector):
    def __init__(self, 
                 win_size=256, 
                 input_c=1, 
                 batch_size=128,
                 epochs=2,
                 validation_size=0,
                 lr=1e-4):

        self.model_name = 'MOMENT'
        self.win_size = win_size
        self.input_c = input_c
        self.batch_size = batch_size
        self.anomaly_criterion = nn.MSELoss(reduce=False)
        self.epochs = epochs
        self.validation_size = validation_size
        self.lr = lr

        cuda = True        
        self.cuda = cuda
        self.device = get_gpu(self.cuda)

        import momentfm
        importlib.reload(momentfm)
        from momentfm import MOMENTPipeline

        self.model = MOMENTPipeline.from_pretrained(
            "AutonLab/MOMENT-1-base", 
            model_kwargs={"task_name": "reconstruction"}, # For anomaly detection, we will load MOMENT in `reconstruction` mode
        )
        self.model.head = PretrainHead_anomaly(
        self.model.config.d_model+2 ,
        self.model.config.patch_len,
        self.model.config.getattr("head_dropout", 0.1),
        self.model.config.getattr("orth_gain", 1.41),
    )
        self.model.init()
        self.model = self.model.to("cuda").float()
        
        import momentfm
        importlib.reload(momentfm)
        from momentfm import MOMENTPipeline
        MOMENTPipeline.embed = embed_anomlay
        import momentfm
        importlib.reload(momentfm)
        from momentfm import MOMENTPipeline
        MOMENTPipeline.reconstruction = reconstruction_anomaly
        MOMENTPipeline.reconstruct = reconstruct_anomaly
        importlib.reload(momentfm)
        self.model_embed = MOMENTPipeline.from_pretrained(
            "AutonLab/MOMENT-1-base", 
            model_kwargs={'task_name': 'embedding'}, # We are loading the model in `embedding` mode to learn representations
            # local_files_only=True,  # Whether or not to only look at local files (i.e., do not try to download the model).
        )
        
        self.model_embed.init()
        self.model_embed.to('cuda').float()
        self.model_embed.eval()

        # Optimize Mean Squarred Error using your favourite optimizer
        self.criterion = torch.nn.MSELoss() 
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.75)
        self.save_path = None
        self.early_stopping = EarlyStoppingTorch(save_path=self.save_path, patience=3)

    
    def zero_shot(self, data):

        test_loader = DataLoader(
            dataset=ReconstructDataset_Moment(data, window_size=self.win_size),
            batch_size=self.batch_size,
            shuffle=False)

        trues, preds = [], []
        self.score_list = []
        with torch.no_grad():
            for batch_x, batch_masks in tqdm(test_loader, total=len(test_loader)):
                batch_x = batch_x.to("cuda").float()
                batch_masks = batch_masks.to("cuda")
                batch_x = batch_x.permute(0,2,1)

                # print('batch_x: ', batch_x.shape)             # [batch_size, n_channels, window_size]
                # print('batch_masks: ', batch_masks.shape)     # [batch_size, window_size]

                output = self.model(x_enc=batch_x, input_mask=batch_masks) # [batch_size, n_channels, window_size]
                score = torch.mean(self.anomaly_criterion(batch_x, output.reconstruction), dim=-1).detach().cpu().numpy()[:, -1]
                self.score_list.append(score)

        self.__anomaly_score = np.concatenate(self.score_list, axis=0).reshape(-1)

        if self.__anomaly_score.shape[0] < len(data):
            self.__anomaly_score = np.array([self.__anomaly_score[0]]*math.ceil((self.win_size-1)/2) + 
                        list(self.__anomaly_score) + [self.__anomaly_score[-1]]*((self.win_size-1)//2))
        self.decision_scores_ = self.__anomaly_score

    def embed_return(self, data,reduction='mean'):

        test_loader = DataLoader(
            dataset=ReconstructDataset_Moment(data, window_size=self.win_size),
            batch_size=self.batch_size,
            shuffle=False)
        
        
        embedding_list = []
        mean_list = []
        std_list = []
        input_mask_list = []
        with torch.no_grad():

            for train_batch_x, batch_masks in tqdm(test_loader, total=len(test_loader)):
                embed_output = self.model_embed(train_batch_x.float().permute(0,2,1).to('cuda'),reduction = reduction)
                outputs = embed_output.embeddings.detach().cpu().numpy()
                
                
                self.model_embed.normalizer._get_statistics(train_batch_x.float().permute(0,2,1))
                mean,stddev = self.model_embed.normalizer.mean, self.model_embed.normalizer.stdev 
                mean = mean.reshape(-1,1)
                stddev = stddev.reshape(-1,1)
                
                embedding_list.append(outputs)
                mean_list.append(mean)
                std_list.append(stddev)
                input_mask_list.append(embed_output.input_mask.detach().cpu().numpy())


        self.embeddings = embedding_list
        self.means = mean_list
        self.stds = std_list
        self.input_mask = input_mask_list

    
    def fit(self, data):
        tsTrain = data[:int((1-self.validation_size)*len(data))]
        tsValid = data[int((1-self.validation_size)*len(data)):]

        train_loader = DataLoader(
            dataset=ReconstructDataset_Moment(tsTrain, window_size=self.win_size),
            batch_size=self.batch_size,
            shuffle=True
        )
        
        valid_loader = DataLoader(
            dataset=ReconstructDataset_Moment(tsValid, window_size=self.win_size),
            batch_size=self.batch_size,
            shuffle=False
        )

        mask_generator = Masking(mask_ratio=0.3) # Mask 30% of patches randomly 

        
        import momentfm
        importlib.reload(momentfm)
        from momentfm import MOMENTPipeline
        MOMENTPipeline.reconstruction = reconstruction_anomaly
        MOMENTPipeline.reconstruct = reconstruct_anomaly



        for epoch in range(1, self.epochs + 1):
            self.model.train()
            for batch_x, batch_masks in tqdm(train_loader, total=len(train_loader)):
                batch_x = batch_x.to(self.device).float()
                batch_x = batch_x.permute(0,2,1)
                # print('batch_x: ', batch_x.shape)

                original = batch_x
                n_channels = batch_x.shape[1]
                
                # Reshape to [batch_size * n_channels, 1, window_size]
                batch_x = batch_x.reshape((-1, 1, self.win_size)) 
                
                batch_masks = batch_masks.to(self.device).long()
                batch_masks = batch_masks.repeat_interleave(n_channels, axis=0)
                
                # Randomly mask some patches of data
                mask = mask_generator.generate_mask(
                    x=batch_x, input_mask=batch_masks).to(self.device).long()
                
                mask = torch.nn.functional.pad(mask, (0, batch_masks.size(1) - mask.size(1)), mode='constant', value=1)

                # Forward
                model_output = self.model(x_enc = batch_x, input_mask=batch_masks, mask=mask).reconstruction
                model_output = torch.nn.functional.pad(model_output, (0, original.size(2)-model_output.size(2)), mode='replicate')

                output = model_output.reshape(original.size(0), n_channels, self.win_size)

                # Compute loss
                loss = self.criterion(output, original)
                    
                # print(f"loss: {loss.item()}")
                
                # Backward
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()


    def reconstruction_loss(self, data,masking = True):


        test_loader = DataLoader(
            dataset=ReconstructDataset_Moment(data, window_size=self.win_size),
            batch_size=self.batch_size,
            shuffle=False)
        
        import momentfm
        importlib.reload(momentfm)
        from momentfm import MOMENTPipeline
        
        mask_generator = Masking(mask_ratio=0.3) # Mask 30% of patches randomly 
        trues, preds = [], []
        self.score_list = []
        prediction_loss_list, reconstruction_loss_list = [],[]
        with torch.no_grad():
            for batch_x, batch_masks in tqdm(test_loader, total=len(test_loader)):
                batch_x = batch_x.to("cuda").float()
                batch_masks = batch_masks.to("cuda")
                batch_x = batch_x.permute(0,2,1)

                original = batch_x
                n_channels = batch_x.shape[1]



                mask = mask_generator.generate_mask(
                    x=batch_x, input_mask=batch_masks).to(self.device).long()
                
                mask = torch.nn.functional.pad(mask, (0, batch_masks.size(1) - mask.size(1)), mode='constant', value=1)



                # Forward
                model_output = self.model(x_enc = batch_x, input_mask=batch_masks, mask=mask).reconstruction
                model_output = torch.nn.functional.pad(model_output, (0, original.size(2)-model_output.size(2)), mode='replicate')

                output = model_output.reshape(original.size(0), n_channels, self.win_size)

                # Compute loss
                loss = torch.sqrt(torch.sum((output- original)**2,0) ).detach().cpu().numpy()
                prediction_loss_list.append(loss)

                # Forward
                model_output = self.model(x_enc = batch_x, input_mask=batch_masks).reconstruction
                model_output = torch.nn.functional.pad(model_output, (0, original.size(2)-model_output.size(2)), mode='replicate')

                output = model_output.reshape(original.size(0), n_channels, self.win_size)

                # Compute loss
                loss =  torch.sqrt(torch.sum((output- original)**2,0) ).detach().cpu().numpy()
                reconstruction_loss_list.append(loss)
            self.reconstruction_loss_list = reconstruction_loss_list
            self.prediction_loss_list = prediction_loss_list
        prediction_loss = np.concatenate(prediction_loss_list).reshape(-1)
        reconstruction_loss = np.concatenate(reconstruction_loss_list).reshape(-1)

        if prediction_loss.shape[0] < len(data):
            prediction_loss = np.array([prediction_loss[0]]*math.ceil((self.win_size-1)/2) + 
                        list(prediction_loss) + [prediction_loss[-1]]*((self.win_size-1)//2))
            reconstruction_loss = np.array([reconstruction_loss[0]]*math.ceil((self.win_size-1)/2) + 
                        list(reconstruction_loss) + [reconstruction_loss[-1]]*((self.win_size-1)//2))

        return reconstruction_loss, prediction_loss
    
    def decision_function(self, data, masking=False, masking_mode="cv", boundary_copy=True,
                      patch_len=8, n_split=3, mask_ratio=0.2):

        test_loader = DataLoader(
            dataset=ReconstructDataset_Moment(data, window_size=self.win_size),
            batch_size=self.batch_size,
            shuffle=False)
        
        trues, preds = [], []
        self.score_list = []

        with torch.no_grad():
            for batch_x, batch_masks in tqdm(test_loader, total=len(test_loader)):
                batch_x = batch_x.to("cuda").float()
                batch_masks = batch_masks.to("cuda")
                batch_x = batch_x.permute(0,2,1)  # [B, C, S]

                if masking:
                    if masking_mode == "cv":
                        # === 기존 n_split cross-validation masking ===
                        scores = []
                        B, S = batch_x.shape[0], batch_x.shape[2]
                        partial_size = (S // patch_len) + 1
                        perm_list = np.random.permutation(partial_size)
                        mask_size = int(len(perm_list) / n_split)

                        for mask_idx in range(n_split):
                            m = torch.ones((B, S), device=self.device)
                            for p in perm_list[mask_idx*mask_size:(mask_idx+1)*mask_size]:
                                start, end = p * mask_size, (p+1) * mask_size
                                m[:, start:end] = 0
                            output = self.model(x_enc=batch_x, input_mask=batch_masks, mask=m.float())
                            score = torch.mean(
                                self.anomaly_criterion(batch_x, output.reconstruction),
                                dim=-1
                            ).detach().cpu().numpy()[:, -1]
                            scores.append(score)
                        final_score = np.mean(scores, axis=0)

                    elif masking_mode == "ratio":
                        # === 단순히 전체의 20%만 랜덤 마스킹 ===
                        B, S = batch_x.shape[0], batch_x.shape[2]
                        m = torch.ones((B, S), device=self.device)
                        num_mask = int(S * mask_ratio)
                        for b in range(B):
                            idx = np.random.choice(S, num_mask, replace=False)
                            m[b, idx] = 0
                        output = self.model(x_enc=batch_x, input_mask=batch_masks, mask=m.float())
                        final_score = torch.mean(
                            self.anomaly_criterion(batch_x, output.reconstruction),
                            dim=-1
                        ).detach().cpu().numpy()[:, -1]

                    else:
                        raise ValueError(f"Unknown masking_mode: {masking_mode}")

                else:
                    # === masking 사용하지 않는 경우 ===
                    output = self.model(x_enc=batch_x, input_mask=batch_masks)
                    final_score = torch.mean(
                        self.anomaly_criterion(batch_x, output.reconstruction),
                        dim=-1
                    ).detach().cpu().numpy()[:, -1]

                self.score_list.append(final_score)

        self.__anomaly_score = np.concatenate(self.score_list, axis=0).reshape(-1)

        if (self.__anomaly_score.shape[0] < len(data)) and boundary_copy:
            self.__anomaly_score = np.array(
                [self.__anomaly_score[0]] * math.ceil((self.win_size-1)/2) +
                list(self.__anomaly_score) +
                [self.__anomaly_score[-1]] * ((self.win_size-1)//2)
            )

        return self.__anomaly_score


def run_MOMENT_FT(data_train, data_test, win_size=256):
    
    clf = MOMENT_custom(win_size=win_size, input_c=data_test.shape[1])

    # Finetune
    clf.fit(data_train)
    reconstruction_loss = clf.decision_function(data_test,masking=False)
    return reconstruction_loss.ravel()



from momentfm.common import TASKS
from momentfm.data.base import TimeseriesOutputs
#from momentfm.models.layers.embed import PatchEmbedding, Patching
from momentfm.models.layers.embed import  Patching
from momentfm.models.layers.revin import RevIN
from momentfm.utils.masking import Masking
from momentfm.utils.utils import (
    NamespaceWithDefaults,
    get_anomaly_criterion,
    get_huggingface_model_dimensions,
)
import logging
import warnings
from argparse import Namespace
from copy import deepcopy
from math import ceil
from momentfm import MOMENTPipeline

import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from transformers import T5Config, T5EncoderModel, T5Model

import math
SUPPORTED_HUGGINGFACE_MODELS = [
    "google/flan-t5-small",
    "google/flan-t5-base",
    "google/flan-t5-large",
    "google/flan-t5-xl",
    "google/flan-t5-xxl",
]
def reconstruction_anomaly(
    self,
    *,
    x_enc: torch.Tensor,
    input_mask: torch.Tensor = None,
    mask: torch.Tensor = None,
    **kwargs,
) -> TimeseriesOutputs:
    batch_size, n_channels, _ = x_enc.shape

    if mask is None:
        mask = self.mask_generator.generate_mask(x=x_enc, input_mask=input_mask)
        mask = mask.to(x_enc.device)  # mask: [batch_size x seq_len]

    x_enc = self.normalizer(x=x_enc, mask=mask * input_mask, mode="norm")
    # Prevent too short time-series from causing NaNs
    x_enc = torch.nan_to_num(x_enc, nan=0, posinf=0, neginf=0)

    x_enc = self.tokenizer(x=x_enc)
    enc_in = self.patch_embedding(x_enc, mask=mask)

    n_patches = enc_in.shape[2]
    enc_in = enc_in.reshape(
        (batch_size * n_channels, n_patches, self.config.d_model)
    )

    patch_view_mask = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
    attention_mask = patch_view_mask.repeat_interleave(n_channels, dim=0)
    if self.config.transformer_type == "encoder_decoder":
        outputs = self.encoder(
            inputs_embeds=enc_in,
            decoder_inputs_embeds=enc_in,
            attention_mask=attention_mask,
        )
    else:
        outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
    enc_out = outputs.last_hidden_state
    
    mean = self.normalizer.mean
    stddev = self.normalizer.stdev 
    
    
    enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))

    # 평균과 분산을 n_patches에 맞게 확장 (broadcast) 시켜야 함
    mean_expanded = mean.expand(-1, -1, enc_out.size(2))  # n_patches 차원으로 broadcast (16, 3, 5)
    variance_expanded = stddev.expand(-1, -1, enc_out.size(2))  # 마찬가지로 (16, 3, 5)

    # 마지막 d_model 차원에 맞게 평균과 분산을 붙여줌
    mean_expanded = mean_expanded.unsqueeze(-1)  # 마지막 차원에 추가 (16, 3, 5, 1)
    variance_expanded = variance_expanded.unsqueeze(-1)  # 마찬가지로 추가 (16, 3, 5, 1)
    #embedding = self.embedding_vector
    #embedding_expanded = embedding.unsqueeze(3).expand(-1, -1,-1,enc_out.size(2)).permute(0,1,3,2) # 마찬가지로 추가 (16, 3, 5, 1)
    
    #enc_concat = torch.cat([enc_out, mean_expanded, variance_expanded,embedding_expanded], dim=-1)
    enc_concat = torch.cat([enc_out, mean_expanded, variance_expanded], dim=-1)
    dec_out = self.head(enc_concat)  # [batch_size x n_channels x seq_len]
    dec_out = self.normalizer(x=dec_out, mode="denorm")

    if self.config.getattr("debug", False):
        illegal_output = self._check_model_weights_for_illegal_values()
    else:
        illegal_output = None

    return TimeseriesOutputs(
        input_mask=input_mask,
        reconstruction=dec_out,
        pretrain_mask=mask,
        illegal_output=illegal_output,
    )

def reconstruct_anomaly(
    self,
    *,
    x_enc: torch.Tensor,
    input_mask: torch.Tensor = None,
    mask: torch.Tensor = None,
    **kwargs,
) -> TimeseriesOutputs:
    if mask is None:
        mask = torch.ones_like(input_mask)

    batch_size, n_channels, _ = x_enc.shape
    x_enc = self.normalizer(x=x_enc, mask=mask * input_mask, mode="norm")

    x_enc = self.tokenizer(x=x_enc)
    enc_in = self.patch_embedding(x_enc, mask=mask)

    n_patches = enc_in.shape[2]
    enc_in = enc_in.reshape(
        (batch_size * n_channels, n_patches, self.config.d_model)
    )
    # [batch_size * n_channels x n_patches x d_model]

    patch_view_mask = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
    attention_mask = patch_view_mask.repeat_interleave(n_channels, dim=0).to(
        x_enc.device
    )

    n_tokens = 0
    if "prompt_embeds" in kwargs:
        prompt_embeds = kwargs["prompt_embeds"].to(x_enc.device)

        if isinstance(prompt_embeds, nn.Embedding):
            prompt_embeds = prompt_embeds.weight.data.unsqueeze(0)

        n_tokens = prompt_embeds.shape[1]

        enc_in = self._cat_learned_embedding_to_input(prompt_embeds, enc_in)
        attention_mask = self._extend_attention_mask(attention_mask, n_tokens)

    if self.config.transformer_type == "encoder_decoder":
        outputs = self.encoder(
            inputs_embeds=enc_in,
            decoder_inputs_embeds=enc_in,
            attention_mask=attention_mask,
        )
    else:
        outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
    enc_out = outputs.last_hidden_state
    enc_out = enc_out[:, n_tokens:, :]

    enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
    # [batch_size x n_channels x n_patches x d_model]
    # 평균과 분산을 n_patches에 맞게 확장 (broadcast) 시켜야 함
        
    mean = self.normalizer.mean
    stddev = self.normalizer.stdev 
    #embedding = self.embedding_vector

    mean_expanded = mean.expand(-1, -1, enc_out.size(2))  # n_patches 차원으로 broadcast (16, 3, 5)
    variance_expanded = stddev.expand(-1, -1, enc_out.size(2))  # 마찬가지로 (16, 3, 5)

    # 마지막 d_model 차원에 맞게 평균과 분산을 붙여줌
    mean_expanded = mean_expanded.unsqueeze(-1)  # 마지막 차원에 추가 (16, 3, 5, 1)
    variance_expanded = variance_expanded.unsqueeze(-1)  # 마찬가지로 추가 (16, 3, 5, 1)
    
    #embedding_expanded = embedding.unsqueeze(3).expand(-1, -1,-1,enc_out.size(2)).permute(0,1,3,2) # 마찬가지로 추가 (16, 3, 5, 1) # 마찬가지로 추가 (16, 3, 5, 1)

    #enc_concat = torch.cat([enc_out, mean_expanded, variance_expanded,embedding_expanded], dim=-1)
    enc_concat = torch.cat([enc_out, mean_expanded, variance_expanded], dim=-1)
    dec_out = self.head(enc_concat)  # [batch_size x n_channels x seq_len]

    #dec_out = self.head(enc_out)  # [batch_size x n_channels x seq_len]
    dec_out = self.normalizer(x=dec_out, mode="denorm")

    return TimeseriesOutputs(input_mask=input_mask, reconstruction=dec_out)


def reconstruction_normal(
    self,
    *,
    x_enc: torch.Tensor,
    input_mask: torch.Tensor = None,
    mask: torch.Tensor = None,
    **kwargs,
) -> TimeseriesOutputs:
    batch_size, n_channels, _ = x_enc.shape

    if mask is None:
        mask = self.mask_generator.generate_mask(x=x_enc, input_mask=input_mask)
        mask = mask.to(x_enc.device)  # mask: [batch_size x seq_len]

    x_enc = self.normalizer(x=x_enc, mask=mask * input_mask, mode="norm")
    # Prevent too short time-series from causing NaNs
    x_enc = torch.nan_to_num(x_enc, nan=0, posinf=0, neginf=0)

    x_enc = self.tokenizer(x=x_enc)
    enc_in = self.patch_embedding(x_enc, mask=mask)

    n_patches = enc_in.shape[2]
    enc_in = enc_in.reshape(
        (batch_size * n_channels, n_patches, self.config.d_model)
    )

    patch_view_mask = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
    attention_mask = patch_view_mask.repeat_interleave(n_channels, dim=0)
    if self.config.transformer_type == "encoder_decoder":
        outputs = self.encoder(
            inputs_embeds=enc_in,
            decoder_inputs_embeds=enc_in,
            attention_mask=attention_mask,
        )
    else:
        outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
    enc_out = outputs.last_hidden_state
    
    mean = self.normalizer.mean
    stddev = self.normalizer.stdev 
    
    
    enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))

    # 평균과 분산을 n_patches에 맞게 확장 (broadcast) 시켜야 함
    #mean_expanded = mean.expand(-1, -1, enc_out.size(2))  # n_patches 차원으로 broadcast (16, 3, 5)
    #variance_expanded = stddev.expand(-1, -1, enc_out.size(2))  # 마찬가지로 (16, 3, 5)

    # 마지막 d_model 차원에 맞게 평균과 분산을 붙여줌
    #mean_expanded = mean_expanded.unsqueeze(-1)  # 마지막 차원에 추가 (16, 3, 5, 1)
    #variance_expanded = variance_expanded.unsqueeze(-1)  # 마찬가지로 추가 (16, 3, 5, 1)
    
    #enc_concat = torch.cat([enc_out, mean_expanded, variance_expanded], dim=-1)
    dec_out = self.head(enc_out)  # [batch_size x n_channels x seq_len]
    dec_out = self.normalizer(x=dec_out, mode="denorm")

    if self.config.getattr("debug", False):
        illegal_output = self._check_model_weights_for_illegal_values()
    else:
        illegal_output = None

    return TimeseriesOutputs(
        input_mask=input_mask,
        reconstruction=dec_out,
        pretrain_mask=mask,
        illegal_output=illegal_output,
    )

def reconstruct_normal(
    self,
    *,
    x_enc: torch.Tensor,
    input_mask: torch.Tensor = None,
    mask: torch.Tensor = None,
    **kwargs,
) -> TimeseriesOutputs:
    if mask is None:
        mask = torch.ones_like(input_mask)

    batch_size, n_channels, _ = x_enc.shape
    x_enc = self.normalizer(x=x_enc, mask=mask * input_mask, mode="norm")

    x_enc = self.tokenizer(x=x_enc)
    enc_in = self.patch_embedding(x_enc, mask=mask)

    n_patches = enc_in.shape[2]
    enc_in = enc_in.reshape(
        (batch_size * n_channels, n_patches, self.config.d_model)
    )
    # [batch_size * n_channels x n_patches x d_model]

    patch_view_mask = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
    attention_mask = patch_view_mask.repeat_interleave(n_channels, dim=0).to(
        x_enc.device
    )

    n_tokens = 0
    if "prompt_embeds" in kwargs:
        prompt_embeds = kwargs["prompt_embeds"].to(x_enc.device)

        if isinstance(prompt_embeds, nn.Embedding):
            prompt_embeds = prompt_embeds.weight.data.unsqueeze(0)

        n_tokens = prompt_embeds.shape[1]

        enc_in = self._cat_learned_embedding_to_input(prompt_embeds, enc_in)
        attention_mask = self._extend_attention_mask(attention_mask, n_tokens)

    if self.config.transformer_type == "encoder_decoder":
        outputs = self.encoder(
            
            inputs_embeds=enc_in,
            decoder_inputs_embeds=enc_in,
            attention_mask=attention_mask,
        )
    else:
        outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
    enc_out = outputs.last_hidden_state
    enc_out = enc_out[:, n_tokens:, :]

    enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
    # [batch_size x n_channels x n_patches x d_model]
    # 평균과 분산을 n_patches에 맞게 확장 (broadcast) 시켜야 함
        
        # 평균과 분산을 n_patches에 맞게 확장 (broadcast) 시켜야 함
    #mean_expanded = mean.expand(-1, -1, enc_out.size(2))  # n_patches 차원으로 broadcast (16, 3, 5)
    #variance_expanded = stddev.expand(-1, -1, enc_out.size(2))  # 마찬가지로 (16, 3, 5)

    # 마지막 d_model 차원에 맞게 평균과 분산을 붙여줌
    #mean_expanded = mean_expanded.unsqueeze(-1)  # 마지막 차원에 추가 (16, 3, 5, 1)
    #variance_expanded = variance_expanded.unsqueeze(-1)  # 마찬가지로 추가 (16, 3, 5, 1)
    
    #enc_concat = torch.cat([enc_out, mean_expanded, variance_expanded], dim=-1)
    #dec_out = self.head(enc_concat)  # [batch_size x n_channels x seq_len]
    dec_out = self.head(enc_out)  # [batch_size x n_channels x seq_len]
   

    #dec_out = self.head(enc_out)  # [batch_size x n_channels x seq_len]
    dec_out = self.normalizer(x=dec_out, mode="denorm")

    return TimeseriesOutputs(input_mask=input_mask, reconstruction=dec_out)



from torch import nn
class PretrainHead_anomaly(nn.Module):
    def __init__(
        self,
        d_model: int = 768,
        patch_len: int = 8,
        head_dropout: float = 0.1,
        orth_gain: float = 1.41,
    ):
        super().__init__()
        self.dropout = nn.Dropout(head_dropout)

        # 여러 개의 Linear 레이어를 정의
        if d_model >768:
            addition_d_model = d_model - 768
            d_model = 768
            
        else:
            addition_d_model = 0
        self.linear1 = nn.Linear(d_model+addition_d_model, patch_len)
        
        
        self.relu = nn.SiLU()

        if orth_gain is not None:
            torch.nn.init.orthogonal_(self.linear1.weight, gain=orth_gain)
            
            self.linear1.bias.data.zero_()
        
    def forward(self, x):
        # 각 레이어를 순차적으로 적용
        x = self.linear1(self.dropout(x))
        
        x = x.flatten(start_dim=2, end_dim=3)
        return x
    
from torch import nn
class PretrainHead_dense_anomaly(nn.Module):
    def __init__(
        self,
        d_model: int = 768,
        patch_len: int = 8,
        head_dropout: float = 0.1,
        orth_gain: float = 1.41,
        number_of_layers: int = 4,  # number_of_layers 추가
    ):
        super().__init__()
        self.dropout = nn.Dropout(head_dropout)
        if d_model >768:
            addition_d_model = d_model - 768
            d_model = 768
            
        else:
            addition_d_model = 0

        # Linear 및 LayerNorm 레이어를 동적으로 생성
        self.linears = nn.ModuleList([nn.Linear(d_model + addition_d_model if i == 0 else d_model, d_model) for i in range(number_of_layers)])
        self.layer_norms = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(number_of_layers)])
        self.final_linear = nn.Linear(d_model, patch_len)  # 마지막 레이어

        self.relu = nn.SiLU()

        # Orthogonal 초기화
        if orth_gain is not None:
            for linear in self.linears:
                torch.nn.init.orthogonal_(linear.weight, gain=orth_gain)
                linear.bias.data.zero_()
            torch.nn.init.orthogonal_(self.final_linear.weight, gain=orth_gain)
            self.final_linear.bias.data.zero_()

    def forward(self, x):
        # number_of_layers에 따라 동적으로 레이어 적용
        for linear, norm in zip(self.linears, self.layer_norms):
            x = linear(x)
            x = norm(x)
            x = self.relu(x)
            x = self.dropout(x)
        
        x = self.final_linear(x)
        x = x.flatten(start_dim=2, end_dim=3)
        return x