import torch
import numpy as np
import random


class ContrastiveLoader(torch.utils.data.IterableDataset):
    def __init__(self, x, u, pair_range, batch_size, neg=False, time_step=None, seq_len=None):
        super().__init__()
        self.x = x
        self.u = u
        self.pair_range = pair_range
        self.batch_size = batch_size
        self.num_sample = len(x)
        self.num_iteration = self.num_sample // self.batch_size

        self.neg = neg
        self.time_step = time_step
        self.seq_len = seq_len

        if self.seq_len is None:
            self.choice_idx = np.arange(self.num_sample)
        else:
            self.choice_idx = np.arange(0, self.num_sample, self.time_step)[:, np.newaxis] + np.arange(0, self.time_step - self.seq_len + 1)
            self.choice_idx = self.choice_idx.flatten()
        self.choice_seq = range(len(self.choice_idx))

    def _sample_pair(self):
        if not self.neg:
            reference_idx = self.choice_idx[random.sample(self.choice_seq, self.batch_size)]
        else:
            reference_idx = self.choice_idx[random.sample(self.choice_seq, self.batch_size * 2)]
            negative_idx = reference_idx[self.batch_size:]
            reference_idx = reference_idx[:self.batch_size]
        positive_idx = np.random.randint(self.pair_range[reference_idx, 0], self.pair_range[reference_idx, 1] + 1) + reference_idx

        if self.seq_len is not None:
            reference_idx = reference_idx[:, np.newaxis] + np.arange(self.seq_len)
            positive_idx = positive_idx[:, np.newaxis] + np.arange(self.seq_len)
            if self.neg:
                negative_idx = negative_idx[:, np.newaxis] + np.arange(self.seq_len)
        if not self.neg:
            return self.x[reference_idx, ...], self.x[positive_idx, ...]
        else:
            return self.x[reference_idx, ...], self.x[positive_idx, ...], self.x[negative_idx, ...]

    def __iter__(self):
        for _ in range(self.num_iteration):
            if not self.neg:
                x1, x2 = self._sample_pair()
                yield x1, x2
            else:
                x1, x2, x3 = self._sample_pair()
                yield x1, x2, x3

    def __len__(self):
        return self.num_iteration * self.batch_size


class SequentialLoader(torch.utils.data.IterableDataset):
    def __init__(self, x, u, batch_size, time_step, seq_len, interval=1, cutoff=False):
        super().__init__()
        self.x = x
        self.u = u
        self.batch_size = batch_size
        self.num_sample = len(x)
        self.num_iteration = self.num_sample // self.batch_size

        self.time_step = time_step
        self.seq_len = seq_len
        self.interval = interval
        self.cutoff = cutoff

        self.choice_idx = np.arange(0, self.batch_size * self.interval, self.interval)
        if self.interval == 1 and self.cutoff:
            idx = np.arange(0, self.batch_size * self.interval, self.time_step)[:, np.newaxis] + np.arange(0, self.time_step - self.seq_len + 1)
            idx = idx.flatten()
            self.choice_idx = self.choice_idx[idx]
        else:
            self.num_iteration //= self.interval
        self.choice_idx = self.choice_idx[:, np.newaxis] + np.arange(self.seq_len)

    def __iter__(self):
        for i in range(self.num_iteration):
            idx = self.choice_idx + i * self.batch_size * self.interval
            yield self.x[idx, ...], self.u[idx[:, -1], ...]

    def __len__(self):
        return self.num_iteration * self.batch_size


class ShuffleSequentialLoader(torch.utils.data.IterableDataset):
    def __init__(self, x, u, batch_size, time_step, seq_len):
        super().__init__()
        self.x = x
        self.u = u
        self.batch_size = batch_size
        self.num_sample = len(x)
        self.num_iteration = self.num_sample // self.batch_size

        self.time_step = time_step
        self.seq_len = seq_len

        self.choice_idx = np.arange(0, self.num_sample, self.time_step)[:, np.newaxis] + np.arange(0, self.time_step - self.seq_len + 1)
        self.choice_idx = self.choice_idx.flatten()
        self.choice_seq = range(len(self.choice_idx))

    def __iter__(self):
        for i in range(self.num_iteration):
            idx = self.choice_idx[random.sample(self.choice_seq, self.batch_size)]
            idx = idx[:, np.newaxis] + np.arange(self.seq_len)
            yield self.x[idx, ...], self.u[idx[:, -1], ...]

    def __len__(self):
        return self.num_iteration * self.batch_size
