from data_provider.data_loader_ele import Dataset_ETT_hour, \
    Dataset_ETT_minute, Dataset_Custom, Dataset_Pred,Dataset_Custom_Finance
from torch.utils.data import DataLoader,Subset
import torch
from torch.nn.utils.rnn import pad_sequence
import numpy as np
data_dict = {
    'ETTh1': Dataset_ETT_hour,
    'ETTh2': Dataset_ETT_hour,
    'ETTm1': Dataset_ETT_minute,
    'ETTm2': Dataset_ETT_minute,
    'custom': Dataset_Custom,
    'finance':Dataset_Custom_Finance
}

def default_collate_market(batch_items):
    # batch_items: list of dicts
    # 将 tensor 类型按 key 堆叠，非 tensor 的保持 list
    out = {}
    first = batch_items[0]
    for k in first.keys():
        vals = [b[k] for b in batch_items]
        if isinstance(vals[0], torch.Tensor):
            out[k] = torch.stack(vals, dim=0)
        else:
            out[k] = vals
    return out

from torch.nn.utils.rnn import pad_sequence
import torch

def collate_fn(batch):
    # batch 是 list of tuples: (seq_x, seq_y, seq_x_mark, seq_y_mark, embeddings_tensor)
    num_fields = len(batch[0])
    if num_fields > 5:
        seq_xs, seq_ys, seq_x_marks, seq_y_marks, embeds,\
        seq_xs_flow, seq_ys_flow, seq_x_marks_flow, seq_y_marks_flow, embeds_flow = zip(*batch)
        # seq_x 和 seq_y: list of (L, D) → stack 成 (B, L, D)
        seq_xs_flow = torch.stack([torch.tensor(x, dtype=torch.float32) if not torch.is_tensor(x) else x for x in seq_xs_flow], dim=0)
        seq_ys_flow = torch.stack([torch.tensor(y, dtype=torch.float32) if not torch.is_tensor(y) else y for y in seq_ys_flow], dim=0)

        # seq_x_mark 和 seq_y_mark: list of (L, F) → stack 成 (B, L, F)
        seq_x_marks_flow = torch.stack([torch.tensor(xm, dtype=torch.float32) if not torch.is_tensor(xm) else xm for xm in seq_x_marks_flow], dim=0)
        seq_y_marks_flow = torch.stack([torch.tensor(ym, dtype=torch.float32) if not torch.is_tensor(ym) else ym for ym in seq_y_marks_flow], dim=0)

        # embeddings: list of (news_len, embed_dim) → pad 到同样长度 (B, max_news_len, embed_dim)
        embeds_padded_flow = pad_sequence(embeds_flow, batch_first=True)
    else:
        seq_xs, seq_ys, seq_x_marks, seq_y_marks, embeds = zip(*batch)

    # seq_x 和 seq_y: list of (L, D) → stack 成 (B, L, D)
    seq_xs = torch.stack([torch.tensor(x, dtype=torch.float32) if not torch.is_tensor(x) else x for x in seq_xs], dim=0)
    seq_ys = torch.stack([torch.tensor(y, dtype=torch.float32) if not torch.is_tensor(y) else y for y in seq_ys], dim=0)

    # seq_x_mark 和 seq_y_mark: list of (L, F) → stack 成 (B, L, F)
    seq_x_marks = torch.stack([torch.tensor(xm, dtype=torch.float32) if not torch.is_tensor(xm) else xm for xm in seq_x_marks], dim=0)
    seq_y_marks = torch.stack([torch.tensor(ym, dtype=torch.float32) if not torch.is_tensor(ym) else ym for ym in seq_y_marks], dim=0)

    # embeddings: list of (news_len, embed_dim) → pad 到同样长度 (B, max_news_len, embed_dim)
    embeds_padded = pad_sequence(embeds, batch_first=True)

    if num_fields > 5:

        return seq_xs, seq_ys, seq_x_marks, seq_y_marks, embeds_padded,\
            seq_xs_flow, seq_ys_flow, seq_x_marks_flow, seq_y_marks_flow, embeds_padded_flow

    return seq_xs, seq_ys, seq_x_marks, seq_y_marks, embeds_padded

def data_provider(args, flag):
    Data = data_dict[args.data]
    timeenc = 0 if args.embed != 'timeF' else 1

    if flag == 'test':
        shuffle_flag = False
        drop_last = True
        batch_size = args.batch_size
        freq = args.freq
    elif flag == 'val':
        shuffle_flag = False
        drop_last = True
        batch_size = args.batch_size
        freq = args.freq
    elif flag == 'pred':
        shuffle_flag = False
        drop_last = False
        batch_size = 1
        freq = args.freq
        Data = Dataset_Pred
    else:
        shuffle_flag = True
        drop_last = True
        batch_size = args.batch_size
        freq = args.freq

    
    data_set = Data(
        root_path=args.root_path,
        # domain = market[i],
        data_path=args.data_path,
        flag=flag,
        size=[args.seq_len, args.label_len, args.pred_len],
        features=args.features,
        target=args.target,
        timeenc=timeenc,
        freq=freq
    )
    
    
    data_loader = DataLoader(
        data_set,
        batch_size=batch_size,
        shuffle=shuffle_flag,
        num_workers=args.num_workers,
        drop_last=drop_last)
    
    # return market_datasets, data_loader_a_share,data_loader_btc,data_loader_us
    return data_set, data_loader
