import sys

sys.path.append("../")


import os
import glob
import pandas as pd
import torch
import numpy as np
import random, argparse
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import defaultdict
from sklearn.preprocessing import MinMaxScaler
from .evaluation.metrics import get_metrics
from .utils.slidingWindows import find_length_rank
from .model_wrapper import *
from .HP_list import Optimal_Uni_algo_HP_dict, Optimal_Multi_algo_HP_dict
from .get_args import get_args

plt.rcParams.update(
    {
        "font.size": 24,  # Set the desired font size
    }
)

# seeding
seed = 2024
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

print("CUDA Available: ", torch.cuda.is_available())
print("cuDNN Version: ", torch.backends.cudnn.version())


import numpy as np
import matplotlib.pyplot as plt
import os


import numpy as np


import numpy as np
import matplotlib.pyplot as plt
import os


class TSPlotter:
    def __init__(self, plot_len=512, find_anomaly_region=True, num_subplots=3) -> None:
        self.plot_len = plot_len
        self.find_anomaly_region = find_anomaly_region
        self.num_subplots = num_subplots

    def _find_top_k_anomaly_regions(self, binary_array, L, threshold=0.01, K=3):
        binary_array = np.array(binary_array)
        cumulative_sum = np.cumsum(binary_array)

        # Calculate window sums correctly
        window_sums = cumulative_sum[L:] - cumulative_sum[:-L]

        # Calculate anomaly percentages
        anomaly_percentages = window_sums / L

        # Find the top K indices that meet the threshold
        valid_indices = np.where(anomaly_percentages >= threshold)[0]
        if len(valid_indices) == 0:
            return [], []  # No valid regions found

        # Sort indices by anomaly percentages in descending order
        sorted_indices = valid_indices[np.argsort(anomaly_percentages[valid_indices])[::-1]]

        # Select top K indices that are as far apart as possible
        top_k_indices = []
        for idx in sorted_indices:
            if len(top_k_indices) == 0 or all(abs(idx - prev_idx) >= self.plot_len for prev_idx in top_k_indices):
                top_k_indices.append(idx)
                if len(top_k_indices) == K:
                    break

        top_k_percentages = anomaly_percentages[top_k_indices]

        return top_k_indices, top_k_percentages

    def plot(self, data, label, output, filename) -> None:
        fig, axs = plt.subplots(self.num_subplots, 1, figsize=(30, 4 * self.num_subplots))
        L = self.plot_len

        if self.find_anomaly_region:
            top_k_indices, _ = self._find_top_k_anomaly_regions(label, self.plot_len, K=self.num_subplots)
        else:
            top_k_indices = np.random.randint(
                low=0,
                high=output.shape[0] - L,
                size=self.num_subplots,
            )

        for i, start in enumerate(top_k_indices):
            axs[i].plot(
                data[start : start + L, ...],
                label="data",
                linewidth=4,
                linestyle="-",
                color="black",
            )
            axs[i].plot(
                label[start : start + L, ...],
                label="label",
                linewidth=4,
                linestyle="-",
                color="green",
            )
            axs[i].plot(
                output[start : start + L, ...],
                label="score",
                linewidth=4,
                linestyle="-",
                color="red",
            )
            axs[i].legend()

        plt.savefig(os.path.join(args.save_dir, "plots", f"full_{filename}.pdf"))
        plt.close()


