import os
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import h5py

from utils import get_logger, to_float, handle_labels


def get_hdf5_datasets(args):
    logger = get_logger(__name__, args)

    training_patches = None
    assert len(args.training_states) == 1
    for state in args.training_states:
        logger.info("Adding training patches from %s" % state)
        fn = (Path(args.data_dir) / Path(f"{state}_train.hdf5")).__str__()
        with h5py.File(fn, 'r') as f:
            # ds = np.array(f["dataset"])
            ds = f["dataset"]
            # if training_patches is None:
                # training_patches = ds
            # else:
                # training_patches = np.concatenate((training_patches, ds), axis=0)
            training_patches = ds

    val_patches = None
    for state in args.validation_states:
        logger.info("Adding validation patches from %s" % state)
        fn = (Path(args.data_dir) / Path(f"{state}_val.hdf5")).__str__()
        with h5py.File(fn, 'r') as f:
            # ds = np.array(f["dataset"])
            ds = f["dataset"]
            # if val_patches is None:
                # val_patches = ds
            # else:
                # val_patches = np.concatenate((val_patches, ds), axis=0)
            val_patches = ds

    logger.info(
        "Loaded %d training patches and %d validation patches"
        % (training_patches.shape[0], val_patches.shape[0])
    )

    training_dataset = Dataset(args, training_patches)
    validation_dataset = Dataset(args, val_patches, args.seed)

    return training_dataset, validation_dataset


class Dataset(torch.utils.data.Dataset):
    def __init__(self, args, patches, seed=None):
        super(Dataset).__init__()
        self.args = args
        self.patches = patches

        # if seed is not None:
        #     np.random.seed(seed)
        # np.random.shuffle(self.patches)

    def __getitem__(self, idx):
        x = self.patches[idx, 0:4, :, :]
        y_hr = self.patches[idx, 4, :, :]
        y_hr = y_hr.astype(np.long)

        # generate y_lr label and transform if use superres loss, return [0] if not use
        y_lr = np.array([0]).astype(np.long)

        return x, y_hr, y_lr

    def __len__(self):
        return len(self.patches)
