import argparse
import ruamel_yaml as yaml
from pathlib import Path
import os
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset


class MMTSFMDataset(Dataset):
    def __init__(self, config, dataset_type):
        if dataset_type == 'train':   
            ann_file = config['train_file']
        elif dataset_type == 'val':
            ann_file = config['val_file']
        self.ann = torch.load(ann_file)
        self.max_patchnum = config['max_patchnum']
        self.patch_size = config['token_size']
        self.patch_overlap = config['patch_overlap']
        self.output_size = config['output_size']

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        
        ann = self.ann[index]
        ts_x = ann['ts_x']
        ts_y = ann['ts_y']
        text_h = ann['history_text']
        text_f = ann['future_text']

        return ts_x, ts_y, text_h, text_f


class MMTSFMDatasetModule(LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size_train = self.config['batch_size_train']
        self.batch_size_val = self.config['batch_size_val']

    def setup(self, stage=None):
        # 设置数据集
        print("Creating forecast dataset")
        self.train_dataset = MMTSFMDataset(self.config, 'train')
        self.val_dataset = MMTSFMDataset(self.config, 'val')

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, 
            batch_size=self.batch_size_train,
            num_workers=4,
            pin_memory=True,
            sampler=None,
            shuffle=True, 
            collate_fn=None,
            drop_last=True,
            )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, 
            batch_size=self.batch_size_val, 
            num_workers=4,
            pin_memory=True,
            sampler=None,
            shuffle=False, 
            collate_fn=None,
            drop_last=False,
            )


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='./configs/forecast.yaml', help='Configuration of MMTSFM')
    parser.add_argument('--output_dir', default='./output/forecast')
    args = parser.parse_args()

    config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)

    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))

    data_module = MMTSFMDatasetModule(config)
    data_module.setup()

    train_loader = data_module.train_dataloader()
    val_loader = data_module.val_dataloader()

    for ts_x, ts_y, text_h, text_f in train_loader:
        print(ts_x.shape, ts_y.shape)
        print(text_h, text_f)
        break
