import functools
import os
from typing import Tuple, Optional, Union, Callable, Dict, Any, List

import numpy as np
import pandas as pd
import torch

from utils.metadata import DATA_DIRECTORY
from .dataset import DatasetMixin
from .preprocessing import minmax_scaler


TRAIN_FILES = (
    '1_0_1000000_14.csv',
    '1_0_100000_15.csv',
    '1_0_100000_16.csv',
    '1_0_10000_17.csv',
    '1_0_500000_18.csv',
    '1_0_500000_19.csv',
    '2_0_100000_20.csv',
    '2_0_100000_22.csv',
    '2_0_1200000_21.csv',
    '3_0_100000_24.csv',
    '3_0_100000_25.csv',
    '3_0_100000_26.csv',
    '3_0_1200000_23.csv',
    '4_0_1000000_31.csv',
    '4_0_100000_27.csv',
    '4_0_100000_28.csv',
    '4_0_100000_29.csv',
    '4_0_100000_30.csv',
    '4_0_100000_32.csv',
    '5_0_100000_33.csv',
    '5_0_100000_34.csv',
    '5_0_100000_35.csv',
    '5_0_100000_36.csv',
    '5_0_100000_37.csv',
    '5_0_100000_40.csv',
    '5_0_50000_38.csv',
    '5_0_50000_39.csv',
    '6_0_100000_42.csv',
    '6_0_100000_43.csv',
    '6_0_100000_44.csv',
    '6_0_100000_45.csv',
    '6_0_100000_46.csv',
    '6_0_100000_52.csv',
    '6_0_1200000_41.csv',
    '6_0_300000_50.csv',
    '6_0_50000_47.csv',
    '6_0_50000_48.csv',
    '6_0_50000_49.csv',
    '6_0_50000_51.csv',
    '9_0_100000_1.csv',
    '9_0_100000_3.csv',
    '9_0_100000_4.csv',
    '9_0_100000_6.csv',
    '9_0_1200000_2.csv',
    '9_0_300000_5.csv',
    '10_0_100000_10.csv',
    '10_0_100000_11.csv',
    '10_0_100000_13.csv',
    '10_0_100000_8.csv',
    '10_0_100000_9.csv',
    '10_0_1200000_7.csv',
    '10_0_300000_12.csv',
)


TEST_FILES = (
    '1_2_100000_68.csv',
    '1_4_1000000_80.csv',
    '1_5_1000000_86.csv',
    '2_1_100000_60.csv',
    '2_2_200000_69.csv',
    '2_5_1000000_87.csv',
    '2_5_1000000_88.csv',
    '3_2_1000000_71.csv',
    '3_2_500000_70.csv',
    '3_4_1000000_81.csv',
    '3_5_1000000_89.csv',
    '4_1_100000_61.csv',
    '4_5_1000000_90.csv',
    '5_1_100000_63.csv',
    '5_1_100000_64.csv',
    '5_1_500000_62.csv',
    '5_2_1000000_72.csv',
    '5_4_1000000_82.csv',
    '5_5_1000000_91.csv',
    '5_5_1000000_92.csv',
    '6_1_500000_65.csv',
    '6_3_200000_76.csv',
    '6_5_1000000_93.csv',
    '9_2_1000000_66.csv',
    '9_3_500000_74.csv',
    '9_4_1000000_78.csv',
    '9_5_1000000_84.csv',
    '10_2_1000000_67.csv',
    '10_3_1000000_75.csv',
    '10_4_1000000_79.csv',
    '10_5_1000000_85.csv',
)


TRAIN_LENGTHS = {
    1: [14391, 2690, 3591, 3591, 2728, 14391],
    2: [28725, 4269, 35923],
    3: [28790, 28790, 28789, 28791],
    4: [7191, 28789, 28769, 28790, 28790, 86391],
    5: [28790, 28790, 28791, 4742, 3590, 28790, 7191, 2727],
    6: [28757, 28789, 28790, 28790, 3588, 86390, 28790, 53990, 7190, 2634, 2690, 2689],
    9: [28790, 3345, 86341, 14390, 86391, 53990],
    10: [14356, 13224, 28790, 28746, 28790, 28790, 35989]
}


TEST_LENGTHS = {
    1: [2945, 43233, 3632],
    2: [46791, 2883, 43230, 3631],
    3: [2482, 2620, 4231, 5937],
    4: [129591, 3632],
    5: [43191, 46791, 46810, 2489, 4232, 43230, 3629],
    6: [46807, 46785, 3629],
    9: [7506, 46808, 43259, 5938],
    10: [10284, 46807, 43230, 5930],
}


