"""
This function is adapted from [moment] by [mononitogoswami]
Original source: [https://github.com/moment-timeseries-foundation-model/moment]
"""

import os
import sys
import tempfile
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset, random_split
from tqdm import tqdm
from torch import nn
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from ..utils.torch_utility import get_trainable_param_count

from .base import BaseDetector
from ..utils.dataset import ReconstructDataset_TSPulse, ForecastDataset_TSPulse
from ..utils.torch_utility import get_gpu
from models.tspulse.modeling_tspulse import TSPulseForReconstruction
from models.tspulse.utils import (
    patchwise_stitched_reconstruction_vectorized_multikey,
    PatchMaskingDatasetWrapper,
)
import matplotlib.pyplot as plt
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments


def min_max_normalize(tensor):
    # Replace NaN values with 0
    tensor = torch.nan_to_num(tensor, nan=0.0)

    min_val, _ = torch.min(tensor, dim=1, keepdim=True)
    max_val, _ = torch.max(tensor, dim=1, keepdim=True)

    # Handle the case where all elements are the same
    if torch.equal(min_val, max_val):
        return torch.zeros_like(tensor)

    normalized_tensor = (tensor - min_val) / (max_val - min_val)
    return normalized_tensor


def zscore_normalize(tensor):
    # Replace NaN values with 0
    tensor = torch.nan_to_num(tensor, nan=0.0)

    mean = torch.mean(tensor, dim=1, keepdim=True)
    std = torch.std(tensor, dim=1, keepdim=True)

    # Handle the case where all elements are the same
    if torch.equal(std, torch.zeros_like(std)):
        return torch.zeros_like(tensor)

    normalized_tensor = (tensor - mean) / std
    return normalized_tensor


def torch_nanmean(x, dim=None, keepdim=False):
    mask = ~torch.isnan(x)
    masked_x = torch.where(mask, x, torch.tensor(0.0, device=x.device, dtype=x.dtype))
    sum_valid = masked_x.sum(dim=dim, keepdim=keepdim)
    count_valid = mask.sum(dim=dim, keepdim=keepdim)
    return sum_valid / count_valid.clamp(min=1)  # Avoid division by zero


class BatchPlotter:
    def __init__(self, num_plots) -> None:
        self.num_plots = num_plots
        self.all_batchx = []
        self.all_batchy = []
        self.all_outputs = []
        self.all_scores = []
        self.channel_dim = -1

    def add_data(self, batch_x, batch_y, output) -> None:
        self.all_batchx.append(batch_x.detach().cpu().numpy())
        self.all_batchy.append(batch_y.numpy())
        self.all_outputs.append(output.detach().cpu().numpy())
        # self.all_scores.append(pointwise_score.detach().cpu().numpy())

    def plot_batches(self, save_dir, filename, channel_last=True, aggr_window_loc=416):
        self.all_batchx = np.concatenate(self.all_batchx, axis=0)
        self.all_batchy = np.concatenate(self.all_batchy, axis=0)
        self.all_outputs = np.concatenate(self.all_outputs, axis=0)
        # self.all_scores = np.concatenate(self.all_scores, axis=0)

        if not channel_last:
            self.all_batchx = np.transpose(self.all_batchx, (0, 2, 1))
            self.all_outputs = np.transpose(self.all_outputs, (0, 2, 1))
            # self.all_scores = np.transpose(self.all_scores, (0, 2, 1))

        if self.all_batchx.shape[-1] > 1:
            self.num_plots = 5
        else:
            self.num_plots = 10

        # Create a figure with 20 subplots
        fig, axs = plt.subplots(self.num_plots, 1, figsize=(20, 6 * self.num_plots))
        # Flatten the axs array for easier iteration
        self.axs = axs.flatten()

        # Find indices with possible anomaly
        # Step 1: Find the indices where there is at least one anomaly in the window
        anomaly_indices = np.where(self.all_batchy.sum(axis=1) > 0)[0]

        if len(anomaly_indices) > 0:
            # Step 2: Sample from these indices
            num_anomalies = len(anomaly_indices)
            num_samples = min(self.num_plots, num_anomalies)  # avoid sampling more than available

            # Step 3: Randomly sample without replacement (or with, if needed)
            rand_indx = np.random.choice(anomaly_indices, size=num_samples, replace=False)
        else:
            rand_indx = np.random.choice(self.all_batchx.shape[0], size=self.num_plots, replace=False)

        colors = ["blue", "green", "red", "purple"]  # Define a list of colors to use

        for i in range(self.num_plots):
            for j in range(min(self.all_batchx.shape[self.channel_dim], 4)):
                color = colors[j % len(colors)]  # Cycle through the colors

                # Plot reconstruction data with dotted line
                self.axs[i].plot(
                    self.all_outputs[rand_indx[i], :, j],
                    label=f"recon_{j}",
                    linewidth=4,
                    linestyle="--",
                    color=color,
                )

                # Plot original data
                self.axs[i].plot(
                    self.all_batchx[rand_indx[i], :, j],
                    label=f"data_{j}",
                    linewidth=4,
                    linestyle="-",
                    color=color,
                    alpha=0.5,
                )

                # plot vertical line
                self.axs[i].axvline(x=aggr_window_loc, color="red", linestyle="--", linewidth=4)

            self.axs[i].legend(loc="upper center")
            self.axs[i].set_title(f"Index {rand_indx[i]}")

        # Remove the gaps between subplots
        plt.tight_layout()
        plt.savefig(f"{save_dir}/{filename}.pdf")
        plt.close()


