from collections import OrderedDict
from typing import List, Tuple, Optional, Union, Any, Dict

import numpy as np
import pandas as pd

import torch

import os

from utils.metadata import DATA_DIRECTORY
from .dataset import DatasetMixin


class TEPDataset(torch.utils.data.Dataset, DatasetMixin):
    def __init__(self, path: str = os.path.join(DATA_DIRECTORY, 'TEP_harvard'),
                 faults: Optional[Union[int, List[int]]] = None,
                 runs: Optional[Union[int, List[int]]] = None,
                 training: bool = True, standardize: bool = True, cache_size: int = 21):
        self.cache = OrderedDict()
        self.path = path
        self.training = training
        self.standardize = standardize
        self.cache_size = cache_size

        if isinstance(faults, int):
            faults = [faults]
        elif faults is None:
            faults = list(range(0, 20 + 1))

        self.faults = set(faults)

        if runs is None:
            # Select all runs
            runs = list(range(1, 500 + 1))
        elif isinstance(runs, int):
            runs = [runs]

        self.runs = runs

        if self.standardize:
            with np.load(os.path.join(self.path, 'TEP_FaultFree_Training_stats.npz')) as d:
                stats = dict(d)
            self.mean = stats['mean']
            self.std = stats['std']

    def load_data(self, fault: int, runs: Optional[Union[int, List[int]]] = None) -> Tuple[np.ndarray, np.ndarray]:
        test_str = 'Training' if self.training else 'Testing'

        if runs is None:
            # Select all runs
            runs = list(range(1, 500 + 1))
        elif isinstance(runs, int):
            runs = [runs]

        # Check if data is in cache
        if (fault, self.training) in self.cache:
            # print(f'Cache HIT: fault {fault}, Training {self.training}')
            data = self.cache[(fault, self.training)]
        else:
            if fault == 0:
                fname = f'TEP_FaultFree_{test_str}.csv'
            else:
                fname = f'TEP_Faulty_{test_str}_{fault:02d}.csv'
            data = pd.read_csv(os.path.join(self.path, fname))

            # Set correct label for the first few samples, where the fault did not occur yet
            warmup = 20 if self.training else 160
            data.loc[data['sample'] <= warmup, 'faultNumber'] = 0

            if self.standardize:
                data[data.columns[3:]] -= self.mean
                data[data.columns[3:]] /= self.std
            data[data.columns[3:]] = data[data.columns[3:]].astype(np.float32)

            if len(self.cache) >= self.cache_size:
                # Drop if cache is full (FIFO)
                self.cache.popitem(last=False)

            # print(f'Cache MISS: fault {fault}, Training {self.training}')
            self.cache[(fault, self.training)] = data.copy()

        # Filter out unwanted runs
        data = data.loc[data['simulationRun'].isin(runs)]

        targets = data['faultNumber'].to_numpy()

        # Remove meta data
        data = data[data.columns[3:]]

        inputs = data.to_numpy()
        del data

        return inputs, targets

    def __getitem__(self, item: int) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        if not (0 <= item < len(self)):
            raise ValueError('Out of bounds')

        # Translate Index into specific run and fault number
        fault_index, run_index = divmod(item, len(self.runs))

        inputs, targets = self.load_data(fault_index, run_index+1)

        return (torch.as_tensor(inputs),), (torch.as_tensor(targets),)

    def __len__(self) -> Optional[int]:
        return len(self.faults) * len(self.runs)

    @property
    def seq_len(self) -> Optional[int]:
        return 500 if self.training else 960

    @property
    def num_features(self) -> int:
        return 52

    @staticmethod
    def get_default_pipeline() -> Dict[str, Dict[str, Any]]:
        return {}
