from pathlib import Path
import numpy as np
import torch
from torch_geometric.loader import DataLoader
from ood_dataset import HEP_Pileup_Shift, HEP_Signal_Shift, QMOF, Drug3d


def get_data_loaders(dataset_name, batch_size, data_config, seed, shift_config=None, **kwargs):
    root_dir = f'/dataset/{dataset_name}/'

    if dataset_name in ['actstrack', 'tau3mu']:
        shift_name = shift_config['shift_name']

        if shift_name == 'pileup_shift':
            pileup_train = shift_config['pileup_train']
            pileup_val = shift_config['pileup_val']
            setting = shift_config['setting']
            restrict_train = shift_config['restrict_TL_train']
            dataset = HEP_Pileup_Shift(root_dir, data_config, seed, pileup_train, pileup_val, setting, restrict_train)

        elif shift_name == 'signal_shift':
            setting = shift_config['setting']
            pileup = shift_config['pileup']
            target = shift_config['target']
            restrict_train = shift_config['restrict_TL_train']
            dataset = HEP_Signal_Shift(root_dir, data_config, seed, pileup, setting, target,
                                       restrict_TL_train=restrict_train)
        loaders = get_ood_data_loader(batch_size, dataset=dataset, idx_split=dataset.idx_split, setting=setting)

    return loaders, dataset


def get_ood_data_loader(batch_size, dataset, idx_split, setting):
    data_loader = dict()
    for item in idx_split.keys():
        shuffling = True if item.split('_')[0] == 'train' else False
        drop_last = True if (item.split('_')[0] == 'train' and setting == 'DA') else False
        batch_size = 32 if item in ['ood_val', 'ood_test'] and dataset.dataset_name == 'actstrack' else batch_size
        loader = DataLoader(dataset[idx_split[item]], batch_size=batch_size, shuffle=shuffling, follow_batch=None,
                            drop_last=drop_last, num_workers=0)
        data_loader[item] = loader
    return data_loader


