"""
Load data for National Illness from informer.py
"""
import numpy as np
import matplotlib.pyplot as plt

from dataloaders.informer import Traffic


def load_data(config_dataset, config_loader, args=None):
    dataset = Traffic(**config_dataset)
    dataset.setup()
    
    train_loader = dataset.train_dataloader(**config_loader)
    # Eval loaders are dictionaries where key is resolution, value is dataloader
    # For now just set resolution to 1
    val_loader   = dataset.val_dataloader(**config_loader)[None]
    test_loader  = dataset.test_dataloader(**config_loader)[None]
    
    return (train_loader, val_loader, test_loader), dataset


def visualize_data(dataloaders, splits=['train', 'val', 'test'],
                   save=False, args=None, title=None):
    assert len(splits) == len(dataloaders)
    start_idx = 0
    for idx, split in enumerate(splits):
        y = dataloaders[idx].dataset.data_x
        x = np.arange(len(y)) + start_idx
        plt.plot(x, y, label=split)
        start_idx += len(x)
    title = 'ETTh' if title is None else title
    plt.title(title)
    plt.legend()
    plt.show()