# from data_provider.data_loader import (Dataset_mekong_target_station,
#                                        Dataset_mekong_for_gnn, Dataset_mekong_stations_list)
from data_provider.data_loader_LamaH import LamaHDataModule
from data_provider.data_loader_camels import CamelsDataModule
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset

data_dict = {
    'LamaH_daily_mts': LamaHDataModule,
    'LamaH_hourly_mts': LamaHDataModule,
    'camels_mts': CamelsDataModule,
}


def collate_fn(batch):
    """
    batch: list of tuples (seq_x, seq_y, seq_x_mark, seq_y_mark, cycle_index, station_id, time_idx)
    """
    seq_x, seq_y, seq_x_mark, seq_y_mark, cycle_index, station_id, time_idx = zip(*batch)
    seq_x = torch.tensor(np.stack(seq_x, axis=0), dtype=torch.float32)
    seq_y = torch.tensor(np.stack(seq_y, axis=0), dtype=torch.float32)
    seq_x_mark = torch.tensor(np.stack(seq_x_mark, axis=0), dtype=torch.float32)
    seq_y_mark = torch.tensor(np.stack(seq_y_mark, axis=0), dtype=torch.float32)
    cycle_index = torch.tensor(np.stack(cycle_index, axis=0), dtype=torch.long)
    if None not in station_id:
        station_id = torch.tensor(station_id, dtype=torch.long)
    time_idx = torch.tensor(time_idx, dtype=torch.long)
    return seq_x, seq_y, seq_x_mark, seq_y_mark, cycle_index, station_id, time_idx


def data_provider(args, verbose=True, data_loader_name=None, target_station=None, station_list=None,
                  batch_flag='mini_batch', is_GNN=False):
    Data = data_dict[data_loader_name]
    timeenc = 0 if args.embed != 'timeF' else 1

    data_module = Data(
        args,
        target_station=target_station,
        station_list=station_list,
        size=[args.seq_len, args.label_len, args.pred_len],
        features=args.features,
        target=args.target,
        scaler_dict=None,
        timeenc=timeenc,
        freq=args.freq,
        batch_flag=batch_flag,
        is_GNN=is_GNN
    )

    # 构造 Dataset
    train_dataset = data_module.get_dataset(flag='train')
    val_dataset = data_module.get_dataset(flag='val')
    test_dataset = data_module.get_dataset(flag='test')

    if verbose:
        print('train', len(train_dataset))
        print('val', len(val_dataset))
        print('test', len(test_dataset))

    # 构造 DataLoader
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=False,
        collate_fn=collate_fn,
        pin_memory=True,       # 加快 CPU→GPU 拷贝
        persistent_workers=True  # 避免 worker 每个 epoch 重启
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=False,
        collate_fn=collate_fn,
        pin_memory=True,
        persistent_workers=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        drop_last=False,
        collate_fn=collate_fn,
        pin_memory=True,
        persistent_workers=True
    )

    return train_dataset, train_loader, val_dataset, val_loader, test_dataset, test_loader
