# Standard Library Imports
import json

# Third-Party Library Imports
import numpy as np
import torch
import scipy.signal

# Local Imports
import utils
import sim.motif_util as motif_util

class SeqSimulator:
    def __init__(
        self, motif_config_path, input_length, seq_alphabet, bg_seq_freqs,
        motif_center_dist_bound, data_seed
    ):
        """
        Samples simulated sequences given a motif configuration file.
        Arguments:
            `motif_config_path`: path to JSON file containing motif
                configurations to sample from
            `input_length`: length of input sequences to generate
            `seq_alphabet`: string containing alphabet of sequences; defaults to
                "ACGT"
            `bg_seq_freqs`: NumPy array of background frequency for each
                sequence token
            `motif_center_dist_bound`: maximum random distance of motif-
                configuration center from sequence center
            `data_seed`: seed for sampling sequences
        """

        assert len(seq_alphabet) == len(bg_seq_freqs)
        self.input_length = input_length
        self.seq_alphabet = seq_alphabet
        self.seq_alphabet_arr = np.array(list(seq_alphabet))
        self.bg_seq_freqs = bg_seq_freqs
        self.bg_seq_freqs_cumsum = \
            np.cumsum(bg_seq_freqs / np.sum(bg_seq_freqs))
        self.motif_center_dist_bound = motif_center_dist_bound

        # Import motif configs
        with open(motif_config_path, "r") as f:
            motif_configs = json.load(f)
        self.motif_dict = motif_util.import_meme_motifs(
            motif_configs["motif_files"]
        )
        self.configs = motif_configs["configs"]
        
        self.rng = np.random.default_rng(data_seed)

    def set_seed(self, data_seed):
        """
        Sets seed of RNG.
        Arguments:
            `data_seed`: integer-valued seed
        """
        self.rng = np.random.default_rng(data_seed)
    
    def _get_possible_motifs(self, configs=None):
        """
        Computes the set of motif keys which could be called for by the motif
        configuration with non-zero probability.
        Arguments:
            `configs`: a list of configurations; default is to use
                `self.configs`
        `configs` must be one of the following 3 forms:
        1) A list of strings and/or spacings specifying the motif IDs and spaces
            between them (e.g. ["GATA1", 8, "TAL1"])
        2) A list of lists, where each inner list is a set of configurations to
            be chosen from uniformly at random
        3) A list of dictionaries, where each dictionary specifies a probability
            and a configuration; dictionaries are selected according to their
            (normalized) probabilities
        Returns a set of motif keys.
        """
        # Motif configurations are defined recursively
        if configs is None:
            configs = self.configs
        
        if all(type(c) in (str, int) for c in configs) or configs == []:
            # This is a single config containing a motif specification
            # Return the set of all motif keys in it
            return set([c for c in configs if type(c) is str])
        elif all(type(c) is list for c in configs):
            # Everything is a list, so take union
            return set().union(*[
                self._get_possible_motifs(c) for c in configs
            ])
        elif all(type(c) is dict for c in configs):
            # Everything is a dictionary, so take union if the probability is
            # non-zero
            return set().union(*[
                self._get_possible_motifs(c["configs"]) for c in configs
                if c["p"] > 0
            ])
        else:
            raise ValueError("Unknown configuration type")

    def sample_random_seq(
        self, length=None, motifs_blacklist=None, match_reject_prob=0,
        match_score_thresh=0.9, match_revcomp=True
    ):
        """
        Samples a random sequence as defined by `self.seq_alphabet` and
        `self.bg_seq_freqs`.
        Arguments:
            `length`: length of sequence to sample; defaults to
                `self.input_length`
            `motif_blacklist`: an iterable of NumPy arrays of motifs to avoid
                in the sampled sequence
            `match_reject_prob`: if the sampled sequence contains any motif in
                `motif_blacklist`, then it is rejected with this probability;
                this probability may be 0
            `match_score_thresh`: match-score threshold for considering a motif
                to be present in the sequence.
            `match_revcomp`: if True, also check motif reverse complements for
                matches
        Returns a random string sequence.
        """
        if length is None:
            length = self.input_length
        
        if self.rng.random() > match_reject_prob:
            # Just sample a sequence and return it
            inds = np.searchsorted(
                self.bg_seq_freqs_cumsum, self.rng.random(length)
            )
            return "".join(self.seq_alphabet_arr[inds])
        else:
            motifs_to_match = list(motifs_blacklist)
            if match_revcomp:
                motifs_to_match.extend([np.flip(m) for m in motifs_to_match])

            # Keep searching until we find a sequence we like
            while True:
                # Sample a long sequence (10 * length arbitrarily)
                inds = np.searchsorted(
                    self.bg_seq_freqs_cumsum, self.rng.random(length * 10)
                )
                long_seq = utils.seqs_to_one_hot(
                    ["".join(self.seq_alphabet_arr[inds])]
                )[0]  # Shape: L x 4
               
                # Scan for matches to get an array of scores for each motif
                all_match_scores = [
                    scipy.signal.correlate(
                        long_seq, m, mode="valid"
                    )[:, 0] / len(m)
                    for m in motifs_to_match
                ]
                # Convert to boolean arrays
                all_match_bools = [
                    scores >= match_score_thresh for scores in all_match_scores
                ]
                # Left-justify and cut off; we won't care about different
                # lengths arising from different motif sizes
                min_length = min(len(arr) for arr in all_match_bools)
                all_match_bools = [arr[:min_length] for arr in all_match_bools]
                # Take logical or
                match_bools = np.any(np.stack(all_match_bools), axis=0)
                # Get indices of matches and gaps
                match_inds = np.where(match_bools)[0]
                # Tack on index of -1 and length
                match_inds = np.pad(
                    match_inds, (1, 1), "constant",
                    constant_values=(-1, len(match_bools))
                )
                match_gaps = np.diff(match_inds)
                if not np.any(match_gaps >= length + 1):
                    # No stretch between matches is long enough; try again
                    continue
                start = match_inds[np.where(match_gaps >= length)[0][0]] + 1

                return utils.one_hot_to_seqs(
                    long_seq[start : start + length][None]
                )[0]
                
    def _sample_motif_seq(self, motif):
        """
        Samples a random sequence from a motif probability matrix. The alphabet
        is assumed to be `self.seq_alphabet`.
        Arguments:
            `motif`: an L x D array of probabilities
        Returns a string of length L.
        """
        # Renormalize to ensure it sums to 1
        motif = motif / np.sum(motif, axis=1, keepdims=True)
        prob_cumsums = np.cumsum(motif, axis=1)
        return "".join([
            self.seq_alphabet_arr[np.searchsorted(probs, self.rng.random())]
            for probs in prob_cumsums
        ])

    def sample_config_seq(
        self, configs=None, motifs_blacklist=None, match_reject_prob=0,
        match_score_thresh=0.9, match_revcomp=True, return_config=False
    ):
        """
        Samples a single sequence as defined by the motif configuration.
        Arguments:
            `configs`: a list of configurations; default is to use
                `self.configs`
            `motif_blacklist`: an iterable of NumPy arrays of motifs to avoid
                in the sampled background
            `match_reject_prob`: if the sampled sequence contains any motif in
                `motif_blacklist`, then it is rejected with this probability;
                this probability may be 0
            `match_score_thresh`: match-score threshold for considering a motif
                to be present in the sequence.
            `match_revcomp`: if True, also check motif reverse complements for
                matches
            `return_config`: if True, also return the specific configuration
                which was sampled (a pair of the start position and a list which
                may contain strings and/or integers)
        `configs` must be one of the following 3 forms:
        1) A list of strings and/or spacings specifying the motif IDs and spaces
            between them (e.g. ["GATA1", 8, "TAL1"])
        2) A list of lists, where each inner list is a set of configurations to
            be chosen from uniformly at random
        3) A list of dictionaries, where each dictionary specifies a probability
            and a configuration; dictionaries are selected according to their
            (normalized) probabilities
        Returns a string of length `self.input_length`, and perhaps a pair of
        the exact configuration chosen.
        """
        # Motif configurations are defined recursively
        if configs is None:
            configs = self.configs
        
        if all(type(c) in (str, int) for c in configs) or configs == []:
            # This is a single config containing a motif specification; this
            # also includes if the config is an empty list, in which case there
            # is no motif to insert
            motif_string = ""
            for c in configs:
                if type(c) is str:
                    motif_string += self._sample_motif_seq(self.motif_dict[c])
                else:
                    motif_string += self.sample_random_seq(
                        c, motifs_blacklist, match_reject_prob,
                        match_score_thresh, match_revcomp
                    )

            # Pick random offset from center
            offset = self.rng.integers(
                -self.motif_center_dist_bound, self.motif_center_dist_bound + 1
            )

            left_pad = (self.input_length // 2) - (len(motif_string) // 2) + \
                offset
            right_pad = self.input_length - left_pad - len(motif_string)

            left_seq = self.sample_random_seq(
                left_pad, motifs_blacklist, match_reject_prob,
                match_score_thresh, match_revcomp
            )
            right_seq = self.sample_random_seq(
                right_pad, motifs_blacklist, match_reject_prob,
                match_score_thresh, match_revcomp
            )
            result = left_seq + motif_string + right_seq

            if return_config:
                return result, (left_pad, configs)
            else:
                return result

        elif all(type(c) is list for c in configs):
            # Everything is a list, so pick uniformly
            index = self.rng.integers(len(configs))
            return self.sample_config_seq(
                configs[index], motifs_blacklist, match_reject_prob,
                match_score_thresh, match_revcomp, return_config
            )

        elif all(type(c) is dict for c in configs):
            # Everything is a dictionary, so pick according to the probabilities
            probs = np.array([c["p"] for c in configs])
            probs = probs / np.sum(probs)
            index = np.searchsorted(np.cumsum(probs), self.rng.random())
            return self.sample_config_seq(
                configs[index]["configs"], motifs_blacklist, match_reject_prob,
                match_score_thresh, match_revcomp, return_config
            )
        
        else:
            raise ValueError("Unknown configuration type")

    def seqs_to_one_hot(self, seqs):
        """
        Converts a list of strings to one-hot encodings, where the encoding is
        specified by `self.seq_alphabet`.
        Arguments:
            `seqs`: a list of N strings, where every string is the same length L
        Returns an N x L x D NumPy array of one-hot encodings, in the same order
        as the input sequences.
        """
        return utils.seqs_to_one_hot(
            seqs, alphabet=self.seq_alphabet, to_upper=False
        )


class SimulatedSeqDataset(torch.utils.data.IterableDataset):
    def __init__(
        self, pos_seq_simulator, batch_size, num_batches, negative_ratio,
        neg_seq_simulator=None, revcomp=False, background_match_reject_prob=0,
        background_match_score_thresh=0.9, return_configs=False
    ):
        """
        Generates batches of one-hot-encoded sequences and binary labels.
        Arguments:
            `pos_seq_simulator (SeqSimulator): generates simulated sequences
                with motif configurations for the positive label
            `batch_size`: number of sequences per batch, B
            `num_batches`: number of batches in an epoch
            `negative_ratio`: generate this many negative sequences per batch as
                positive ones
            `neg_seq_simulator (SeqSimulator): generates simulated sequences
                with motif configurations for the negative label; by default,
                this is None, and negative sequences are randomized backgrounds
            `revcomp`: whether or not to perform revcomp to the batch; this will
                not change the batch size, but halve the number of unique
                objects per batch; if True, `batch_size` must be even
            `background_match_reject_prob`: probability of rejecting a random
                background sequence if it has a motif match; this probability
                may be 0
            `background_match_score_thresh`: match-score threshold for
                considering a background sequence to have a motif
            `return_configs`: if True, also return the specific start positions
                and configurations used to generate the sequences
        In each batch, generates a B x L x D NumPy array of one-hot encodings
        and a B-array of binary labels. May also return a B-list of
        configurations for each sequence (each is a triplet of a start position
        index, a list containing motif keys and spacings, and whether or not the
        sequence/configuration is reverse complemented).
        """
        self.pos_seq_simulator = pos_seq_simulator
        self.batch_size = batch_size
        self.num_batches = num_batches
        self.negative_ratio = negative_ratio
        self.neg_seq_simulator = neg_seq_simulator
        self.revcomp = revcomp
        self.background_match_reject_prob = background_match_reject_prob
        self.background_match_score_thresh = background_match_score_thresh
        self.return_configs = return_configs

        if background_match_reject_prob > 0:
            self.possible_motifs = {
                key : pos_seq_simulator.motif_dict[key] for key in
                pos_seq_simulator._get_possible_motifs()
            }
            if neg_seq_simulator is not None:
                self.possible_motifs.update({
                    key : neg_seq_simulator.motif_dict[key] for key in
                    neg_seq_simulator._get_possible_motifs()
                })
        else:
            self.possible_motifs = {}

        if revcomp:
            assert batch_size % 2 == 0
            revcomp_factor = 2
        else:
            revcomp_factor = 1

        self.num_pos_per_batch = int(
            np.ceil((batch_size // revcomp_factor) / (1 + negative_ratio))
        )
        self.num_neg_per_batch = (batch_size // revcomp_factor) - \
            self.num_pos_per_batch

    def get_batch(self, index):
        """
        Returns a batch, which consists of a B x L x D NumPy array of 1-hot
        encoded sequences and a B-array of labels.
        Arguments:
            `index`: unused argument which normally would specify the index of a
                batch
        """
        labels = np.concatenate([
            np.ones(self.num_pos_per_batch), np.zeros(self.num_neg_per_batch)
        ])
      
        # Sample positive sequences
        pos_samples = [
            self.pos_seq_simulator.sample_config_seq(
                motifs_blacklist=self.possible_motifs.values(),
                match_reject_prob=self.background_match_reject_prob,
                match_score_thresh=self.background_match_score_thresh,
                match_revcomp=self.revcomp,
                return_config=self.return_configs
            ) for _ in range(self.num_pos_per_batch)
        ]
        if self.return_configs:
            pos_seqs, pos_configs = zip(*pos_samples)
            pos_seqs, pos_configs = list(pos_seqs), list(pos_configs)
        else:
            pos_seqs = pos_samples
        pos_one_hots = self.pos_seq_simulator.seqs_to_one_hot(pos_seqs)
        
        # Sample negative sequences
        if not self.num_neg_per_batch:
            # No negatives in the batch
            neg_one_hots = np.empty((0,) + pos_one_hots.shape[1:])
            if self.return_configs:
                neg_configs = []
        else:
            if self.neg_seq_simulator is None:
                neg_seqs = [
                    self.pos_seq_simulator.sample_random_seq(
                        motifs_blacklist=self.possible_motifs.values(),
                        match_reject_prob=self.background_match_reject_prob,
                        match_score_thresh=self.background_match_score_thresh,
                        match_revcomp=self.revcomp
                    ) for _ in range(self.num_neg_per_batch)
                ]
                if self.return_configs:
                    # Configurations are just empty here
                    neg_configs = [None] * len(neg_seqs)
            else:
                neg_samples = [
                    self.neg_seq_simulator.sample_config_seq(
                        motifs_blacklist=self.possible_motifs.values(),
                        match_reject_prob=self.background_match_reject_prob,
                        match_score_thresh=self.background_match_score_thresh,
                        match_revcomp=self.revcomp,
                        return_config=self.return_configs
                    ) for _ in range(self.num_neg_per_batch)
                ]
                if self.return_configs:
                    neg_seqs, neg_configs = zip(*neg_samples)
                    neg_seqs, neg_configs = list(neg_seqs), list(neg_configs)
                else:
                    neg_seqs = neg_samples
            # Convert to one-hot sequences, just use the positive simulator
            neg_one_hots = self.pos_seq_simulator.seqs_to_one_hot(neg_seqs)

        one_hots = np.concatenate([pos_one_hots, neg_one_hots])
        if self.return_configs:
            configs_no_orient = pos_configs + neg_configs

        if self.revcomp:
            # Only support reverse-complement augmentation for ACGT sequences
            assert self.pos_seq_simulator.seq_alphabet == "ACGT"
            one_hots = np.concatenate([
                one_hots, np.flip(one_hots, axis=(1, 2))
            ])
            labels = np.concatenate([labels, labels])
            if self.return_configs:
                configs = [tup + (False,) for tup in configs_no_orient] + \
                    [tup + (True,) for tup in configs_no_orient]
        else:
            if self.return_configs:
                configs = [tup + (True,) for tup in configs_no_orient]
        
        if self.return_configs:
            return torch.tensor(one_hots), torch.tensor(labels), configs
        else:
            return torch.tensor(one_hots), torch.tensor(labels)

    def __iter__(self):
        """
        Returns an iterator over the batches. If the dataset iterator is called
        from multiple workers, each worker will be give a shard of the full
        range.
        """
        worker_info = torch.utils.data.get_worker_info()
        num_batches = self.num_batches
        if worker_info is None:
            # In single-processing mode
            start, end = 0, num_batches
        else:
            worker_id = worker_info.id
            num_workers = worker_info.num_workers
            shard_size = int(np.ceil(num_batches / num_workers))
            start = shard_size * worker_id
            end = min(start + shard_size, num_batches)
        return (self.get_batch(i) for i in range(start, end))

    def __len__(self):
        return self.num_batches
    
    def on_epoch_start(self):
        """
        Placeholder function that does nothing
        """
        pass

def create_data_loader(
    motif_config_path, batch_size=128, num_batches=100, input_length=500,
    seq_alphabet="ACGT", bg_seq_freqs=np.array([0.25, 0.25, 0.25, 0.25]),
    motif_center_dist_bound=50, negative_ratio=1, revcomp=True,
    background_match_reject_prob=1, background_match_score_thresh=0.9,
    num_workers=10, data_seed=None, neg_motif_config_path=None,
    return_configs=False
):
    """
    Creates a PyTorch DataLoader object which iterates through batches of data.
    Arguments:
        `motif_config_path`: path to JSON file containing motif configurations
            to sample from
        `neg_motif_config_path`: path to JSON file containing motif
            configurations for negative examples; defaults to having negatives
            just be random backgrounds
        `return_configs`: if True, each batch also returns the offsets and
            configurations used to create simulated sequences
    """
    pos_seq_simulator = SeqSimulator(
        motif_config_path, input_length, seq_alphabet, bg_seq_freqs,
        motif_center_dist_bound, data_seed
    )

    if neg_motif_config_path:
        neg_seq_simulator = SeqSimulator(
            neg_motif_config_path, input_length, seq_alphabet, bg_seq_freqs,
            motif_center_dist_bound, data_seed
        )
    else:
        neg_seq_simulator = None


    dataset = SimulatedSeqDataset(
        pos_seq_simulator, batch_size, num_batches, negative_ratio,
        neg_seq_simulator, revcomp, background_match_reject_prob,
        background_match_score_thresh, return_configs=return_configs
    )

    generator = torch.Generator()
    if data_seed is not None:
        # This sets the initial state of torch.initial_seed(), making future
        # calls to it deterministic
        generator.manual_seed(data_seed)
    else:
        # This makes sure that when there is no seed provided, the generator is
        # seeded randomly
        generator.seed()

    def worker_init_fn(worker_id):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            seed = torch.initial_seed() % (2 ** 32)
            worker_info.dataset.pos_seq_simulator.set_seed(
                seed + (2 * worker_id)
            )
            if worker_info.dataset.neg_seq_simulator is not None:
                worker_info.dataset.neg_seq_simulator.set_seed(
                    seed + (2 * worker_id) + 1
                )

    # Dataset loader: dataset is iterable and already returns batches
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=None, num_workers=num_workers,
        collate_fn=lambda x: x, worker_init_fn=worker_init_fn,
        generator=generator
    )

    return loader
