import numpy as np
import os
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from ml_sampler import MultilabelBalancedRandomSampler
from ml_sampler import MultiClassBalancedRandomSampler
from ml_sampler import BinBalancedRandomSampler
from dataset import ECGDataset


def class_imbalance_sampler(labels):
    class_count = torch.bincount(labels.squeeze())
    class_weighting = 1. / class_count
    sample_weights = class_weighting[labels]
    sampler = WeightedRandomSampler(sample_weights, len(labels))
    return sampler


def fetch_dataloader(mode, df, params):
    """
    Fetches the DataLoader object for each type in types from data_dir.
    Args:
        mode: (list) has one or more of 'train', 'valid', 'test' depending on which data is required
        df: (string)
        params: (Params) hyperparameters
    Returns:
        data: (dict) contains the DataLoader object for each type in types
    """

    path_x = params['dataset']['waveform_dir']
    dataset = ECGDataset(path_x, df, params, mode)

    if mode == 'train':
        if params['train']['balanced_sample']:
            if params['train']['sampler_type'] == 'multi_label_balanced':
                sampler = MultilabelBalancedRandomSampler(dataset.get_mutil_label(), np.arange(len(dataset)), 'cycle')
            elif params['train']['sampler_type'] == 'class_balanced':
                if params['init_set']['task_type'] == 'multi_label' and len(params['preproc']['task']) == 1:
                    num_classes = 2
                else:
                    num_classes = len(params['preproc']['task'])
                sampler = MultiClassBalancedRandomSampler(dataset.get_label(), params['train']['sampler_alpha'],
                                                          num_classes=num_classes)
            elif params['train']['sampler_type'] == 'bin_balanced':
                sampler = BinBalancedRandomSampler(dataset.get_raw_label(), params['train']['sampler_alpha'],
                                                   **params['bin_metric_params']['binning_kwargs'])
            else:
                raise ValueError(f"{params['train']['sampler_type']} is not supported for Sampler.")
            # We might use the following params, but these should be in configurations per experiment:
            # shuffle=True, pin_memory=params['cuda']
            dl = DataLoader(dataset,
                            batch_size=params['train']['batch_size'],
                            sampler=sampler,
                            num_workers=params['init_set']['num_workers'])
        else:
            dl = DataLoader(dataset,
                            batch_size=params['train']['batch_size'],
                            shuffle=True,
                            num_workers=params['init_set']['num_workers'])
    else:
        dl = DataLoader(dataset,
                        batch_size=params['train']['batch_size'],
                        num_workers=params['init_set']['num_workers'],
                        shuffle=False)
    return dl