if __name__ == "__main__":
    args = get_args()

    prefix = "uni"
    if "TSB-AD-M" in args.data_direc:
        prefix = "multi"

    if args.filename is not None:
        all_files = [args.filename]
    else:
        # all_files = glob.glob(args.data_direc + "/*.csv")
        # filter only eval files
        if prefix == "uni":
            if args.do_tuning:
                df_eval = pd.read_csv("Datasets/File_List/TSB-AD-U-Tuning.csv")
            else:
                df_eval = pd.read_csv("Datasets/File_List/TSB-AD-U-Eva.csv")
        else:
            if args.do_tuning:
                df_eval = pd.read_csv("Datasets/File_List/TSB-AD-M-Tuning.csv")
            else:
                df_eval = pd.read_csv("Datasets/File_List/TSB-AD-M-Eva.csv")

        all_files = list(df_eval["file_name"])
        print(f"Running AD with {args.AD_Name} on {len(all_files)} data files...")

    # Filter all_files based on dataset name
    if args.dataset_name is not None:
        all_files = [f for f in all_files if args.dataset_name in f]
        print("Files for dataset after filtering = ", args.dataset_name)

    all_results = defaultdict(list)
    vus_pr = {"file": [], args.AD_Name: []}
    i = 0
    for filename in all_files:
        print("filename =", filename)
        df = pd.read_csv(os.path.join(args.data_direc, filename)).dropna()
        data = df.iloc[:, 0:-1].values.astype(float)
        label = df["Label"].astype(int).to_numpy()

        slidingWindow = find_length_rank(data, rank=1)
        train_index = filename.split(".")[0].split("_")[-3]
        data_train = data[: int(train_index), :]
        if prefix == "uni":
            Optimal_Det_HP = Optimal_Uni_algo_HP_dict[args.AD_Name]
        else:
            Optimal_Det_HP = Optimal_Multi_algo_HP_dict[args.AD_Name]

        # Set model_path
        model_path = args.model_path

        # Use finetuned models
        if args.use_finetuned_models and args.AD_Name == "TSPulse_ZS":
            dataset_name = filename.split("_")[1]
            model_path = os.path.join(args.model_path, dataset_name, "tspulse_finetuned_model")

        if args.AD_Name in Semisupervise_AD_Pool:
            if "TSPulse" in args.AD_Name:
                output = run_Semisupervise_AD(
                    args.AD_Name,
                    data_train,
                    data,
                    label=label,
                    filename=filename.split(".")[0],
                    plot=args.plot,
                    save_dir=args.save_dir,
                    save_models=args.save_models,
                    windowed_detector=args.windowed_detector,
                    aggr_win_size=args.aggr_win_size,
                    use_ts_from_fft=args.use_ts_from_fft,
                    use_forecast=args.use_forecast,
                    ensemble_outputs=args.ensemble_outputs,
                    model_path=model_path,
                    tspulse_decoder_mode=args.tspulse_decoder_mode,
                    window_position=args.window_position,
                    batch_size=args.batch_size,
                    finetune_num_epochs=args.finetune_num_epochs,
                    freeze_backbone=args.freeze_backbone,
                    **Optimal_Det_HP,
                )
            else:
                output = run_Semisupervise_AD(args.AD_Name, data_train, data, **Optimal_Det_HP)
        elif args.AD_Name in Unsupervise_AD_Pool:
            if "TSPulse" in args.AD_Name:
                output = run_Unsupervise_AD(
                    args.AD_Name,
                    data,
                    label=label,
                    filename=filename.split(".")[0],
                    plot=args.plot,
                    save_dir=args.save_dir,
                    save_models=args.save_models,
                    windowed_detector=args.windowed_detector,
                    aggr_win_size=args.aggr_win_size,
                    use_ts_from_fft=args.use_ts_from_fft,
                    use_forecast=args.use_forecast,
                    ensemble_outputs=args.ensemble_outputs,
                    model_path=model_path,
                    tspulse_decoder_mode=args.tspulse_decoder_mode,
                    window_position=args.window_position,
                    batch_size=args.batch_size,
                    **Optimal_Det_HP,
                )
            elif "MOMENT_ZS" in args.AD_Name:
                output = run_Unsupervise_AD(
                    args.AD_Name,
                    data,
                    label=label,
                    filename=filename.split(".")[0],
                    plot=args.plot,
                    save_dir=args.save_dir,
                    **Optimal_Det_HP,
                )
            else:
                output = run_Unsupervise_AD(args.AD_Name, data, **Optimal_Det_HP)
        else:
            raise Exception(f"{args.AD_Name} is not defined")

        if isinstance(output, np.ndarray):
            output = MinMaxScaler(feature_range=(0, 1)).fit_transform(output.reshape(-1, 1)).ravel()

            if args.plot or args.save_results:
                data = MinMaxScaler(feature_range=(0, 1)).fit_transform(data)

            if args.plot:
                plotter = TSPlotter()
                plotter.plot(data, label, output, args.save_prefix + "_" + filename.split(".")[0])

            if args.save_results and len(all_files) == 1:
                file_base_name = filename.split(".")[0]
                output_save_dir = os.path.join(args.save_dir, f"{args.AD_Name}_output_dump")
                os.makedirs(output_save_dir, exist_ok=True)
                np.save(
                    f"{output_save_dir}/{file_base_name}_data.npy",
                    data,
                )
                np.save(
                    f"{output_save_dir}/{file_base_name}_label.npy",
                    label,
                )
                np.save(
                    f"{output_save_dir}/{file_base_name}_scores.npy",
                    output,
                )

            evaluation_result = get_metrics(
                output,
                label,
                slidingWindow=slidingWindow,
                pred=output > (np.mean(output) + 3 * np.std(output)),
            )
            print(f"Evaluation Result for {filename}: ", evaluation_result)

            all_results["file"].append(filename)
            for k in evaluation_result.keys():
                all_results[k].append(evaluation_result[k])
                if k == "VUS-PR":
                    vus_pr["file"].append(filename)
                    vus_pr[args.AD_Name].append(evaluation_result[k])

        else:
            print(f"At {filename}: " + output)

        i += 1

    print("=" * 50)
    print("FINAL RESULTS")
    print("-" * 50)
    df = pd.DataFrame()
    df["Stat"] = ["mean", "std", "median"]
    for k in all_results.keys():
        if k != "file":
            df[k] = [
                round(np.nanmean(all_results[k]), 2),
                round(np.nanstd(all_results[k]), 2),
                round(np.nanmedian(all_results[k]), 2),
            ]
    print(df)
    print("=" * 50)

    if args.save_results:
        df.to_csv(f"{args.save_dir}/{args.save_prefix}_results.csv")

        df_vus_pr = pd.DataFrame(vus_pr)
        df_vus_pr.to_csv(f"{args.save_dir}/{args.save_prefix}_VUS-PR.csv")

        df_all = pd.DataFrame(all_results)
        df_all.to_csv(f"{args.save_dir}/{args.save_prefix}_ALL-metrics.csv")
