import torch
import numpy as np

from domainbed.datasets import datasets
from domainbed.lib import misc
from domainbed.datasets import transforms as DBT


def set_transfroms(dset, data_type, hparams, algorithm_class=None):
    """
    Args:
        data_type: ['train', 'valid', 'test', 'mnist']
    """
    assert hparams["data_augmentation"]

    additional_data = False
    if data_type == "train":
        dset.transforms = {"x": DBT.aug}
        additional_data = True
    elif data_type == "valid":
        if hparams["val_augment"] is False:
            dset.transforms = {"x": DBT.basic}
        else:
            # Originally, DomainBed use same training augmentation policy to validation.
            # We turn off the augmentation for validation as default,
            # but left the option to reproducibility.
            dset.transforms = {"x": DBT.aug}
    elif data_type == "test":
        dset.transforms = {"x": DBT.basic}
    elif data_type == "mnist":
        # No augmentation for mnist
        dset.transforms = {"x": lambda x: x}
    else:
        print(data_type)
        raise ValueError(data_type)

    if additional_data and algorithm_class is not None:
        for key, transform in algorithm_class.transforms.items():
            dset.transforms[key] = transform


def get_dataset(test_envs, args, hparams, algorithm_class=None, infer=False, check_teacher=None):
    """Get dataset and split."""
    # if infer, will not shuffle.
    is_mnist = "MNIST" in args.dataset
    is_DomED = 'DomED' in args.dataset
        
    if is_DomED:
        dataset = vars(datasets)[args.dataset](args.data_dir1, args.data_dir2, test_envs) 
    else:
        dataset = vars(datasets)[args.dataset](args.data_dir1)
    #  if not isinstance(dataset, MultipleEnvironmentImageFolder):
    #      raise ValueError("SMALL image datasets are not implemented (corrupted), for transform.")

    in_splits = []
    out_splits = []
    if infer:
        #breakpoint()
        test_dataset = dataset[test_envs[0]]
        keys = list(range(len(test_dataset)))
        in_ = _SplitDataset(test_dataset, keys, return_z=False)
        in_type = "test"
        set_transfroms(in_, in_type, hparams, algorithm_class)
        out = []
        in_splits.append((in_, None))
        out_splits.append((out, None))
    else:            
        for env_i, env in enumerate(dataset):
            # The split only depends on seed_hash (= trial_seed).
            # It means that the split is always identical only if use same trial_seed,
            # independent to run the code where, when, or how many times.
            if is_DomED:
                if env_i in test_envs:
                    out, in_ = split_dataset(
                        env,
                        int(len(env) * args.holdout_fraction),
                        misc.seed_hash(args.trial_seed, env_i),
                        return_z = False)
                    in_type = "test"
                    out_type = "test"
                else:
                    out, in_ = split_dataset(
                        env,
                        int(len(env) * args.holdout_fraction),
                        misc.seed_hash(args.trial_seed, env_i),
                        return_z = True)
                    in_type = "train"
                    out_type = "valid"
            else:
                out, in_ = split_dataset(
                    env,
                    int(len(env) * args.holdout_fraction),
                    misc.seed_hash(args.trial_seed, env_i),
                    return_z = False)
                if env_i in test_envs:
                    in_type = "test"
                    out_type = "test"
                else:
                    in_type = "train"
                    out_type = "valid"

                if is_mnist:
                    in_type = "mnist"
                    out_type = "mnist"
                    
            set_transfroms(in_, in_type, hparams, algorithm_class)
            set_transfroms(out, out_type, hparams, algorithm_class)

            if hparams["class_balanced"]:
                in_weights = misc.make_weights_for_balanced_classes(in_)
                out_weights = misc.make_weights_for_balanced_classes(out)
            else:
                in_weights, out_weights = None, None
            in_splits.append((in_, in_weights))
            out_splits.append((out, out_weights))
    

    return dataset, in_splits, out_splits


class _SplitDataset(torch.utils.data.Dataset):
    """Used by split_dataset"""

    def __init__(self, underlying_dataset, keys, return_z=False):
        super(_SplitDataset, self).__init__()
        self.underlying_dataset = underlying_dataset
        self.keys = keys
        self.transforms = {}
        
        self.return_z = return_z
        self.direct_return = isinstance(underlying_dataset, _SplitDataset)

    def __getitem__(self, key):
        if self.direct_return:
            return self.underlying_dataset[self.keys[key]]

        
        if self.return_z:
            x, y, z = self.underlying_dataset[self.keys[key]]
            ret = {"x":x, "y": y, "z": z}
        else:
            x, y = self.underlying_dataset[self.keys[key]]
            ret = {"x":x, "y": y}
            

        for key, transform in self.transforms.items():
            ret[key] = transform(x)

        return ret

    def __len__(self):
        return len(self.keys)



def split_dataset(dataset, n, seed=0, return_z=False, return_w=False):
    """
    Return a pair of datasets corresponding to a random split of the given
    dataset, with n datapoints in the first dataset and the rest in the last,
    using the given random seed
    """
    assert n <= len(dataset)
    keys = list(range(len(dataset)))
    np.random.RandomState(seed).shuffle(keys)
    keys_1 = keys[:n]
    keys_2 = keys[n:]
    return _SplitDataset(dataset, keys_1, return_z=return_z), _SplitDataset(dataset, keys_2, return_z=return_z)
