import numpy as np
import torch

from ATAC.util import torchify
from mbrl.types import TransitionBatch

DEFAULT_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def sample_ensemble_batch(dataset, batch_size, ensemble_size):
    k = list(dataset.keys())[0]
    n, device = len(dataset[k]), DEFAULT_DEVICE  # dataset[k].device
    for v in dataset.values():
        assert len(v) == n, "Dataset values must have same length"
    indices = np.random.randint(low=0, high=n, size=(ensemble_size, batch_size,))  # , device=device)
    return {k: torchify(v[indices]) for k, v in dataset.items()}

def batch_iterator(dataset, batch_size, permute_indices=False):
    k = list(dataset.keys())[0]
    n, device = len(dataset[k]), DEFAULT_DEVICE  # dataset[k].device
    for v in dataset.values():
        assert len(v) == n, "Dataset values must have same length"
    if permute_indices:
        rand_idx = np.random.permutation(n)
        rand_idx = np.array(rand_idx)
    else:
        rand_idx = np.random.choice(
            n, size=(n),
            replace=True
        )

    num_batches = int(n // batch_size)
    if int(n % batch_size) > 0:
        num_batches += 1
    for mb in range(num_batches):
        start_idx = mb * batch_size
        end_idx = min((mb+1)*batch_size, n)
        indices = rand_idx[start_idx:end_idx]
        yield {k: torchify(v[indices]) for k, v in dataset.items()}


def ensemble_batch_iterator(dataset, batch_size, ensemble_size, permute_indices=False):
    k = list(dataset.keys())[0]
    n, device = len(dataset[k]), DEFAULT_DEVICE  # dataset[k].device
    for v in dataset.values():
        assert len(v) == n, "Dataset values must have same length"
    if permute_indices:
        rand_idx = [np.random.permutation(n) for _ in range(ensemble_size)]
        rand_idx = np.array(rand_idx)
    else:
        rand_idx = np.random.choice(
            n, size=(ensemble_size, n),
            replace=True
        )

    num_batches = int(n // batch_size)
    if int(n % batch_size) > 0:
        num_batches += 1
    for mb in range(num_batches):
        start_idx = mb * batch_size
        end_idx = min((mb+1)*batch_size, n)
        indices = rand_idx[:, start_idx:end_idx]
        yield {k: torchify(v[indices]) for k, v in dataset.items()}


# def to_transition_batch(observations, actions, next_observations, rewards, terminals):
def to_transition_batch(batch):
    observations, actions, next_observations, rewards, terminals = \
        batch['observations'], batch['actions'], batch['next_observations'],\
        batch['rewards'], batch['terminals']

    return TransitionBatch(
        obs = observations,
        act = actions,
        next_obs = next_observations,
        rewards = rewards,
        dones = terminals)

def train_val_split(dataset, val_ratio):
    k = list(dataset.keys())[0]
    n, device = len(dataset[k]), DEFAULT_DEVICE  # dataset[k].device
    for v in dataset.values():
        assert len(v) == n, "Dataset values must have same length"
    rand_idx = np.random.permutation(n)
    rand_idx = np.array(rand_idx)

    num_samples_val = int(val_ratio * n)
    num_samples_train = n - num_samples_val

    train_idxs = rand_idx[0:num_samples_train]
    val_idxs = rand_idx[num_samples_train:]

    train_dataset = {k: v[train_idxs] for k, v in dataset.items()}
    val_dataset = {k: v[val_idxs] for k, v in dataset.items()}

    return train_dataset, val_dataset