import os
import pickle
import pandas as pd
import numpy as np

from .data import data_path, resource

__all__ = ['load']

cache_tridaily = data_path('japan_daily', 'tridaily', 'experiment.pickle')


def load():
    _fetch()

    # If path does not exists, generate DataFrames.
    if not os.path.exists(cache_tridaily):
        _parse()

    with open(cache_tridaily, 'rb') as f:
        return pickle.load(f)


def _fetch():
    # All data.
    resource(target=data_path('japan_daily', 'japan_80s.csv'),
             url='https://www.dropbox.com/s/kokgr6ekb0gh7wq/japan_80s.csv?dl'
                 '=1')


def _parse():
    df = pd.read_csv(data_path('japan_daily', 'japan_80s.csv'))

    # Extract 1980 and 1981 data.
    df_ = df[(df['date'].str[:4] == '1980') | (df['date'].str[:4] == '1981')]

    # Rescale TAVG.
    df_['TAVG'] = df_['TAVG'] / 10

    # Convert dates to Datetime.
    df_['date'] = pd.DatetimeIndex(df_['date'])

    # Create tridaily groups.
    n = 3
    df_['group'] = (df_['date'] - df_.iloc[0]['date']).dt.days // n
    df_['time'] = (df_['date'] - df_.iloc[0]['date']).dt.days % n

    # Remove TAVG every other day to create train/test split.
    test_idx = df_[df_['time'] == 1].index

    # Train indices are all those that aren't in test indeces.
    train_idx = list(set(df_.index.tolist()) - set(test_idx))

    train = df_.copy()
    test = df_.copy()

    train['TAVG'][test_idx] = np.nan
    test['TAVG'][train_idx] = np.nan

    # Set all other nans in test df.
    for obs in ['PRCP', 'SNWD', 'TMAX', 'TMIN']:
        test[obs][:] = np.nan

    train_1980 = train[train['date'].dt.year == 1980]
    train_1981 = train[train['date'].dt.year == 1981]
    test_1980 = test[test['date'].dt.year == 1980]
    test_1981 = test[test['date'].dt.year == 1981]

    # Train on all data from 1980.
    train = df_[df_['date'].dt.year == 1980]

    data = {'all': df_}
    names = ['train', 'train_1980', 'train_1981', 'test_1980', 'test_1981']
    dfs = [train, train_1980, train_1981, test_1980, test_1981]
    # Extract stations per group.
    for name, df in zip(names, dfs):
        per_group = []
        df.groupby('group').apply(lambda x: per_group.append(x.copy()))
        data[name] = per_group

    # Extract dates.
    data['dates_1980'] = [x.iloc[0].group for x in data['train_1980']]
    data['dates_1981'] = [x.iloc[0].group for x in data['train_1981']]

    # Save experiment data.
    with open(cache_tridaily, 'wb') as f:
        pickle.dump(data, f)
