import pickle
import numpy as np
import torch
import feature.util as util

# Define device
if torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

DNA_ALPHABET = "ACGT"
PROTEIN_ALPHABET = "ACDEFGHIKLMNPQRSTVWY"

class AMPSeqLoader:
    def __init__(self, seq_paths):
        """
        Create a PyTorch IterableDataset which yields sequences from a pickled
        file.
        Arguments:
            `seq_path`: path or list of paths to pickle file(s) containing raw
                sequences
        """
        super().__init__()

        if type(seq_paths) is str:
            seq_paths = [seq_paths]
        
        seqs = []
        for path in seq_paths:
            with open(path, "rb") as f:
                seqs.append(pickle.load(f))
        self.seqs = np.concatenate(seqs)

    def __getitem__(self, index):
        """
        Returns the sequence at index `index` in `self.seqs`. 
        """
        return self.seqs[index]

    def __len__(self):
        return len(self.seqs)


class TFBindSeqLoader:
    def __init__(self, seq_paths):
        """
        Create a PyTorch IterableDataset which yields sequences from a pickled
        file.
        Arguments:
            `seq_path`: path or list of paths to pickle file(s) containing raw
                sequences
        """
        super().__init__()

        if type(seq_paths) is str:
            seq_paths = [seq_paths]
        
        inds = []
        for path in seq_paths:
            with open(path, "rb") as f:
                inds.append(pickle.load(f))
        inds = np.concatenate(inds)

        seqs = np.array(["A", "C", "G", "T"])[inds]
        self.seqs = ["".join(s) for s in seqs]

    def __getitem__(self, index):
        """
        Returns the sequence at index `index` in `self.seqs`. 
        """
        return self.seqs[index]

    def __len__(self):
        return len(self.seqs)


class MPRASeqLoader:
    def __init__(self, seq_paths):
        """
        Create a PyTorch IterableDataset which yields sequences from a pickled
        file.
        Arguments:
            `seq_path`: path or list of paths to pickle file(s) containing raw
                sequences
        """
        super().__init__()

        if type(seq_paths) is str:
            seq_paths = [seq_paths]
        
        seqs = []
        for path in seq_paths:
            with open(path, "rb") as f:
                seqs.append(pickle.load(f))
        self.seqs = np.concatenate(seqs)

    def __getitem__(self, index):
        """
        Returns the sequence at index `index` in `self.seqs`. 
        """
        return self.seqs[index]

    def __len__(self):
        return len(self.seqs)


class SeqDataset(torch.utils.data.IterableDataset):
    """
    Generates batches of one-hot-encoded sequences.
    Arguments:
        `seq_loader`: a sequence loader which returns sequences; note that "$"
            must not be a legal character in the sequences
        `seq_alphabet`: string representing the legal alphabet (e.g. "ACGT")
        `batch_size`: number of sequences per batch, B
    """
    def __init__(self, seq_loader, seq_alphabet, batch_size):
        self.seq_loader = seq_loader
        self.seq_alphabet = seq_alphabet
        self.batch_size = batch_size

        self.num_batches = int(np.ceil(len(seq_loader) / batch_size))
        self.indices = np.random.permutation(len(seq_loader))

    def get_batch(self, index):
        """
        Returns a batch, which is a dictionary containing the following keys:
            "x": a B x L x D tensor of one-hot encoded sequences, where L is the
                maximum sequence length in the batch
            "mask": a B x L tensor of booleans denoting a padding mask, True
                when padding was added
        """
        start = index * self.batch_size
        end = start + self.batch_size
        inds = self.indices[start:end]
        
        # Get sequences
        seqs = [self.seq_loader[i] for i in inds]

        # Pad sequences
        seq_lens = np.array([len(s) for s in seqs])
        max_len = np.max(seq_lens)
        seqs_padded = [s + ("$" * (max_len - len(s))) for s in seqs]

        # Get one-hot encodings
        one_hots = util.seqs_to_one_hot(seqs_padded, alphabet=self.seq_alphabet)
        
        # Compute mask
        rans = np.tile(np.arange(max_len)[None], (len(seqs), 1))
        mask = rans >= seq_lens[:, None]

        return {
            "x": torch.tensor(one_hots),
            "mask": torch.tensor(mask)
        }

    def __iter__(self):
        """
        Returns an iterator over the batches. If the dataset iterator is called
        from multiple workers, each worker will be give a shard of the full
        range.
        """
        worker_info = torch.utils.data.get_worker_info()
        num_batches = self.num_batches
        if worker_info is None:
            # In single-processing mode
            start, end = 0, num_batches
        else:
            worker_id = worker_info.id
            num_workers = worker_info.num_workers
            shard_size = int(np.ceil(num_batches / num_workers))
            start = shard_size * worker_id
            end = min(start + shard_size, num_batches)
        return (self.get_batch(i) for i in range(start, end))

    def __len__(self):
        return self.num_batches
    
    def on_epoch_start(self):
        """
        Re-randomizes the order of sequences.
        """
        self.indices = np.random.permutation(len(self.seq_loader))


if __name__ == "__main__":
    dataset = SeqDataset(
        SeqLoader(
            "/projects/site/gred/resbioai/tsenga5/seq_diff/data/train_TFbind.pkl"
        ), DNA_ALPHABET, 32
    )
    dataset.on_epoch_start()
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=None, num_workers=4, collate_fn=(lambda x: x)
    )
    batch = next(iter(data_loader))
