# Copyright authors of TSPulse

import os
import glob
import re
import pandas as pd
import numpy as np
import torch
from sktime.datasets import load_from_tsfile_to_dataframe
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Dataset, random_split, Subset
import torch.nn.functional as F


def interpolate_missing(y):
    """
    Replaces NaN values in pd.Series `y` using linear interpolation
    """
    if y.isna().any():
        y = y.interpolate(method="linear", limit_direction="both")
    return y


class Normalizer(object):
    """
    Normalizes dataframe across ALL contained rows (time steps). Different from per-sample normalization.
    """

    def __init__(
        self,
        norm_type="standardization",
        mean=None,
        std=None,
        min_val=None,
        max_val=None,
    ):
        """
        Args:
            norm_type: choose from:
                "standardization", "minmax": normalizes dataframe across ALL contained rows (time steps)
                "per_sample_std", "per_sample_minmax": normalizes each sample separately (i.e. across only its own rows)
            mean, std, min_val, max_val: optional (num_feat,) Series of pre-computed values
        """

        self.norm_type = norm_type
        self.mean = mean
        self.std = std
        self.min_val = min_val
        self.max_val = max_val

    def normalize(self, df):
        """
        Args:
            df: input dataframe
        Returns:
            df: normalized dataframe
        """
        if self.norm_type == "standardization":
            if self.mean is None:
                self.mean = df.mean()
                self.std = df.std()
            return (df - self.mean) / (self.std + np.finfo(float).eps)

        elif self.norm_type == "minmax":
            if self.max_val is None:
                self.max_val = df.max()
                self.min_val = df.min()
            return (df - self.min_val) / (self.max_val - self.min_val + np.finfo(float).eps)

        elif self.norm_type == "per_sample_std":
            grouped = df.groupby(by=df.index)
            return (df - grouped.transform("mean")) / grouped.transform("std")

        elif self.norm_type == "per_sample_minmax":
            grouped = df.groupby(by=df.index)
            min_vals = grouped.transform("min")
            return (df - min_vals) / (grouped.transform("max") - min_vals + np.finfo(float).eps)

        else:
            raise (NameError(f'Normalize method "{self.norm_type}" not implemented'))


def k_fold_cv(skf, dataset, kth_fold):
    y = []
    for idx in range(len(dataset)):
        y.append(dataset[idx]["target_values"])

    for i, (train_indices, val_indices) in enumerate(skf.split(np.zeros(len(y)), y)):
        if i == kth_fold:
            train_dataset = Subset(dataset, train_indices)
            val_dataset = Subset(dataset, val_indices)
    return train_dataset, val_dataset


def get_uea_classification_data(args):
    data_path = args.data_path

    output = {}

    base_dataset = UEADataset(
        args=args,
        root_path=os.path.join(data_path, args.dset),
        flag="TRAIN",
    )

    if args.valid_split_strat is None:
        dataset_size = len(base_dataset)
        val_size = int(args.split_valid_ratio * dataset_size)  # 10% valid split
        train_size = dataset_size - val_size
        train_dataset, val_dataset = random_split(base_dataset, [train_size, val_size])
        output["dset_train"] = train_dataset
        output["dset_valid"] = val_dataset
    elif args.valid_split_strat == "K-fold":
        skf = StratifiedKFold(n_splits=args.num_folds, shuffle=True)
        train_dataset, val_dataset = k_fold_cv(skf, base_dataset, args.kth_fold)
        output["dset_train"] = train_dataset
        output["dset_valid"] = val_dataset

    output["dset_test"] = UEADataset(
        args=args,
        root_path=os.path.join(data_path, args.dset),
        flag="TEST",
    )

    output["num_input_channels"] = output["dset_train"][0]["past_values"].shape[1]
    output["num_targets"] = output["dset_test"].num_targets
    return output


