import os

import numpy as np
import pandas as pd
import torch

from utils import get_logger, to_float, handle_labels
from torchvision import transforms


def get_datasets(args):
    logger = get_logger(__name__, args)
    # get the training and validation datasets

    # get the list of training patches: [(path_to_patch, state), ...]
    training_patches = []
    for state in args.training_states:
        logger.info("Adding training patches from %s" % state)
        fn = os.path.join(args.data_dir, "%s_extended-train_patches.csv" % state)
        if not os.path.isfile(fn):
            fn = os.path.join(args.data_dir, "%s-train_patches.csv" % state)
        df = pd.read_csv(fn)
        for fn in df["patch_fn"].values:
            training_patches.append((os.path.join(args.data_dir, fn), state))

    # get the list of validation patches: [(path_to_patch, state), ...]
    validation_patches = []
    for state in args.validation_states:
        logger.info("Adding validation patches from %s" % state)
        fn = os.path.join(args.data_dir, "%s_extended-val_patches.csv" % state)
        if not os.path.isfile(fn):
            fn = os.path.join(args.data_dir, "%s-val_patches.csv" % state)
        df = pd.read_csv(fn)
        for fn in df["patch_fn"].values:
            validation_patches.append((os.path.join(args.data_dir, fn), state))
    
    # get the list of validation patches: [(path_to_patch, state), ...]
    test_patches = []
    for state in args.validation_states:
        logger.info("Adding test patches from %s" % state)
        fn = os.path.join(args.data_dir, "%s_extended-test_patches.csv" % state)
        df = pd.read_csv(fn)
        for fn in df["patch_fn"].values:
            test_patches.append((os.path.join(args.data_dir, fn), state))

    logger.info(
        "Loaded %d training patches and %d validation patches"
        % (len(training_patches), len(validation_patches) + len(test_patches))
    )

    training_dataset = Dataset(args, training_patches)
    validation_dataset = Dataset(args, validation_patches + test_patches, args.seed)

    return training_dataset, validation_dataset


def color_aug(colors, r=0.05):
    n_ch = colors.shape[0]
    contra_adj = r
    bright_adj = r

    ch_mean = np.mean(colors, axis=(-1, -2), keepdims=True).astype(np.float32)

    contra_mul = np.random.uniform(1 - contra_adj, 1 + contra_adj, (n_ch, 1, 1)).astype(
        np.float32
    )
    bright_mul = np.random.uniform(1 - bright_adj, 1 + bright_adj, (n_ch, 1, 1)).astype(
        np.float32
    )

    colors = (colors - ch_mean) * contra_mul + ch_mean * bright_mul
    colors = np.clip(colors, 0, 1)
    return colors


class Dataset(torch.utils.data.Dataset):
    def __init__(self, args, patch_list, seed=None):
        super(Dataset).__init__()
        self.args = args
        if seed is not None:
            np.random.seed(seed)
        np.random.shuffle(patch_list)
        self.patches = patch_list

        # duplicate the patch list to allow label overloading
        if args.do_label_overloading:
            self.patches.extend(patch_list)
        # self.preprocess  = transforms.Compose([
        #     torch.from_numpy,
        #     transforms.Normalize(mean=[0.4620475 , 0.5123384 ,0.47723344, 0.63885415], std=[0.15355177, 0.14804325, 0.09460772, 0.2364038 ]),])


    def __getitem__(self, index):
        fn, state = self.patches[index]

        # read data from file
        if fn.endswith(".npz"):
            dl = np.load(fn)
            data = dl["arr_0"].squeeze()
            dl.close()
        else:
            data = np.load(fn).squeeze()

        # pytorch assume channel goes first, like N * C * H * W, stop roll axises
        # data = np.rollaxis(data, 0, 3)

        # do a random crop if input_size is less than the prescribed size
        assert data.shape[1] == data.shape[2]
        data_size = data.shape[1]
        if self.args.input_size < data_size:
            x_idx = np.random.randint(0, data_size - self.args.input_size)
            y_idx = np.random.randint(0, data_size - self.args.input_size)
            data = data[:, y_idx: y_idx + self.args.input_size, x_idx: x_idx + self.args.input_size]

        # generate x and convert image data to float, do label overloading if indicated
        if self.args.do_label_overloading and index >= len(self.patches) / 2:
            x = to_float(data[4: 4 + self.args.input_nchannels, :, :], self.args.data_type)
        else:
            x = to_float(data[: self.args.input_nchannels, :, :], self.args.data_type)

        # color augmentation
        if self.args.do_color:
            x = color_aug(x, self.args.color_augmentation_intensity)

        # generate y_hr label, transform the label if indicated
        if self.args.hr_label_key:
            y_hr = handle_labels(
                data[self.args.hr_label_index, :, :], self.args.hr_label_key
            )
        else:
            y_hr = data[self.args.hr_label_index, :, :]
        y_hr = y_hr.astype(np.long)

        # generate y_lr label and transform if use superres loss, return [0] if not use
        if self.args.loss.startswith("superres"):
            if self.args.lr_label_key:
                y_lr = handle_labels(
                    data[self.args.lr_label_index, :, :], self.args.lr_label_key
                )
            else:
                y_lr = data[self.args.lr_label_index, :, :]
            y_lr = y_lr.astype(np.long)
        else:
            y_lr = np.array([0])

        y_lr = y_lr.astype(np.long)

        return x, y_hr

    def __len__(self):
        return len(self.patches)
