#  Copyright (c) 2024, Salesforce, Inc.
#  SPDX-License-Identifier: Apache-2
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from functools import partial
from typing import Callable, Optional

import hydra
import lightning as L
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.utils._pytree import tree_map
from torch.utils.data import Dataset, DistributedSampler
from uni2ts.model.moirai import MoiraiForecast
from uni2ts.common import hydra_util  # noqa: hydra resolvers
from uni2ts.data.loader import DataLoader
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split


class DataModule(L.LightningDataModule):
    def __init__(
        self,
        cfg: DictConfig,
        train_dataset: Dataset,
        val_dataset: Optional[Dataset | list[Dataset]],
    ):
        super().__init__()
        self.cfg = cfg
        self.train_dataset = train_dataset

        if val_dataset is not None:
            self.val_dataset = val_dataset
            self.val_dataloader = self._val_dataloader

    @staticmethod
    def get_dataloader(
        dataset: Dataset,
        dataloader_func: Callable[..., DataLoader],
        shuffle: bool,
        world_size: int,
        batch_size: int,
        num_batches_per_epoch: Optional[int] = None,
    ) -> DataLoader:
        sampler = (
            DistributedSampler(
                dataset,
                num_replicas=None,
                rank=None,
                shuffle=shuffle,
                seed=0,
                drop_last=False,
            )
            if world_size > 1
            else None
        )
        return dataloader_func(
            dataset=dataset,
            shuffle=shuffle if sampler is None else None,
            sampler=sampler,
            batch_size=batch_size,
            num_batches_per_epoch=num_batches_per_epoch,
        )

    def train_dataloader(self) -> DataLoader:
        return self.get_dataloader(
            self.train_dataset,
            instantiate(self.cfg.train_dataloader, _partial_=True),
            self.cfg.train_dataloader.shuffle,
            self.trainer.world_size,
            self.train_batch_size,
            num_batches_per_epoch=self.train_num_batches_per_epoch,
        )

    def _val_dataloader(self) -> DataLoader | list[DataLoader]:
        return tree_map(
            partial(
                self.get_dataloader,
                dataloader_func=instantiate(self.cfg.val_dataloader, _partial_=True),
                shuffle=self.cfg.val_dataloader.shuffle,
                world_size=self.trainer.world_size,
                batch_size=self.val_batch_size,
                num_batches_per_epoch=None,
            ),
            self.val_dataset,
        )

    @property
    def train_batch_size(self) -> int:
        return self.cfg.train_dataloader.batch_size // (
            self.trainer.world_size * self.trainer.accumulate_grad_batches
        )

    @property
    def val_batch_size(self) -> int:
        return self.cfg.val_dataloader.batch_size // (
            self.trainer.world_size * self.trainer.accumulate_grad_batches
        )

    @property
    def train_num_batches_per_epoch(self) -> int:
        return (
            self.cfg.train_dataloader.num_batches_per_epoch
            * self.trainer.accumulate_grad_batches
        )
from torch.utils.data import Dataset
from uni2ts.data.dataset import TimeSeriesDataset
class TrimSeries(Dataset):
    """
    A wrapper that returns each time series in a TimeSeriesDataset,
    composed of one (or a few) series, truncated at the front using train_index
    """
    def __init__(self, base_ds: TimeSeriesDataset, train_index: int):
        self.base_ds = base_ds
        self.train_index = train_index

    def __len__(self):
        return len(self.base_ds)          # Number of time series remains unchanged

    def __getitem__(self, idx: int):
        sample = self.base_ds[idx].copy() # Copy the dictionary
        # ── Trim fields with time dimension such as "target" and "observed_mask"
        for k in ("target", "observed_mask"):
            if k in sample:
                sample[k] = [
                    arr[: self.train_index] for arr in sample[k]
                ]
        # If needed, apply the same trimming to other fields (e.g., past_feat_dynamic_real)
        return sample
