from data_provider.data_loader_old import Dataset_ETT_hour, \
    Dataset_ETT_minute, Dataset_Custom, Dataset_Pred
from torch.utils.data import DataLoader
import torch
from torch.nn.utils.rnn import pad_sequence

data_dict = {
    'ETTh1': Dataset_ETT_hour,
    'ETTh2': Dataset_ETT_hour,
    'ETTm1': Dataset_ETT_minute,
    'ETTm2': Dataset_ETT_minute,
    'custom': Dataset_Custom,
}

def default_collate_market(batch_items):
    # batch_items: list of dicts
    # 将 tensor 类型按 key 堆叠，非 tensor 的保持 list
    out = {}
    first = batch_items[0]
    for k in first.keys():
        vals = [b[k] for b in batch_items]
        if isinstance(vals[0], torch.Tensor):
            out[k] = torch.stack(vals, dim=0)
        else:
            out[k] = vals
    return out

from torch.nn.utils.rnn import pad_sequence
import torch

def collate_fn(batch):
    # batch 是 list of tuples: (seq_x, seq_y, seq_x_mark, seq_y_mark, embeddings_tensor)
    num_fields = len(batch[0])
    if num_fields > 5:
        seq_xs, seq_ys, seq_x_marks, seq_y_marks, embeds,\
        seq_xs_flow, seq_ys_flow, seq_x_marks_flow, seq_y_marks_flow, embeds_flow = zip(*batch)
        # seq_x 和 seq_y: list of (L, D) → stack 成 (B, L, D)
        seq_xs_flow = torch.stack([torch.tensor(x, dtype=torch.float32) if not torch.is_tensor(x) else x for x in seq_xs_flow], dim=0)
        seq_ys_flow = torch.stack([torch.tensor(y, dtype=torch.float32) if not torch.is_tensor(y) else y for y in seq_ys_flow], dim=0)

        # seq_x_mark 和 seq_y_mark: list of (L, F) → stack 成 (B, L, F)
        seq_x_marks_flow = torch.stack([torch.tensor(xm, dtype=torch.float32) if not torch.is_tensor(xm) else xm for xm in seq_x_marks_flow], dim=0)
        seq_y_marks_flow = torch.stack([torch.tensor(ym, dtype=torch.float32) if not torch.is_tensor(ym) else ym for ym in seq_y_marks_flow], dim=0)

        # embeddings: list of (news_len, embed_dim) → pad 到同样长度 (B, max_news_len, embed_dim)
        embeds_padded_flow = pad_sequence(embeds_flow, batch_first=True)
    else:
        seq_xs, seq_ys, seq_x_marks, seq_y_marks, embeds = zip(*batch)

    # seq_x 和 seq_y: list of (L, D) → stack 成 (B, L, D)
    seq_xs = torch.stack([torch.tensor(x, dtype=torch.float32) if not torch.is_tensor(x) else x for x in seq_xs], dim=0)
    seq_ys = torch.stack([torch.tensor(y, dtype=torch.float32) if not torch.is_tensor(y) else y for y in seq_ys], dim=0)

    # seq_x_mark 和 seq_y_mark: list of (L, F) → stack 成 (B, L, F)
    seq_x_marks = torch.stack([torch.tensor(xm, dtype=torch.float32) if not torch.is_tensor(xm) else xm for xm in seq_x_marks], dim=0)
    seq_y_marks = torch.stack([torch.tensor(ym, dtype=torch.float32) if not torch.is_tensor(ym) else ym for ym in seq_y_marks], dim=0)

    # embeddings: list of (news_len, embed_dim) → pad 到同样长度 (B, max_news_len, embed_dim)
    embeds_padded = pad_sequence(embeds, batch_first=True)

    if num_fields > 5:

        return seq_xs, seq_ys, seq_x_marks, seq_y_marks, embeds_padded,\
            seq_xs_flow, seq_ys_flow, seq_x_marks_flow, seq_y_marks_flow, embeds_padded_flow

    return seq_xs, seq_ys, seq_x_marks, seq_y_marks, embeds_padded

def data_provider(args, flag):
    Data = data_dict[args.data]
    timeenc = 0 if args.embed != 'timeF' else 1

    if flag == 'test':
        shuffle_flag = False
        drop_last = True
        batch_size = args.batch_size
        freq = args.freq
    elif flag == 'val':
        shuffle_flag = False
        drop_last = True
        batch_size = args.batch_size
        freq = args.freq
    elif flag == 'pred':
        shuffle_flag = False
        drop_last = False
        batch_size = 1
        freq = args.freq
        Data = Dataset_Pred
    else:
        shuffle_flag = True
        drop_last = True
        batch_size = args.batch_size
        freq = args.freq

    market = ['a_share','btc','us']
    market_datasets={}
    for i in range(len(market)):
        data_set = Data(
            root_path=args.root_path,
            domain = market[i],
            data_path=args.data_path,
            flag=flag,
            size=[args.seq_len, args.label_len, args.pred_len],
            features=args.features,
            target=args.target,
            timeenc=timeenc,
            freq=freq
        )
        market_datasets[market[i]]=data_set
     
    data_loader_a_share = DataLoader(
        market_datasets[market[0]],
        batch_size=batch_size,collate_fn=collate_fn,
        shuffle=shuffle_flag,
        num_workers=6,
        drop_last=drop_last)
    # data_loader_btc = DataLoader(
    #     market_datasets[market[1]],
    #     batch_size=batch_size,collate_fn=collate_fn,
    #     shuffle=shuffle_flag,
    #     num_workers=8,
    #     drop_last=drop_last)
    data_loader_us = DataLoader(
        market_datasets[market[2]],
        batch_size=batch_size,collate_fn=collate_fn,
        shuffle=shuffle_flag,
        num_workers=6,
        drop_last=drop_last)
    # data_loader = DataLoader(
    #     data_set,
    #     batch_sampler=sampler, 
    #     collate_fn=default_collate_market,
    #     num_workers=0)
    #     # num_workers=args.num_workers)
    # print("111111")
    
    # return market_datasets, data_loader_a_share,data_loader_btc,data_loader_us
    return market_datasets, data_loader_a_share,data_loader_us
