import copy
import torch

from torch.utils.data import Dataset

__all__ = ['ObservationDataset', 'TupleDataset', 'ConditionTupleDataset',
           'MetaTupleDataset']


class ObservationDataset(Dataset):
    def __init__(self, y):
        super().__init__()

        self.y = y

    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, idx):
        y = self.y[idx]

        return y

    def dataset(self):
        idx = list(range(len(self)))

        return self.__getitem__(idx)


class TupleDataset(Dataset):
    def __init__(self, x, y, contains_nan=False):
        super().__init__()

        assert len(x) == len(y), 'x and y must be the same length.'

        if len(x.shape) == 1:
            # Ensure inputs are 2-dimensional.
            self.x = x.unsqueeze(1)
        else:
            self.x = x

        if contains_nan:
            self.y = copy.deepcopy(y)
            self.m = torch.ones_like(y).fill_(True)

            # Identify nan values and replace with 0.
            m_idx = torch.isnan(y)
            self.m[m_idx] = False
            self.y[m_idx] = 0.
        else:
            self.y = y

        self.contains_nan = contains_nan

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x = self.x[idx]
        y = self.y[idx]

        if self.contains_nan:
            m = self.m[idx]
            return x, y, m, idx
        else:
            return x, y, idx

    def dataset(self):
        idx = list(range(len(self)))

        return self.__getitem__(idx)


class ConditionTupleDataset(torch.utils.data.Dataset):
    def __init__(self, x, y, y_c, contains_nan=False):
        super().__init__()

        assert len(x) == len(y), 'x and y must be the same length.'

        if len(x.shape) == 1:
            # Ensure inputs are 2-dimensional.
            self.x = x.unsqueeze(1)
        else:
            self.x = x

        if contains_nan:
            self.y = copy.deepcopy(y)
            self.y_c = copy.deepcopy(y_c)
            self.m = torch.ones_like(y).fill_(True)
            self.m_c = torch.ones_like(y_c).fill_(True)

            # Identify nan values and replace with 0.
            m_idx = torch.isnan(y)
            m_c_idx = torch.isnan(y_c)
            self.m[m_idx] = False
            self.y[m_idx] = 0.
            self.m_c[m_c_idx] = False
            self.y_c[m_c_idx] = 0.
        else:
            self.y = y
            self.y_c = copy.deepcopy(y_c)
            self.m_c = torch.ones_like(y_c).fill_(True)

            # Identify nan values and replace with 0.
            m_c_idx = torch.isnan(y_c)
            self.m_c[m_c_idx] = False
            self.y_c[m_c_idx] = 0.

        self.contains_nan = contains_nan

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x = self.x[idx]
        y = self.y[idx]
        y_c = self.y_c[idx]
        m_c = self.m_c[idx]

        if self.contains_nan:
            m = self.m[idx]
            return x, y, y_c, m, m_c, idx
        else:
            return x, y, y_c, m_c, idx

    def dataset(self):
        idx = list(range(len(self)))

        return self.__getitem__(idx)


class MetaTupleDataset(Dataset):
    def __init__(self, x, y, contains_nan=False):
        super().__init__()

        assert len(x) == len(y), 'x and y must be the same length.'

        self.x = x

        if contains_nan:
            self.y = copy.deepcopy(y)

            self.m = list(map(lambda y_: torch.ones_like(y_).fill_(True), y))

            # Identify nan values and replace with 0.
            m_idx = list(map(lambda y_: torch.isnan(y_), y))

            for i, idx in enumerate(m_idx):
                self.m[i][idx] = False
                self.y[i][idx] = 0.
        else:
            self.y = y

        self.contains_nan = contains_nan

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x = self.x[idx]
        y = self.y[idx]

        if self.contains_nan:
            m = self.m[idx]
            return x, y, m, idx
        else:
            return x, y, idx

    def dataset(self):
        idx = list(range(len(self)))

        return self.__getitem__(idx)
