import ast
import csv
import errno
import functools
import hashlib
import json
import logging
import os
import shutil
import tempfile
import urllib.request
import zipfile
from typing import Tuple, List, Optional, Union, Dict, Any

import numpy as np
import torch

from utils.utils import getitem
from .dataset import DatasetMixin
from utils.metadata import DATA_DIRECTORY, PROJECT_ROOT

DATASET_URL = r'https://s3-us-west-2.amazonaws.com/telemanom/data.zip'
LABELS_URL = r'https://raw.githubusercontent.com/khundman/telemanom/master/labeled_anomalies.csv'
BUFFER_SIZE = 16*1024*1024
ZIP_CHECKSUM = 'b4d66deb492d9b0a353b51879152687ed9313897e8e19320d2dc853d738ed8a7'
FILE_CHECKSUMS = os.path.join(PROJECT_ROOT, 'data', 'smap', 'smap_checksums.json')


class SMAPDownloader:
    def __init__(self, data_path: str = os.path.join(DATA_DIRECTORY, 'smap')):
        self.data_path = data_path

    @staticmethod
    def compute_sha256(file, buffer_size: int = BUFFER_SIZE):
        hasher = hashlib.sha256()
        while True:
            data = file.read(buffer_size)
            if len(data) == 0:
                break
            hasher.update(data)

        return hasher.hexdigest()

    def check_existing_files(self) -> bool:
        if not os.path.isdir(self.data_path):
            return False

        # Check checksums for all relevant files
        with open(FILE_CHECKSUMS, mode='r') as f:
            checksums = json.load(f)
        for file, chksum in checksums.items():
            file_path = os.path.join(self.data_path, file)
            if not os.path.isfile(file_path):
                return False

            with open(file_path, mode='rb') as f:
                file_hash = self.compute_sha256(f)
            if file_hash != chksum:
                logging.error(f'SHA-256 checksum of file {file} is not correct! Expected "{chksum}", got "{file_hash}".')
                return False

        return True

    @staticmethod
    def download_to_file(url, file, buffer_size: int = BUFFER_SIZE):
        logging.info(f'Downloading "{url}"...')
        with urllib.request.urlopen(url) as data:
            shutil.copyfileobj(data, file, buffer_size)
        file.seek(0)

    def download_data(self):
        if self.check_existing_files():
            return

        # download files
        try:
            os.makedirs(self.data_path)
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        with tempfile.TemporaryFile() as tmp:
            self.download_to_file(DATASET_URL, tmp)

            if self.compute_sha256(tmp) != ZIP_CHECKSUM:
                raise RuntimeError('The SHA-256 Hash of the downloaded zip is not correct!')

            logging.info('Extracting data...')
            with zipfile.ZipFile(tmp, 'r') as zip:
                for info in zip.infolist():
                    if info.filename.startswith('data/test') or info.filename.startswith('data/train'):
                        info.filename = info.filename.replace('data/', '')
                        zip.extract(info, self.data_path)
            logging.info('Done!')

        with open(os.path.join(self.data_path, 'labeled_anomalies.csv'), mode='wb') as f:
            self.download_to_file(LABELS_URL, f)

        logging.info('Checking SHA-256 checksum of downloaded files...')
        if not self.check_existing_files():
            logging.critical('FAILURE!')
            raise RuntimeError('The SHA-256 Hash of the downloaded files is not correct!')