class ExathlonDataset(torch.utils.data.Dataset, DatasetMixin):
    def __init__(self, path: str = os.path.join(DATA_DIRECTORY, 'exathlon', 'data', 'processed'), app_id: int = 1,
                 training: bool = True, standardize: Union[bool, Callable[[pd.DataFrame, Dict], pd.DataFrame]] = True):
        """
        Create and initialise a dataset. Each dataset must implement DatasetMixin and inherit
        from torch.utils.data.Dataloader. The class must be called <name-of-dataset>Dataset and the name of the file
        must be <name-of-dataset>.lower()_dataset.py.

        :param path: Folder from which to load the dataset.
        :param app_id: Data from which app to load
        :param training: Whether to load the training or the test set.
        :param standardize: Can be either a bool that decides whether to apply the dataset-dependent default
            standardization or a function with signature (dataframe, stats) -> dataframe, where stats is a dictionary of
            common statistics on the training dataset (i.e., mean, std, median, etc. for each feature)
        """

        if app_id not in TRAIN_LENGTHS:
            raise ValueError(f'App ID must be one of {list(TEST_LENGTHS.keys())}')

        self.path = path
        self.app_id = app_id
        self.training = training

        self.inputs = None
        self.targets = None

        if callable(standardize):
            with np.load(os.path.join(self.path, 'train', f'train_stats_{app_id}.npz')) as d:
                stats = dict(d)
            self.standardize_fn = functools.partial(standardize, stats=stats)
        elif standardize:
            with np.load(os.path.join(self.path, 'train', f'train_stats_{app_id}.npz')) as d:
                stats = dict(d)
            self.standardize_fn = functools.partial(minmax_scaler, stats=stats)
        else:
            self.standardize_fn = None

    def load_data(self) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        test_str = 'train' if self.training else 'test'

        load_path = os.path.join(self.path, test_str)
        files = TRAIN_FILES if self.training else TEST_FILES
        files = [f for f in files if f.startswith(f'{self.app_id}_')]

        inputs, targets = [], []
        for f in files:
            file_name = os.path.join(load_path, f)

            data = pd.read_csv(file_name, index_col='t')

            if self.training:
                target = np.zeros(len(data), dtype=np.int64)
            else:

                target = data['Anomaly'].to_numpy()
                target = target != 0
                target = target.astype(np.int64)

            data = data.drop(columns=['Anomaly'])

            if self.standardize_fn is not None:
                data = self.standardize_fn(data)
            data = data.astype(np.float32)

            input = data.to_numpy()

            inputs.append(input)
            targets.append(target)

        return inputs, targets

    def __getitem__(self, item: int) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        """
        This should return the time series of the dataset. I.e., if the dataset has 5 independent time-series,
        passing 0, ..., 4 as item should return these time series. The format is (inputs, targets), where inputs
        and targets are tupples of torch.Tensors.

        :param item: Index of the time series to return.
        :return:
        """
        if not (0 <= item < len(self)):
            raise KeyError('Out of bounds')

        if self.inputs is None or self.targets is None:
            self.inputs, self.targets = self.load_data()

        return (torch.as_tensor(self.inputs[item]),), (torch.as_tensor(self.targets[item]),)

    def __len__(self) -> Optional[int]:
        return len(TRAIN_LENGTHS[self.app_id]) if self.training else len(TEST_LENGTHS[self.app_id])

    @property
    def seq_len(self) -> List[int]:
        if self.training:
            return TRAIN_LENGTHS[self.app_id]
        else:
            return TEST_LENGTHS[self.app_id]

    @property
    def num_features(self) -> int:
        return 19

    @staticmethod
    def get_default_pipeline() -> Dict[str, Dict[str, Any]]:
        return {
            'subsample': {'class': 'SubsampleTransform', 'args': {'subsampling_factor': 5, 'aggregation': 'mean'}},
            'cache': {'class': 'CacheTransform', 'args': {}}
        }

    @staticmethod
    def get_feature_names():
        return ['driver_StreamingMetrics_streaming_lastCompletedBatch_processingDelay_value',
                'driver_StreamingMetrics_streaming_lastCompletedBatch_schedulingDelay_value',
                'driver_StreamingMetrics_streaming_lastCompletedBatch_totalDelay_value',
                '1_diff_driver_StreamingMetrics_streaming_totalCompletedBatches_value',
                '1_diff_driver_StreamingMetrics_streaming_totalProcessedRecords_value',
                '1_diff_driver_StreamingMetrics_streaming_totalReceivedRecords_value',
                '1_diff_driver_StreamingMetrics_streaming_lastReceivedBatch_records_value',
                '1_diff_driver_BlockManager_memory_memUsed_MB_value',
                '1_diff_driver_jvm_heap_used_value', '1_diff_node5_CPU_ALL_Idle%', '1_diff_node6_CPU_ALL_Idle%',
                '1_diff_node7_CPU_ALL_Idle%', '1_diff_node8_CPU_ALL_Idle%',
                '1_diff_avg_executor_filesystem_hdfs_write_ops_value', '1_diff_avg_executor_cpuTime_count',
                '1_diff_avg_executor_runTime_count', '1_diff_avg_executor_shuffleRecordsRead_count',
                '1_diff_avg_executor_shuffleRecordsWritten_count', '1_diff_avg_jvm_heap_used_value']