class TSPulse(BaseDetector):
    def __init__(
        self,
        win_size: int = None,
        input_c: int = 1,
        use_ts_from_fft: bool = False,
        use_forecast: bool = False,
        aggr_win_size: int = None,
        windowed_detector: bool = False,
        mask_type: str = "user",
        tspulse_decoder_mode: str = None,
        batch_size: int = 50_000,
        finetune_num_epochs: int = 20,
        validation_size: float = 0.2,
        lr: float = 1e-4,
        plot: bool = False,
        save_models: bool = False,
        num_plots: int = 5,
        save_dir: str = None,
        seed: int = 42,
        filename: str = None,
        model_path: str = None,
        window_position: str = "last",
        freeze_backbone: bool = False,
        finetune: bool = False,
    ):
        self.model_name = "TSPulse"
        self.win_size = win_size
        self.input_c = input_c
        self.batch_size = batch_size
        self.anomaly_criterion = nn.MSELoss(reduce=False)
        self.finetune_num_epochs = finetune_num_epochs
        self.validation_size = validation_size
        self.lr = lr
        self.plot = plot
        self.num_plots = num_plots
        self.save_models = save_models
        self.save_dir = save_dir
        self.filename = filename
        self.use_ts_from_fft = use_ts_from_fft
        self.aggr_win_size = aggr_win_size
        self.windowed_detector = windowed_detector
        self.model_path = model_path
        self.window_position = window_position
        self.use_forecast = use_forecast
        self.forecast_window_size = None

        if tspulse_decoder_mode is None:
            if self.input_c > 1:
                self.tspulse_decoder_mode = "mix_channel"
            else:
                self.tspulse_decoder_mode = "common_channel"
            print(
                f"Forcing TSPulse's decoder_mode to {self.tspulse_decoder_mode} based on number of channels since `tspulse_decoder_mode=None`."
            )
        else:
            self.tspulse_decoder_mode = tspulse_decoder_mode
        print("TSPulse decoder mode =", self.tspulse_decoder_mode)

        if [self.use_ts_from_fft, self.use_forecast].count(True) > 1:
            raise ValueError("Only one can be True among these: use_ts_from_fft, use_forecast")

        random.seed(seed)
        np.random.seed(seed)

        cuda = True
        self.cuda = cuda
        self.device = get_gpu(self.cuda)

        if finetune:
            # Load model with "user" mask_type
            self.model = TSPulseForReconstruction.from_pretrained(
                self.model_path,
                num_input_channels=self.input_c,
                decoder_mode=self.tspulse_decoder_mode,
                scaling="revin",
                mask_type="user",
            ).to(self.device)
        else:
            if self.windowed_detector:
                assert self.aggr_win_size is not None, (
                    "`aggr_win_size` should not be `None` when `windowed_detector` is used."
                )

                self.model = TSPulseForReconstruction.from_pretrained(
                    model_path,
                    num_input_channels=self.input_c,
                    mask_type=mask_type,
                    decoder_mode=self.tspulse_decoder_mode,
                    scaling="revin",
                )
            else:
                # don't pass mask type
                self.model = TSPulseForReconstruction.from_pretrained(
                    model_path,
                    num_input_channels=self.input_c,
                    decoder_mode=self.tspulse_decoder_mode,
                    scaling="revin",
                )
                # Aggregation win size = context length of the model
                self.aggr_win_size = self.model.config.context_length

        print("Loaded TSPulse model from ", model_path)

        if self.win_size is None:
            print("Getting win_size from the model's context length")
            self.win_size = self.model.config.context_length

        self.model = self.model.to(self.device).float()

        # Get models' metadata
        if self.use_forecast:
            self.forecast_window_size = self.model.config.prediction_length

        # For fine-tuning
        self.freeze_backbone = freeze_backbone

        # Reduce batch size to avoid OOMs (A100-80GB Gpu)
        self.batch_size = int(
            self.batch_size // (self.input_c * (self.aggr_win_size // self.model.config.patch_length))
        )
        print(
            f"Recalculated Batch Size = {self.batch_size} since, num_channels = {self.input_c}, aggr_win_size = {self.aggr_win_size}, patch_len={self.model.config.patch_length}"
        )

    def _inference(self, test_loader, reconstruct_start, reconstruct_end):
        with torch.no_grad():
            for cnt, dict_batch in tqdm(enumerate(test_loader), total=len(test_loader)):
                batch_x, batch_y, batch_masks = (
                    dict_batch["past_values"],
                    dict_batch["anomaly_labels"],
                    dict_batch["past_observed_mask"],
                )
                if self.use_forecast:
                    batch_future_values = dict_batch["future_values"].to(self.device)

                # Move to device
                batch_x = batch_x.to(self.device).float()
                batch_masks = batch_masks.to(self.device)
                plot_inp = batch_x

                # Get TSPulse zeroshot output with stiched masked reconstruction
                if self.use_ts_from_fft:
                    keys_to_stitch = [
                        "reconstruction_outputs",
                        "reconstructed_ts_from_fft",
                        "fft_reconstruction_outputs",
                        "original_past_values_fft",
                    ]
                else:
                    keys_to_stitch = ["reconstruction_outputs"]

                if self.use_forecast:
                    model_forward_output = self.model(past_values=batch_x)
                else:
                    stitched_dict = patchwise_stitched_reconstruction_vectorized_multikey(
                        model=self.model,
                        past_values=batch_x,
                        patch_size=self.model.config.patch_length,
                        keys_to_stitch=keys_to_stitch,
                        keys_to_aggregate=[],
                        reconstruct_start=reconstruct_start,
                        reconstruct_end=reconstruct_end,
                    )

                # Get desired output from TSPulse outputs
                # output shape: [batch_size, window_size, n_channels]
                if self.use_ts_from_fft:
                    # time reconstruction from fft
                    output = stitched_dict["reconstructed_ts_from_fft"]
                elif self.use_forecast:
                    output = model_forward_output.forecast_output
                else:
                    # time reconstruction
                    output = stitched_dict["reconstruction_outputs"]

                # plot input/output
                if self.use_forecast:
                    plot_op = torch.cat((torch.full_like(plot_inp, fill_value=torch.nan), output), dim=1)
                    plot_inp = torch.cat((plot_inp, batch_future_values), dim=1)
                else:
                    plot_op = output

                # Calculate pointwise error
                if self.use_forecast:
                    # Similar to TimesFM, take the first point in the horizon
                    pointwise_score = self.anomaly_criterion(batch_future_values[:, 0, :], output[:, 0, :]).unsqueeze(
                        1
                    )
                else:
                    pointwise_score = self.anomaly_criterion(
                        batch_x[:, reconstruct_start:reconstruct_end, :],
                        output[:, reconstruct_start:reconstruct_end, :],
                    )

                # Aggregated score for this time point
                score = torch.mean(pointwise_score, dim=[1, 2]).detach().cpu().numpy()

                self.score_list.append(score)

                if self.plot:
                    self.plotter.add_data(
                        plot_inp,
                        batch_y,
                        plot_op,
                    )

    def zero_shot(self, data, label):
        # label is only used for plotting
        if self.use_forecast:
            test_dataset = ForecastDataset_TSPulse(
                data=data,
                window_size=self.win_size,
                forecast_horizon=1,  # only first point will be used
                label=label,
            )
        else:
            test_dataset = ReconstructDataset_TSPulse(
                data=data,
                window_size=self.win_size,
                aggr_window_size=self.aggr_win_size,
                label=label,
                return_dict=True,
            )
        test_loader = DataLoader(
            dataset=test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=2,
        )

        # Initialize score list and plotter
        self.score_list = []
        if self.plot:
            self.plotter = BatchPlotter(self.num_plots)

        if self.window_position == "first" and not self.use_forecast:
            # Start masking reconstrunction
            self._inference(
                test_loader=test_loader,
                reconstruct_start=0,
                reconstruct_end=self.aggr_win_size,
            )

            # Have start masking in general
            # But, have end masking at the end to avoid end-point-abberation
            start_of_end_masking_idx = max(
                (data.shape[0] - 2 * self.model.config.context_length + self.aggr_win_size + 1),
                0,
            )
            # Create a subset from start_idx to the end
            subset = Subset(test_dataset, list(range(start_of_end_masking_idx, len(test_dataset))))
            subset_loader = DataLoader(subset, batch_size=self.batch_size, shuffle=False)

            self._inference(
                test_loader=subset_loader,
                reconstruct_start=self.model.config.context_length - self.aggr_win_size,
                reconstruct_end=self.model.config.context_length,
            )
        else:
            # Start masking reconstrunction
            self._inference(
                test_loader=test_loader,
                reconstruct_start=self.model.config.context_length - self.aggr_win_size,
                reconstruct_end=self.model.config.context_length,
            )

        if self.plot:
            if self.use_forecast:
                vline_loc = self.model.config.context_length
            elif self.window_position == "last":
                vline_loc = self.model.config.context_length - self.aggr_win_size
            else:
                vline_loc = self.aggr_win_size
            self.plotter.plot_batches(
                os.path.join(self.save_dir, "plots"),
                self.filename,
                aggr_window_loc=vline_loc,
            )

        self.__anomaly_score = np.concatenate(self.score_list, axis=0).reshape(-1)

        if self.__anomaly_score.shape[0] < len(data):
            if self.use_forecast:
                start_pad_len = self.win_size
                end_pad_len = 0
            else:
                if self.window_position == "first":
                    start_pad_len = math.ceil((self.aggr_win_size - 1) / 2)
                    end_pad_len = (self.aggr_win_size - 1) // 2
                else:
                    start_pad_len = self.win_size - self.aggr_win_size // 2 - 1
                    end_pad_len = self.aggr_win_size // 2

            self.__anomaly_score = np.array(
                [self.__anomaly_score[0]] * start_pad_len
                + list(self.__anomaly_score)
                + [self.__anomaly_score[-1]] * end_pad_len
            )
        self.decision_scores_ = self._smooth_scores(self.__anomaly_score)

    def _smooth_scores(self, scores, window_size=8):
        if self.use_forecast:
            return scores

        # Ensure the window size is valid
        if window_size < 1:
            raise ValueError("Window size must be at least 1")

        # Use numpy's convolve function to smooth the scores
        smoothed_scores = np.convolve(scores, np.ones(window_size) / window_size, mode="same")

        return smoothed_scores

    def fit(self, data):
        try:
            print("Fine-tuning TSPulse.")
            create_valid = True
            if data.shape[0] < 3000:  # 20% of this should be > context_len
                print("Data too small to create a validation set.")
                create_valid = False
                self.validation_size = 0.0

            if data.shape[0] < self.model.config.context_length:
                print("Skipping fine-tuning due to very short length")
                return

            tsTrain = data[: int((1 - self.validation_size) * len(data))]
            if create_valid:
                tsValid = data[int((1 - self.validation_size) * len(data)) :]

            train_dataset = PatchMaskingDatasetWrapper(
                ReconstructDataset_TSPulse(tsTrain, window_size=self.win_size, return_dict=True),
                window_length=self.aggr_win_size,
                patch_length=self.model.config.patch_length,
                window_position=self.window_position,
            )
            if len(train_dataset) < 100:
                print("Skipping fine-tuning due to very few training samples")
                return

            if create_valid:
                valid_dataset = PatchMaskingDatasetWrapper(
                    ReconstructDataset_TSPulse(tsValid, window_size=self.win_size, return_dict=True),
                    window_length=self.aggr_win_size,
                    patch_length=self.model.config.patch_length,
                    window_position=self.window_position,
                )
            else:
                valid_dataset = train_dataset

            max_finetune_samples = 100_000
            if len(train_dataset) > max_finetune_samples:
                use_fraction = max_finetune_samples / len(train_dataset)
                # Randomly select use_fraction samples to make finetuning faster
                train_dataset, _ = random_split(train_dataset, [use_fraction, 1 - use_fraction])
                valid_dataset, _ = random_split(valid_dataset, [use_fraction, 1 - use_fraction])
                print(
                    f"Training samples are > max_finetune_samples ({max_finetune_samples}), using {round(use_fraction * 100)}% for faster fine-tuning."
                )

            # Freeze the backbone
            if self.freeze_backbone:
                print(
                    "Number of params before freezing backbone",
                    count_parameters(self.model),
                )

                # Freeze the backbone of the model
                for param in self.model.backbone.parameters():
                    param.requires_grad = False

                # Count params
                print(
                    "Number of params after freezing the backbone",
                    count_parameters(self.model),
                )

            temp_dir = tempfile.mkdtemp()

            suggested_lr = self.lr
            finetune_num_epochs = self.finetune_num_epochs
            if not create_valid:
                finetune_num_epochs = min(5, finetune_num_epochs)

            finetune_batch_size = self.batch_size
            if len(train_dataset) < 500:
                finetune_batch_size = 8
            num_workers = 4
            num_gpus = 1

            print(f"Fine-tune: Train samples = {len(train_dataset)}, Valid Samples = {len(valid_dataset)}")

            finetune_args = TrainingArguments(
                output_dir=temp_dir,
                overwrite_output_dir=True,
                learning_rate=suggested_lr,
                num_train_epochs=finetune_num_epochs,
                do_eval=True,
                eval_strategy="epoch",
                per_device_train_batch_size=finetune_batch_size,
                per_device_eval_batch_size=finetune_batch_size * 10,
                dataloader_num_workers=num_workers,
                report_to="tensorboard",
                save_strategy="epoch",
                logging_strategy="epoch",
                save_total_limit=1,
                logging_dir=temp_dir,  # Make sure to specify a logging directory
                load_best_model_at_end=True,  # Load the best model when training ends
                metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
                greater_is_better=False,  # For loss
            )

            # Create the early stopping callback
            early_stopping_callback = EarlyStoppingCallback(
                early_stopping_patience=5,  # Number of epochs with no improvement after which to stop
                early_stopping_threshold=0.0001,  # Minimum improvement required to consider as improvement
            )

            # Optimizer and scheduler
            optimizer = AdamW(self.model.parameters(), lr=suggested_lr)
            scheduler = OneCycleLR(
                optimizer,
                suggested_lr,
                epochs=finetune_num_epochs,
                steps_per_epoch=math.ceil(len(train_dataset) / (finetune_batch_size * num_gpus)),
            )

            finetune_trainer = Trainer(
                model=self.model,
                args=finetune_args,
                train_dataset=train_dataset,
                eval_dataset=valid_dataset,
                callbacks=[early_stopping_callback],
                optimizers=(optimizer, scheduler),
            )

            # Fine tune
            finetune_trainer.train()

            # save model
            if self.save_models:
                finetune_trainer.save_model(f"{self.save_dir}/tspulse_finetuned_model__{self.filename}")
            print("Successfully completed finetuning.")

        except Exception as e:
            print("Error occured in finetune. Error =", e)
            sys.exit(-1)

    def decision_function(self, data, label):
        """
        Not used, present for API consistency by convention.
        label is for plotting only.
        """

        self.zero_shot(data, label)
        return self.decision_scores_