class UEADataset(Dataset):
    """
    Dataset class for datasets included in:
        Time Series Classification Archive (www.timeseriesclassification.com)
    Argument:
        limit_size: float in (0, 1) for debug
    Attributes:
        all_df: (num_samples * seq_len, num_columns) dataframe indexed by integer indices, with multiple rows corresponding to the same index (sample).
            Each row is a time step; Each column contains either metadata (e.g. timestamp) or a feature.
        feature_df: (num_samples * seq_len, feat_dim) dataframe; contains the subset of columns of `all_df` which correspond to selected features
        feature_names: names of columns contained in `feature_df` (same as feature_df.columns)
        all_IDs: (num_samples,) series of IDs contained in `all_df`/`feature_df` (same as all_df.index.unique() )
        labels_df: (num_samples, num_labels) pd.DataFrame of label(s) for each sample
        max_seq_len: maximum sequence (time series) length. If None, script argument `max_seq_len` will be used.
            (Moreover, script argument overrides this attribute)
    """

    def __init__(self, args, root_path, file_list=None, limit_size=None, flag=None):
        self.args = args
        self.root_path = root_path
        self.flag = flag
        self.all_df, self.labels_df, self.num_targets = self.load_all(root_path, file_list=file_list, flag=flag)

        self.all_IDs = self.all_df.index.unique()  # all sample IDs (integer indices 0 ... num_samples-1)

        if limit_size is not None:
            if limit_size > 1:
                limit_size = int(limit_size)
            else:  # interpret as proportion if in (0, 1]
                limit_size = int(limit_size * len(self.all_IDs))
            self.all_IDs = self.all_IDs[:limit_size]
            self.all_df = self.all_df.loc[self.all_IDs]

        # use all features
        self.feature_names = self.all_df.columns
        self.feature_df = self.all_df

        # pre_process
        normalizer = Normalizer()
        self.feature_df = normalizer.normalize(self.feature_df)

    def load_all(self, root_path, file_list=None, flag=None):
        """
        Loads datasets from csv files contained in `root_path` into a dataframe, optionally choosing from `pattern`
        Args:
            root_path: directory containing all individual .csv files
            file_list: optionally, provide a list of file paths within `root_path` to consider.
                Otherwise, entire `root_path` contents will be used.
        Returns:
            all_df: a single (possibly concatenated) dataframe with all data corresponding to specified files
            labels_df: dataframe containing label(s) for each sample
        """
        # Select paths for training and evaluation
        if file_list is None:
            data_paths = glob.glob(os.path.join(root_path, "*"))  # list of all paths
        else:
            data_paths = [os.path.join(root_path, p) for p in file_list]
        if len(data_paths) == 0:
            raise Exception("No files found using: {}".format(os.path.join(root_path, "*")))
        if flag is not None:
            data_paths = list(filter(lambda x: re.search(flag, x), data_paths))
        input_paths = [p for p in data_paths if os.path.isfile(p) and p.endswith(".ts")]
        if len(input_paths) == 0:
            pattern = "*.ts"
            raise Exception("No .ts files found using pattern: '{}'".format(pattern))

        all_df, labels_df, num_targets = self.load_single(input_paths[0])  # a single file contains dataset

        return all_df, labels_df, num_targets

    def load_single(self, filepath):
        df, labels = load_from_tsfile_to_dataframe(
            filepath, return_separate_X_and_y=True, replace_missing_vals_with="NaN"
        )
        labels = pd.Series(labels, dtype="category")
        self.class_names = labels.cat.categories
        num_targets = len(self.class_names)
        labels_df = pd.DataFrame(
            labels.cat.codes, dtype=np.int8
        )  # int8-32 gives an error when using nn.CrossEntropyLoss

        lengths = df.applymap(
            lambda x: len(x)
        ).values  # (num_samples, num_dimensions) array containing the length of each series

        horiz_diffs = np.abs(lengths - np.expand_dims(lengths[:, 0], -1))

        if np.sum(horiz_diffs) > 0:  # if any row (sample) has varying length across dimensions
            df = df.applymap(subsample)

        lengths = df.applymap(lambda x: len(x)).values
        vert_diffs = np.abs(lengths - np.expand_dims(lengths[0, :], 0))
        if np.sum(vert_diffs) > 0:  # if any column (dimension) has varying length across samples
            self.max_seq_len = int(np.max(lengths[:, 0]))
        else:
            self.max_seq_len = lengths[0, 0]

        df = pd.concat(
            (
                pd.DataFrame({col: df.loc[row, col] for col in df.columns})
                .reset_index(drop=True)
                .set_index(pd.Series(lengths[row, 0] * [row]))
                for row in range(df.shape[0])
            ),
            axis=0,
        )

        # Replace NaN values
        grp = df.groupby(by=df.index)
        df = grp.transform(interpolate_missing)

        return df, labels_df, num_targets

    def instance_norm(self, case):
        if self.root_path.count("EthanolConcentration") > 0:  # special process for numerical stability
            mean = case.mean(0, keepdim=True)
            case = case - mean
            stdev = torch.sqrt(torch.var(case, dim=1, keepdim=True, unbiased=False) + 1e-5)
            case /= stdev
            return case
        else:
            return case

    def __getitem__(self, ind):
        batch_x = self.feature_df.loc[self.all_IDs[ind]].values
        labels = self.labels_df.loc[self.all_IDs[ind]].values
        X = self.instance_norm(torch.from_numpy(batch_x))
        original_x = X
        y = torch.from_numpy(labels)
        current_seq_len = X.shape[0]

        X = X.transpose(1, 0).unsqueeze(dim=0)  # l c --->  1 c l
        X = F.interpolate(
            X, self.args.context_points, mode="linear"
        )  # need a 3D tensor of shape B C L    # 1 c l ---> 1 c context_points
        X = X.squeeze(dim=0).transpose(0, 1)  # 1 c cp ---> cp c

        y = torch.squeeze(y)

        output_dict = {}
        output_dict["past_values"] = X.float()
        output_dict["target_values"] = y.long()

        return output_dict

    def __len__(self):
        return len(self.all_IDs)