import numpy as np
import os
from torch.utils.data import Subset
@hydra.main(version_base="1.3", config_name="default.yaml")
def main(cfg: DictConfig):
    data_list = np.array(os.listdir('/home/root/dataset/Moirai/test'))
    data_number = cfg.data.dataset[:3]
    data_dir = data_list[[data_number in a for a in data_list]][0]
    train_index = int(data_dir.split('_')[-3])
    
    
    
    real_data_dir = data_list[[data_number in a[:3] for a in data_list]][0]
    train_index = int(real_data_dir.split('_')[-3])
    url = '/home/root/dataset/Moirai/test/' + real_data_dir
    
    print('==================================================================')
    print(url)
    print('==================================================================')
    if cfg.tf32:
        assert cfg.trainer.precision == 32
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    model: L.LightningModule = instantiate(cfg.model, _convert_="all")

    if cfg.compile:
        model.module.compile(mode=cfg.compile)
    trainer: L.Trainer = instantiate(cfg.trainer)
    train_dataset: Dataset = instantiate(cfg.data).load_dataset(
        model.train_transform_map
    )
    val_dataset: Optional[Dataset | list[Dataset]] = (
        tree_map(
            lambda ds: ds.load_dataset(model.val_transform_map),
            instantiate(cfg.val_data, _convert_="all"),
        )
        if "val_data" in cfg
        else None
    )
    L.seed_everything(cfg.seed + trainer.logger.version, workers=True)

    #trimmed_train = TrimSeries(train_dataset, train_index)
    trainer.fit(
        model,
        datamodule=DataModule(cfg, train_dataset, val_dataset),
        ckpt_path=cfg.ckpt_path,
    )
    forecast_model = MoiraiForecast(
    module              = model.module,     # ← Fine‑tuned weights
    prediction_length   = 1,
    context_length      = 256,
    patch_size          = 64,
    num_samples         = 100,
    target_dim          = 1,
    feat_dynamic_real_dim       = 0,
    past_feat_dynamic_real_dim  = 0,
    )
    forecast_model.eval()
    predictor = forecast_model.create_predictor(batch_size=256)



    with torch.no_grad():
        
        df = pd.read_csv(url, index_col=0, parse_dates=True)
        label = df['Label']
        df.drop(columns=['Label'], inplace=True)
        # Convert into GluonTS dataset
        ds = PandasDataset(dict(df))

        # Split into train/test set
        train, test_template = split(
            ds, offset=-0
        )  # assign last TEST time steps as test set

        # Generate rolling-window instances
        test_data_rolled = test_template.generate_instances(
            prediction_length=1,
            windows=len(df),
            distance=1,
        )

        # ---------------------------------------
        # 4) Prediction & Anomaly Scoring (MSE)
        # ---------------------------------------
        # Using the rolled test_data_rolled,
        # predict one step per window and compute the error
        forecasts = predictor.predict(test_data_rolled.input)
        forecast_iter = iter(forecasts)
        
        anomaly_scores = []
        for i, (inp, lbl) in enumerate(zip(test_data_rolled.input, test_data_rolled.label)):
            fct = next(forecast_iter)

            # fct.mean.shape = (1, n_channel) -> (prediction_length=1, target_dim=n_channel)
            y_pred = fct.mean.reshape(-1)   # shape (n_channel,)
            y_true = lbl["target"].reshape(-1)  # shape (n_channel,)

            # (y_true - y_pred)^2 for each channel
            # Use the average error across all channels as a single score
            err = np.mean((y_true - y_pred)**2)
            anomaly_scores.append(y_pred.tolist() + y_true.tolist() + [err])
    res_df = pd.DataFrame(anomaly_scores)
    n_ch = y_pred.shape[0]
    if n_ch==1:
        res_df_col = ['true_y','pred_y','mse_score']
    else:
        res_df_col = [f"true_ch{ch}" for ch in range(n_ch)] + [f"pred_ch{ch}" for ch in range(n_ch)] + ['mse_score']
    res_df.columns = res_df_col
    # anomaly_scores[i] → error at rolled index i (i.e., in the final test_steps segment)
    res_df['label'] = label

    res_df.to_csv('/home/root/Foundation_model/results/Moirai_FT/'+data_number +'_U_MOIRAI_FT.csv')
if __name__ == "__main__":
    main()
