import glob
import multiprocessing
import os
import pickle
import pdb
from itertools import product

import numpy as np
import pandas as pd

from .data import data_path, resource, dependency

__all__ = ['load']

cache_data = data_path('eeg', 'data.pickle')
cache_experiment = data_path('eeg', 'new_experiment.pickle')


def load():
    _fetch()

    # Generate cache if it does not exist.
    if not os.path.exists(cache_experiment):
        _parse()

    with open(cache_experiment, 'rb') as f:
        return pickle.load(f)


def _fetch():
    # Training data:
    resource(target=data_path('eeg', 'SMNI_CMI_TRAIN.tar.gz'),
             url='https://archive.ics.uci.edu/ml/'
                 'machine-learning-databases/eeg-mld/'
                 'SMNI_CMI_TRAIN.tar.gz')
    dependency(target=data_path('eeg', 'train'),
               source=data_path('eeg', 'SMNI_CMI_TRAIN.tar.gz'),
               commands=['tar xzf SMNI_CMI_TRAIN.tar.gz',
                         'mv SMNI_CMI_TRAIN train',
                         'find train | grep gz$ | xargs gunzip'])

    # Test data:
    resource(target=data_path('eeg', 'SMNI_CMI_TEST.tar.gz'),
             url='https://archive.ics.uci.edu/ml/'
                 'machine-learning-databases/eeg-mld/'
                 'SMNI_CMI_TEST.tar.gz')
    dependency(target=data_path('eeg', 'test'),
               source=data_path('eeg', 'SMNI_CMI_TEST.tar.gz'),
               commands=['tar xzf SMNI_CMI_TEST.tar.gz',
                         'mv SMNI_CMI_TEST test',
                         'find test | grep gz$ | xargs gunzip'])


def _parse_trial(fp):
    rows = np.genfromtxt(fp, delimiter=' ', dtype=str)
    sites = {}

    # Extract site names.
    site_names = tuple(sorted(set([x[1] for x in rows])))

    # Extract sites.
    for row in rows:
        try:
            sites[row[1]].append((row[2], row[3]))
        except KeyError:
            sites[row[1]] = [(row[2], row[3])]

    # Convert to data series, assuming that all inputs are the same for all
    # sites.
    x, ys = None, []
    for site in sorted(sites.keys()):
        rows = np.array(sites[site], dtype=float)
        x = rows[:, 0] / 256.  # Sampled at 256 Hz.
        ys.append(rows[:, 1])

    return {'df': pd.DataFrame(np.stack(ys, axis=0).T,
                               index=pd.Index(x, name='time'),
                               columns=site_names)}


def _extract_trials(fps):
    # Extract trial numbers.
    trial_numbers = map(lambda x: int(x.split('.rd.')[1]), fps)

    # Parse trials.
    with multiprocessing.Pool(processes=8) as pool:
        parsed_trials = pool.map(_parse_trial, fps)

    # Check that all lists of site names are the same.
    if len(set([tuple(d['df'].columns) for d in parsed_trials])) != 1:
        raise AssertionError('Site names are inconsistent between trials.')

    # Return dictionary mapping trial number to parsed result.
    return {k: v for k, v in zip(trial_numbers, parsed_trials)}


def _parse():
    if not os.path.exists(cache_data):
        print('Parsing EEG data. This may take a while.')

        numbers = [('c', n) for n in [337, 338, 339, 340, 341,
                                      342, 344, 345, 346, 347]] + \
                  [('a', n) for n in [364, 365, 368, 369, 370,
                                      371, 372, 375, 377, 378]]
        partitions = ['train', 'test']
        subject_dir_format = data_path('eeg', '{partition}',
                                       'co2{type}{subject_n:07d}')

        # Create containers for the different partitions.
        data = {partition: {} for partition in partitions}

        # Iterate over all subjects.
        for partition, (subject_type, subject_n) in product(partitions, numbers):

            # Create an entry for the current subject.
            if subject_n not in data[partition]:
                data[partition][subject_n] = {'type': subject_type}

            # Determine directory of subject.
            subject_dir = subject_dir_format.format(partition=partition,
                                                    type=subject_type,
                                                    subject_n=subject_n)

            # Get all trials files and extract data.
            trial_files = glob.glob(subject_dir + '/*.rd.*')
            data[partition][subject_n]['trials'] = _extract_trials(trial_files)

            # Tag data series with subject number and subject type.
            for d in data[partition][subject_n]['trials'].values():
                d['label'] = (subject_n, subject_type)

        # Checks that all lists of site names are consistent.
        if len(set([tuple(d['df'].columns)
                    for _, n in numbers
                    for p in partitions
                    for d in data[p][n]['trials'].values()])) != 1:
            raise AssertionError('Site names are inconsistent between subjects.')

        # Dump extracted data to file.
        with open(cache_data, 'wb') as f:
            pickle.dump(data, f)
    else:
        with open(cache_data, 'rb') as f:
            data = pickle.load(f)

    train = []
    for key, trial in data['train'][337]['trials'].items():
        df_ = trial['df'].copy()
        train.append(df_)

    test = []
    for key, trial in data['test'][337]['trials'].items():
        df_ = trial['df'].copy()
        test.append(df_)

    np.random.seed(1997)

    # Add mask for evaluation.
    test_mask = [np.random.binomial(n=1, p=0.25, size=(64, 1)).repeat(
        4, axis=0).repeat(64, axis=1) for _ in range(len(test))]

    # Add noise masks for random patterns of missingness.
    # train_mask = [np.random.binomial(n=1, p=0.25, size=(256, 64))
    #               for _ in range(len(train))]
    # test_mask = [np.random.binomial(n=1, p=0.25, size=(256, 64)) + mask for
    #              mask in test_mask]

    # Alternative form of noise.
    train_mask = [np.random.binomial(n=1, p=0.25, size=(16, 64)).repeat(
        16, axis=0) for _ in range(len(train))]
    test_mask = [np.random.binomial(n=1, p=0.25, size=(16, 64)).repeat(
        16, axis=0) + mask for mask in test_mask]

    train_test = []
    eval_test = []
    for test_df, train_df, test_m, train_m in zip(test, train, test_mask,
                                                  train_mask):
        train_df.mask(train_m > 0, np.nan, inplace=True)
        train_test.append(test_df.mask(test_m > 0, np.nan))
        eval_test.append(test_df.mask(test_m == 0, np.nan))

    # Save experiment data.
    with open(cache_experiment, 'wb') as f:
        pickle.dump((train, test, train_test, eval_test), f)
