import os
import pandas as pd
import numpy as np
from utils.dataset import Dataset_Custom
from torch.utils.data import DataLoader

data_type_param = {
    'train':    { 'shuffle_flag': True, 'drop_last': True, 'batch_size': 32 },
    'val':      { 'shuffle_flag': False, 'drop_last': True, 'batch_size': 32 },
    'test':     { 'shuffle_flag': False, 'drop_last': True, 'batch_size': 32 },
    'pred':     { 'shuffle_flag': False, 'drop_last': False, 'batch_size': 1 }
}

data_type_packet = {
    "weather": "csv",
    "traffic": "csv",
    "ETTh1": "csv",
    "ETTh2": "csv",
    "ETTm1": "csv",
    "ETTm2": "csv",
    "ILI": "csv",
    "exchange_rate": "csv",
    "electricity": "csv",
    "METR-LA": "npz",
    "PEMS-BAY": "npz"
}


def construct_borders(length:int, seq_len:int=336, pred_len:int=48, dataset_type:str='custom'):
    if dataset_type == 'custom':
        num_train = int(length * 0.7)
        num_test = int(length * 0.2)
        num_vali = length - num_train - num_test
        border1s = [0,          num_train - seq_len,    length - num_test - seq_len,    length - seq_len - pred_len]
        border2s = [num_train,  num_train + num_vali,   length,                         length]
    elif dataset_type == 'ETTh1' or dataset_type == 'ETTh2':
        border1s = [0,              12 * 30 * 24 - seq_len,     12 * 30 * 24 + 4 * 30 * 24 - seq_len,   length - seq_len - pred_len]
        border2s = [12 * 30 * 24,   12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24,             length]
    elif dataset_type == 'ETTm1' or dataset_type == 'ETTm2':
        border1s = [0,                  12 * 30 * 24 * 4 - seq_len,         12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - seq_len,   length - seq_len - pred_len]
        border2s = [12 * 30 * 24 * 4,   12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4,             length]
    return border1s, border2s

def load_data(args):
    # 选择数据加载器类
    FileType = data_type_packet[args.dataset_name]
    timeenc = 0 if args.embed != 'timeF' else 1
    data = {}
    if FileType == "csv":
        # df_raw.columns: ['date', ...(other features), target feature]
        df_raw = pd.read_csv(os.path.join(args.root_path, args.data_path))
        
        cols = list(df_raw.columns)
        cols.remove(args.target)
        cols.remove('date')
        cols = cols[0:args.enc_in-1]
        df_raw = df_raw[['date'] + cols + [args.target]]
        border1s, border2s = construct_borders(len(df_raw), args.seq_len, args.pred_len, args.dataset_type)
        for catagory in ['train', 'val', 'test', 'pred']:
            # if catagory == 'pred':
            data[catagory] = Dataset_Custom(
                raw_data=df_raw, 
                border1s=border1s, 
                border2s=border2s,
                flag=catagory,
                size=[args.seq_len, args.label_len, args.pred_len],
                features=args.features,
                target=args.target,
                timeenc=timeenc,
                freq=args.freq
            )
            # else:
            #     data[catagory] = Dataset_Pred(
            #         raw_data=df_raw, 
            #         border1s=border1s, 
            #         border2s=border2s,
            #         flag=catagory,
            #         size=[args.seq_len, args.label_len, args.pred_len],
            #         features=args.features,
            #         target=args.target,
            #         timeenc=timeenc,
            #         freq=args.freq
            #     )
            batch_size = 1 if catagory=='pred' else args.batch_size
            data[catagory+"_loader"] = DataLoader(
                data[catagory],
                batch_size=batch_size,
                shuffle=data_type_param[catagory]['shuffle_flag'],
                num_workers=args.num_workers,
                drop_last=data_type_param[catagory]['drop_last'],
                pin_memory=True if args.use_multi_gpu else False
            )

            print(catagory, len(data[catagory]))

    if args.use_gcn:
        df_train = df_raw[cols+[args.target]][border1s[0]:border2s[0]]
        corr = df_train.corr()
        high_correlated_count = np.sum(corr > 0.8, axis=1) - 1
        return data, np.array(corr), np.array(high_correlated_count)
    else:
        return data, None, None