class SMAPDataset(torch.utils.data.Dataset, DatasetMixin):
    def __init__(self, data_path: str = os.path.join(DATA_DIRECTORY, 'smap'),
                 training: bool = True, download: bool = True):
        super(SMAPDataset, self).__init__()
        self.data_path = data_path
        self.training = training
        self.downloader = SMAPDownloader()

        # TODO: Perhaps we should consider each sequence in SMAP as a separate Dataset, because each sequence has a
        #   different feature and they are not generated by the same process

        if download:
            self.downloader.download_data()

        self.data = self.labels = None

    def load_data(self) -> Tuple[List[np.ndarray], ...]:
        with open(os.path.join(self.data_path, 'labeled_anomalies.csv'), 'r') as file:
            csv_reader = csv.reader(file, delimiter=',')
            res = [row for row in csv_reader][1:]
        res = sorted(res, key=functools.partial(getitem, item=0))
        # Note: P-2 is excluded in the OmniAnomaly code for some reason
        data_info = [row for row in res if row[1] == 'SMAP']

        labels = []
        if not self.training:
            for row in data_info:
                anomalies = ast.literal_eval(row[2])
                length = int(row[-1])
                label = np.zeros([length], dtype=np.int64)
                for anomaly in anomalies:
                    label[anomaly[0]:anomaly[1] + 1] = 1
                labels.append(label)

        def load_sequences(category):
            data = []
            for row in data_info:
                filename = row[0]
                temp = np.load(os.path.join(self.data_path, category, filename + '.npy'))
                data.append(temp.astype(np.float32))
                if self.training:
                    labels.append(np.zeros((temp.shape[0],), dtype=np.int64))
            return data

        return load_sequences('train' if self.training else 'test'), labels

    def __len__(self) -> int:
        return 55

    @property
    def seq_len(self) -> Union[int, List[int]]:
        if self.training:
            return [2880, 2648, 2736, 2690, 705, 682, 2879, 762, 762, 2435, 2849, 2611, 312, 1490, 2880, 2880, 2833,
                    2561, 2594, 2583, 2602, 2583, 2880, 2880, 2880, 2880, 2880, 2880, 2880, 2880, 2880, 2880, 2769,
                    2880, 2880, 2869, 2861, 2880, 2820, 2478, 2624, 2551, 2881, 2446, 2872, 2821, 2821, 2855, 2609,
                    2853, 2874, 2818, 2875, 2855, 2876]

        return [8640, 7914, 8205, 8080, 4693, 4453, 8631, 8375, 8434, 8044, 8509, 7431, 7918, 7663, 8595, 8640, 8473,
                7628, 7884, 7642, 7874, 7406, 8516, 8505, 8514, 8512, 8640, 8532, 8307, 8354, 8294, 8300, 8310, 8532,
                8302, 8584, 8626, 8376, 8469, 7361, 7907, 7632, 8640, 8029, 8505, 8209, 8209, 8493, 7783, 8071, 7244,
                7331, 8612, 8625, 8579]

    @property
    def num_features(self) -> Union[int, Tuple[int, ...]]:
        return 25

    def __getitem__(self, item: int) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        if self.data is None:
            self.data, self.labels = self.load_data()

        return (torch.as_tensor(self.data[item]),), (torch.as_tensor(self.labels[item]),)

    @staticmethod
    def get_default_pipeline() -> Dict[str, Dict[str, Any]]:
        # TODO
        return {}

    @staticmethod
    def get_feature_names() -> List[str]:
        # Feature names not really specified
        return [''] * 25


class MSLDataset(torch.utils.data.Dataset, DatasetMixin):
    def __init__(self, data_path: str = os.path.join(DATA_DIRECTORY, 'smap'), training: bool = True,
                 download: bool = True):

        super(MSLDataset, self).__init__()

        self.data_path  = data_path
        self.training   = training
        self.downloader = SMAPDownloader()

        if download:
            self.downloader.download_data()

        self.data = self.labels = None

    def load_data(self) -> Tuple[List[np.ndarray], ...]:

        with open(os.path.join(self.data_path, 'labeled_anomalies.csv'), 'r') as file:

            csv_reader = csv.reader(file, delimiter=',')
            res        = [row for row in csv_reader][1:]

        res       = sorted(res, key=functools.partial(getitem, item=0))
        data_info = [row for row in res if row[1] == 'MSL']

        labels = []

        if not self.training:

            for row in data_info:

                anomalies = ast.literal_eval(row[2])
                length    = int(row[-1])
                label     = np.zeros([length], dtype=np.int64)

                for anomaly in anomalies:
                    label[anomaly[0]:anomaly[1] + 1] = 1

                labels.append(label)

        def load_sequences(category):

            data = []

            for row in data_info:

                filename = row[0]
                temp     = np.load(os.path.join(self.data_path, category, filename + '.npy'))

                data.append(temp.astype(np.float32))

                if self.training:
                    labels.append(np.zeros((temp.shape[0],), dtype=np.int64))

            return data

        return load_sequences('train' if self.training else 'test'), labels

    def __len__(self) -> int:
        return 27

    @property
    def seq_len(self) -> Union[int, List[int]]:
        if self.training:
            return [2158, 764, 3675, 2074, 1451, 2244, 2598, 2511, 3342, 2209, 2208, 2037, 2076, 2032, 1565, 1587, 4308, 3969, 2880, 3682, 926, 1145, 1145, 2272, 2272, 748, 439]

        return [2264, 2051, 2625, 2158, 2191, 3422, 3922, 5054, 2487, 2277, 2277, 2127, 2038, 2303, 2049, 2156, 6100, 3535, 6100, 2856, 1827, 2430, 2430, 2217, 2218, 1519, 1096]

    @property
    def num_features(self) -> Union[int, Tuple[int, ...]]:
        return 55

    def __getitem__(self, item: int) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        if self.data is None:
            self.data, self.labels = self.load_data()

        return (torch.as_tensor(self.data[item]),), (torch.as_tensor(self.labels[item]),)

    @staticmethod
    def get_default_pipeline() -> Dict[str, Dict[str, Any]]:
        # TODO
        return {}

    @staticmethod
    def get_feature_names() -> List[str]:
        # Feature names not really specified
        return [''] * 55
