import os
import pickle
import pandas as pd
import numpy as np

from .data import data_path, resource

__all__ = ['load']


# Keep fixed for reproducibility.
np.random.seed(1)
cache_weekly = data_path('japan_daily', 'monthly', 'experiment.pickle')


def load():
    _fetch()

    # If path does not exists, generate DataFrames.
    if not os.path.exists(cache_weekly):
        _parse()

    with open(cache_weekly, 'rb') as f:
        return pickle.load(f)


def _fetch():
    # All data.
    resource(target=data_path('japan_daily', 'japan_80s.csv'),
             url='https://www.dropbox.com/s/kokgr6ekb0gh7wq/japan_80s.csv?dl'
                 '=1')


def _parse():
    df = pd.read_csv(data_path('japan_daily', 'japan_80s.csv'))

    # Extract 1980 and 1981 data.
    df_ = df[(df['date'].str[:4] == '1980') | (df['date'].str[:4] == '1981')]

    # Rescale TAVG.
    df_['TAVG'] = df_['TAVG'] / 10

    # Convert dates to Datetime.
    df_['date'] = pd.DatetimeIndex(df_['date'])

    # Create monthly groups.
    df_['group'] = df_['date'].dt.month + (df_['date'].dt.year
                                           - df_.iloc[0]['date'].year) * 12
    df_['time'] = df_['date'].dt.day

    # Set either TMIN, TMAX or TAVG for middle two weeks to missing.
    observations = ['TAVG']
    test_idx = [[] for _ in observations]
    for month, monthly_df in df_.groupby('group'):
        # Get index of middle two weeks.
        mid_df = monthly_df[monthly_df['time'].isin(
            [8, 9, 10, 11, 12, 13, 14, 15, 16, 17 ,18, 19, 20, 21])]
        test_idx[np.random.randint(len(observations))] += list(mid_df.index)

    # Train indices are all those that aren't in test indeces.
    train_idx = [list(set(df_.index.tolist()) - set(idx)) for idx in test_idx]

    train = df_.copy()
    test = df_.copy()

    for obs, train_idx_, test_idx_ in zip(observations, train_idx, test_idx):
        train[obs][test_idx_] = np.nan
        test[obs][train_idx_] = np.nan

    # Set all other nans in test df.
    for obs in ['PRCP', 'SNWD', 'TMIN', 'TMAX']:
        test[obs][:] = np.nan

    train_1980 = train[train['date'].dt.year == 1980]
    train_1981 = train[train['date'].dt.year == 1981]
    test_1980 = test[test['date'].dt.year == 1980]
    test_1981 = test[test['date'].dt.year == 1981]

    # Train on all data from 1980.
    train = df_[df_['date'].dt.year == 1980]

    data = {'all': df_}
    names = ['train', 'train_1980', 'train_1981', 'test_1980', 'test_1981']
    dfs = [train, train_1980, train_1981, test_1980, test_1981]
    # Extract stations per group.
    for name, df in zip(names, dfs):
        per_group = []
        df.groupby('group').apply(lambda x: per_group.append(x.copy()))
        data[name] = per_group

    # Extract dates.
    data['dates_1980'] = [x.iloc[0].group for x in data['train_1980']]
    data['dates_1981'] = [x.iloc[0].group for x in data['train_1981']]

    # Save experiment data.
    if not os.path.exists(os.path.dirname(cache_weekly)):
        os.makedirs(os.path.dirname(cache_weekly))

    with open(cache_weekly, 'wb') as f:
        pickle.dump(data, f)
