"""
Load data for Monash benchmarks from monash_data.py
"""
import numpy as np
import matplotlib.pyplot as plt

from dataloaders.monash_data import MonashDataset


def load_data(config_dataset, config_loader, args=None):
    data_dir = '/dfs/scratch1/common/public-datasets/monash'
    dataset = MonashDataset(data_dir=data_dir, **config_dataset)
    dataset.init()
    dataset.setup()
    
    train_loader = dataset.train_dataloader(**config_loader)
    # Eval loaders are dictionaries where key is resolution, value is dataloader
    # For now just set resolution to 1
    val_loader   = dataset.val_dataloader(**config_loader)[None]
    test_loader  = dataset.test_dataloader(**config_loader)[None]
    
    if args is not None:
        args.horizon = train_loader.dataset.forecast_horizon
        args.lag     = train_loader.dataset.lag
    
    return (train_loader, val_loader, test_loader), dataset


def visualize_data(dataloaders, splits=['train', 'val', 'test'],
                   save=False, args=None, title=None):
    ts_idx = 0  # Just visualize first time series
    
    assert len(splits) == len(dataloaders)
    start_idx = 0
    for idx, split in enumerate(splits):
        y = dataloaders[idx].dataset.data[ts_idx]
        x = np.arange(len(y)) + start_idx
        plt.plot(x, y, label=split)
        start_idx += len(x)
    title = f'Monash {args.variant}' if title is None else title
    plt.title(title)
    plt.legend()
    plt.show()
    