# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/data.ipynb.

# %% auto 0
__all__ = ["Interval", "Transcript", "RefseqDataset"]

# %% ../nbs/data.ipynb 2
import pandas as pd
import numpy as np
from .genome import Genome
from typing import List, Sequence, Tuple
import os

# %% ../nbs/data.ipynb 3
class Interval(object):
    """
    Interval for representing genomic coordinates.
    Uses 0-start, half-open (0-based) coordinate system. Same as ucsc genome browser internal documentation
    For more information:
        - http://genome.ucsc.edu/blog/the-ucsc-genome-browser-coordinate-counting-systems/
        - https://www.biostars.org/p/84686/

    """

    def __init__(
        self, chromosome: str, start: int, end: int, strand: str, genome: Genome
    ):
        assert strand in ["+", "-"], strand
        self.strand = strand

        valid_chromosomes = ["chr{}".format(i) for i in range(23)] + ["chrX", "chrY"]
        # assert chromosome in valid_chromosomes, chromosome
        self.chromosome = chromosome
        self.chrom = chromosome

        assert start <= end, (start, end)
        assert start >= 0, (start, end)
        self.start = start
        self.end = end

        self.genome = genome
        self.alphabet_map = {"A": 0, "C": 1, "T": 2, "G": 3, "N": 4}

    def __len__(self):
        return self.end - self.start

    def __repr__(self):
        """
        Return string representation of a transcript
        """
        return "Interval {}:{}-{}:{}".format(
            self.chromosome, self.start, self.end, self.strand
        )

    def overlaps(self, interval):

        assert type(interval) == Interval
        if interval.chrom != self.chrom:
            return 0
        if interval.strand != self.strand:
            return 0

        overlap_start = np.max([self.start, interval.start])
        overlap_end = np.min([self.end, interval.end])

        overlap = overlap_end - overlap_start
        if overlap < 0:
            return 0
        else:
            return overlap

    def within(self, interval):
        assert type(interval) == Interval
        if interval.chrom != self.chrom:
            return False
        if interval.strand != self.strand:
            return False

        after_start = self.start >= interval.start
        before_end = self.end <= interval.end

        return after_start and before_end

    def one_hot_encode(self, zero_mean: bool = True):
        seq = self.genome.get_encoding_from_coords(
            self.chrom, self.start, self.end, self.strand
        )
        if zero_mean:
            seq = seq - 0.25
        return seq

    def sequence(self):
        return self.genome.get_sequence_from_coords(
            self.chrom, self.start, self.end, self.strand
        )

    def encode(self):
        seq = self.genome.get_sequence_from_coords(
            self.chrom, self.start, self.end, self.strand
        ).upper()
        return np.array([self.alphabet_map[x] for x in seq])


# %% ../nbs/data.ipynb 5
class Transcript(object):
    """
    An object reprenting an RNA transcript allowing to query RNA sequence and convert to one hot encoding
    """

    def __init__(
        self,
        transcript_id: str,
        gene: str,
        exon_starts: List[int],
        exon_ends: List[int],
        genome: Genome,
        exon_count: int,
        strand: str,
        chromosome: str,
        tx_start: int,
        tx_end: int,
        cds_start: int,
        cds_end: int,
        expand_transcript_distance: int = 0,
        expand_exon_distance: int = 0,
    ) -> None:
        """
        transcript_id
        gene
        exon_starts
        exon_ends
        genome
        exon_count
        pad transcript
        expand_transcript_distance
        expand_exon_distance
        """
        self.transcript_id = transcript_id
        self.gene = gene
        self.chromosome = chromosome
        self.chrom = chromosome
        self.strand = strand
        self.tx_start = tx_start
        self.tx_end = tx_end
        self.cds_start = cds_start
        self.cds_end = cds_end
        self.exon_starts = exon_starts
        self.exon_ends = exon_ends
        self.exon_count = exon_count

        self.expand_transcript_distance = expand_transcript_distance
        self.expand_exon_distance = expand_exon_distance
        self.genome = genome

        assert self.strand in ["+", "-"]
        assert len(self.exon_ends) == len(self.exon_starts), (
            len(self.exon_ends),
            len(self.exon_starts),
            transcript_id,
        )
        assert len(self.exon_ends) == exon_count, (
            len(self.exon_ends),
            exon_count,
            transcript_id,
        )

        # Construct exon expand distance
        self.exon_distance_list = self.generate_inter_exon_distances()
        self.exon_expand_distances = self.calculate_expand_distance()

        # construct transcript intervals which includes the exon expand distance and transcript expand distance
        self.transcript_intervals = self.construct_transcript_intervals()

        # Construct exon list
        exons = []
        for exon_coords in zip(exon_starts, exon_ends):
            exons.append(
                Interval(
                    self.chromosome,
                    exon_coords[0],
                    exon_coords[1],
                    self.strand,
                    self.genome,
                )
            )
        self.exons = exons

    def __len__(self):
        """
        Length returns the total length of the generated sequence including exon expand distance
        and transcript expand distance.
        """
        return np.sum([len(x) for x in self.transcript_intervals])

    def __repr__(self):
        """
        Return string representation of a transcript
        """
        return "Transcript {} {}:{}-{}:{}".format(
            self.transcript_id, self.chromosome, self.tx_start, self.tx_end, self.strand
        )

    def generate_inter_exon_distances(self):
        """
        Generates distances between exons
        """
        inter_exon_dist_tuple = list(
            zip([x for x in self.exon_starts[:-1]], [x for x in self.exon_ends[1:]])
        )
        inter_exon_dist = [x[1] - x[0] for x in inter_exon_dist_tuple]

        distances_to_next_exon = inter_exon_dist + [(self.tx_end - self.exon_ends[-1])]
        distances_to_previous_exon = [
            self.exon_starts[0] - self.tx_start
        ] + inter_exon_dist
        exon_distance_list = list(
            zip(distances_to_next_exon, distances_to_previous_exon)
        )

        # Require all the elements to not be intersecting
        assert all([dist >= 0 for x in exon_distance_list for dist in x]), (
            self.transcript_id,
            self.chromosome,
            self.tx_start,
            self.tx_end,
            self.strand,
            self.exon_starts,
            self.exon_ends,
            exon_distance_list,
        )
        return exon_distance_list

    def calculate_relative_cds_start_end(self):
        assert self.expand_exon_distance == 0, "Doesn't work with introns"
        assert self.expand_transcript_distance == 0, "Doesn't work with promoters"

        first_cds_base = Interval(
            self.chromosome, self.cds_start, self.cds_start, self.strand, self.genome
        )
        last_cds_base = Interval(
            self.chromosome, self.cds_end, self.cds_end, self.strand, self.genome
        )

        assert self.transcript_intervals

        # Find CDS start
        exon_index_cds_start = np.argwhere(
            [first_cds_base.within(interval) for interval in self.transcript_intervals]
        )
        assert len(exon_index_cds_start) == 1
        exon_index_cds_start = exon_index_cds_start[0][0]

        # count all the exons until the start
        length_until_cds_start = sum(
            [len(x) for x in self.transcript_intervals[: exon_index_cds_start + 1]]
        )
        # subtract distance from end of exon until start of CDS
        within_exon_d_to_start = (
            self.transcript_intervals[exon_index_cds_start].end - first_cds_base.end
        )
        assert within_exon_d_to_start >= 0
        length_until_cds_start -= within_exon_d_to_start

        # Find CDS end
        exon_index_cds_end = np.argwhere(
            [last_cds_base.within(interval) for interval in self.transcript_intervals]
        )
        assert len(exon_index_cds_end) == 1
        exon_index_cds_end = exon_index_cds_end[0][0]

        length_until_cds_end = sum(
            [len(x) for x in self.transcript_intervals[: exon_index_cds_end + 1]]
        )
        # subtract distance from end of exon until start of CDS
        within_exon_d_to_end = (
            self.transcript_intervals[exon_index_cds_end].end - last_cds_base.end
        )
        assert within_exon_d_to_end >= 0
        length_until_cds_end -= within_exon_d_to_end

        if self.strand == "+":
            relative_cds_end = length_until_cds_end
            relative_cds_start = length_until_cds_start
        elif self.strand == "-":
            relative_cds_end = self.__len__() - length_until_cds_start
            relative_cds_start = self.__len__() - length_until_cds_end

        return relative_cds_start, relative_cds_end

    def calculate_relative_splice_sites(self):
        assert self.expand_exon_distance == 0, "Doesn't work with introns"
        assert self.expand_transcript_distance == 0, "Doesn't work with promoters"
        assert self.transcript_intervals

        if self.strand == "-":

            indices = [
                sum([len(x) for x in self.transcript_intervals[:exon_index]])
                for exon_index in range(len(self.transcript_intervals))
            ]
            indices = [len(self) - x for x in indices]
            indices = indices[::-1]
            # 0 base encoding
            indices = [x - 1 for x in indices]

        elif self.strand == "+":
            indices = [
                sum([len(x) for x in self.transcript_intervals[: exon_index + 1]])
                for exon_index in range(len(self.transcript_intervals))
            ]
            # 0 base encoding
            indices = [x - 1 for x in indices]

        return indices

    def calculate_expand_distance(self):
        """
        Calculates the expand distance for every single exon. Makes sure there are no overlapping sequences
        """
        exon_expand_distances = []
        for i in range(self.exon_count):

            # 2 cases
            #   1 where the distance between exons is greater than 2* exon_expand_distance
            #   2 where the distance between exons is less than 2* exon_expand_distance

            # in case of first exon the tx start won't be expanding forwards so we don't have to impose /2 restriction
            if i != 0:
                free_distance_to_previous_exon = self.exon_distance_list[i][0] / 2
            else:
                free_distance_to_previous_exon = self.exon_distance_list[i][0]

            if free_distance_to_previous_exon < self.expand_exon_distance:
                # Take half of the available distance to the next exon
                expand_exon_distance_up = int(free_distance_to_previous_exon)
            else:
                expand_exon_distance_up = self.expand_exon_distance

            # in case of last exon the tx_end won't be expanding backwards so we don't have to impose /2 restriction
            if i != self.exon_count - 1:
                free_distance_to_next_exon = self.exon_distance_list[i][1] / 2
            else:
                free_distance_to_next_exon = self.exon_distance_list[i][1]

            if free_distance_to_next_exon < self.expand_exon_distance:
                # In case the next exon is closer than the expand distance take half the available space to expand
                expand_exon_distance_down = int(free_distance_to_next_exon)
            else:
                expand_exon_distance_down = self.expand_exon_distance

            exon_expand_distances.append(
                (expand_exon_distance_up, expand_exon_distance_down)
            )
        return exon_expand_distances

    def construct_transcript_intervals(self):
        transcript_intervals = []

        # transcript expand distance
        if self.expand_transcript_distance:
            transcript_intervals.append(
                Interval(
                    self.chromosome,
                    self.tx_start - self.expand_transcript_distance,
                    self.tx_start,
                    self.strand,
                    self.genome,
                )
            )

        # Exon intervals
        for i in range(self.exon_count):
            transcript_intervals.append(
                Interval(
                    self.chromosome,
                    self.exon_starts[i] - self.exon_expand_distances[i][0],
                    self.exon_ends[i] + self.exon_expand_distances[i][1],
                    self.strand,
                    self.genome,
                )
            )

        if self.expand_transcript_distance:
            transcript_intervals.append(
                Interval(
                    self.chromosome,
                    self.tx_end,
                    self.tx_end + self.expand_transcript_distance,
                    self.strand,
                    self.genome,
                )
            )

        return transcript_intervals

    def one_hot_encode_transcript(
        self, pad_length_to: int = 0, zero_mean: bool = True, zero_pad: bool = False
    ):

        if self.strand == "+":
            one_hot_list = [
                x.one_hot_encode(zero_mean) for x in self.transcript_intervals
            ]
        elif self.strand == "-":
            one_hot_list = [
                x.one_hot_encode(zero_mean) for x in self.transcript_intervals[::-1]
            ]
        else:
            raise ValueError

        if pad_length_to:
            # check padding length is greater than the self length
            assert (
                len(self) <= pad_length_to
            ), "Length of transcript {} greater than padding specified {}".format(
                len(self), pad_length_to
            )
            # N is represented as [0.25, 0.25, 0.25, 0.25]

            if len(self) < pad_length_to:
                pad_sequence = np.zeros(
                    (pad_length_to - len(self), 4), dtype=np.float32
                )
                # can be not zero mean but still pad with zeros
                if not zero_mean and not zero_pad:
                    pad_sequence = pad_sequence + 0.25

                one_hot_list.append(pad_sequence)

        concat_sequence = np.concatenate(one_hot_list)

        return concat_sequence

    def get_sequence(self, pad_length_to: int = 0):
        if self.strand == "+":
            seqs = [x.sequence() for x in self.transcript_intervals]
        elif self.strand == "-":
            seqs = [x.sequence() for x in self.transcript_intervals[::-1]]

        if pad_length_to:
            # check padding length is greater than the self length
            assert (
                len(self) <= pad_length_to
            ), "Length of transcript {} greater than padding specified {}".format(
                len(self), pad_length_to
            )
            # N is represented as [0.25, 0.25, 0.25, 0.25]
            if len(self) < pad_length_to:
                seqs.append((pad_length_to - len(self)) * "N")
        return "".join(seqs)

    @classmethod
    def translate_dna(cls, dna_sequence):
        """
        Translates a DNA sequence into a protein sequence
        using the standard genetic code.

        Args:
            dna_sequence (str): A string representing the
            DNA sequence. Must consist of characters 'T', 'C', 'A', 'G'.

        Returns:
            str: The corresponding protein sequence. Unrecognized
            codons are represented by '?'. The stop codons are represented by '*'.

        Examples:
            >>> translate_dna('ATGGCCATGGCGCCCAGAACTGAGATCAATAGTACCCGTATTAACGGGTGA')
            'MAMAPRTEINSTRING-'
            >>> translate_dna('ATGTTTCAA')
            'MFQ'
        """

        codon_map = {
            'TTT': 'F', 'CTT': 'L', 'ATT': 'I', 'GTT': 'V',
            'TTC': 'F', 'CTC': 'L', 'ATC': 'I', 'GTC': 'V',
            'TTA': 'L', 'CTA': 'L', 'ATA': 'I', 'GTA': 'V',
            'TTG': 'L', 'CTG': 'L', 'ATG': 'M', 'GTG': 'V',
            'TCT': 'S', 'CCT': 'P', 'ACT': 'T', 'GCT': 'A',
            'TCC': 'S', 'CCC': 'P', 'ACC': 'T', 'GCC': 'A',
            'TCA': 'S', 'CCA': 'P', 'ACA': 'T', 'GCA': 'A',
            'TCG': 'S', 'CCG': 'P', 'ACG': 'T', 'GCG': 'A',
            'TAT': 'Y', 'CAT': 'H', 'AAT': 'N', 'GAT': 'D',
            'TAC': 'Y', 'CAC': 'H', 'AAC': 'N', 'GAC': 'D',
            'TAA': '*', 'CAA': 'Q', 'AAA': 'K', 'GAA': 'E',
            'TAG': '*', 'CAG': 'Q', 'AAG': 'K', 'GAG': 'E',
            'TGT': 'C', 'CGT': 'R', 'AGT': 'S', 'GGT': 'G',
            'TGC': 'C', 'CGC': 'R', 'AGC': 'S', 'GGC': 'G',
            'TGA': '*', 'CGA': 'R', 'AGA': 'R', 'GGA': 'G',
            'TGG': 'W', 'CGG': 'R', 'AGG': 'R', 'GGG': 'G'
        }

        protein_sequence = ''
        for i in range(0, len(dna_sequence), 3):
            codon = dna_sequence[i:i + 3].upper()
            protein_sequence += codon_map.get(codon, '?')

        return protein_sequence

    def get_amino_acid_sequence(self):
        rel_cds_start, rel_cds_end = self.calculate_relative_cds_start_end()
        nt_sequence = self.get_sequence()
        coding_sequence = nt_sequence[rel_cds_start: rel_cds_end]
        aa_seq = Transcript.translate_dna(coding_sequence)
        return aa_seq

    def encode(self, pad_length_to: int = 0):
        if self.strand == "+":
            seqs = np.concatenate(
                [x.encode() for x in self.transcript_intervals]
            ).flatten()
        elif self.strand == "-":
            seqs = np.concatenate(
                [x.encode() for x in self.transcript_intervals[::-1]]
            ).flatten()

        if pad_length_to:
            # check padding length is greater than the self length
            assert (
                len(self) <= pad_length_to
            ), "Length of transcript {} greater than padding specified {}".format(
                len(self), pad_length_to
            )
            # N is represented as [0.25, 0.25, 0.25, 0.25]
            if len(self) < pad_length_to:
                n = np.repeat(4, pad_length_to - len(self))
                seqs = np.concatenate([seqs, n])
        return seqs

    def encode_splice_track(self, pad_length_to: int = 0):
        rel_ss = self.calculate_relative_splice_sites()

        if pad_length_to == 0:
            encoding_length = len(self)
        else:
            encoding_length = pad_length_to

        ss_encoded = np.zeros(encoding_length)
        ss_encoded[rel_ss] = 1

        return ss_encoded.reshape(-1, 1)

    def encode_coding_sequence_track(self, pad_length_to: int = 0):

        rel_cds_start, rel_cds_end = self.calculate_relative_cds_start_end()

        if pad_length_to == 0:
            encoding_length = len(self)
        else:
            encoding_length = pad_length_to

        # if no coding sequence return empty track
        if rel_cds_start == rel_cds_end:
            return np.zeros((encoding_length, 1))

        first_nuc_index = np.arange(rel_cds_end - rel_cds_start, step=3) + rel_cds_start
        cds_encoded = np.zeros(encoding_length)
        cds_encoded[first_nuc_index] = 1

        return cds_encoded.reshape(-1, 1)


# %% ../nbs/data.ipynb 8
class RefseqDataset:
    """
    Refseq dataset
    """

    def __init__(self, transcript_list: List[Transcript]):
        assert all(
            [type(x) == Transcript for x in transcript_list]
        ), "Not transcripts passed into dataset {}".format(
            pd.value_counts([type(x) for x in transcript_list])
        )
        self.transcripts = transcript_list
        self.max_transcript_length = np.max([len(t) for t in self.transcripts])
        self.valid_chromosomes = ["chr{}".format(i) for i in range(23)]

    def __len__(self):
        return len(self.transcripts)

    def __getitem__(self, idx):
        return self.transcripts[idx]

    @classmethod
    def load_refseq_as_df(
        cls,
        refseq_df_path,
        mini,
        chromosomes_to_use,
        drop_non_nm=False,
    ):
        if mini:
            df = pd.read_csv(refseq_df_path, compression="infer", sep="\t", nrows=1000)
        else:
            df = pd.read_csv(refseq_df_path, compression="infer", sep="\t")
        # Subset to just transcript coding transcripts
        if drop_non_nm:
            df = df[df["name"].str.startswith("NM")]

        df = df[df["chrom"].isin(chromosomes_to_use)]
        return df

    @classmethod
    def refseq_df_to_transcripts(
        cls, df, expand_transcript_distance, expand_exon_distance, genome
    ):
        transcripts = []
        for index, row in df.iterrows():
            exon_starts = [int(x) for x in row["exonStarts"].split(",") if x]
            exon_ends = [int(x) for x in row["exonEnds"].split(",") if x]

            transcripts.append(
                Transcript(
                    genome=genome,
                    transcript_id=row["name"],
                    gene=row["name2"],
                    exon_starts=exon_starts,
                    exon_ends=exon_ends,
                    exon_count=row["exonCount"],
                    strand=row["strand"],
                    chromosome=row["chrom"],
                    tx_start=row["txStart"],
                    tx_end=row["txEnd"],
                    cds_start=row["cdsStart"],
                    cds_end=row["cdsEnd"],
                    expand_transcript_distance=expand_transcript_distance,
                    expand_exon_distance=expand_exon_distance,
                )
            )
        return transcripts

    @classmethod
    def load_refseq(
        cls,
        refseq_df_path,
        genome,
        expand_transcript_distance=0,
        expand_exon_distance=0,
        mini=False,
        drop_non_nm=False,
        valid_chromosomes=True
    ):
        if valid_chromosomes:
            chromosomes_to_use = ["chr{}".format(i) for i in range(23)]
        else:
            chromosomes_to_use = genome.get_chrs()

        df = RefseqDataset.load_refseq_as_df(
            refseq_df_path, mini, drop_non_nm=drop_non_nm,
            chromosomes_to_use=chromosomes_to_use,
        )

        transcripts = RefseqDataset.refseq_df_to_transcripts(
            df, expand_transcript_distance, expand_exon_distance, genome
        )
        return transcripts

    def one_hot_encode_dataset(
        self,
        pad_length_to: int = 0,
        zero_mean: bool = True,
        split_transcript: int = 0,
    ) -> np.array:

        # if pad_length_to is not set set it to the maximum length of the transcript
        if not pad_length_to:
            pad_length_to = self.max_transcript_length

        assert (
            pad_length_to >= self.max_transcript_length
        ), "Maximum transcript length in dataset:{} greater than pad_length_to:{}".format(
            self.max_transcript_length, pad_length_to
        )
        if not split_transcript:
            padded_dataset = np.array(
                [
                    t.one_hot_encode_transcript(pad_length_to, zero_mean=zero_mean)
                    for t in self.transcripts
                ]
            )
        else:
            padded_dataset = []
            for t in self.transcripts:
                # One hot encode transcript
                t_one_hot = t.one_hot_encode_transcript(zero_mean=zero_mean)

                # Split the transcript along 0 dimension into n chunks of less than
                # split_transcript_lenght
                extend_distance = (
                    split_transcript - t_one_hot.shape[0] % split_transcript
                )

                # Distance to extend
                n_to_extend = np.zeros((extend_distance, 4), dtype=np.float32)
                t_one_hot = np.concatenate([t_one_hot, n_to_extend])

                # Split the transcript along 0 dimension into n chunks of less than
                # split_transcript_lenght
                number_of_chunks = t_one_hot.shape[0] / split_transcript
                t_one_hot = np.split(t_one_hot, number_of_chunks)

                padded_dataset += t_one_hot

            padded_dataset = np.array(padded_dataset)
        return padded_dataset

    def get_sequence_dataset(self, pad_length_to: int = 0) -> np.array:
        if not pad_length_to:
            pad_length_to = self.max_transcript_length

        assert (
            pad_length_to >= self.max_transcript_length
        ), "Maximum transcript length in dataset:{} greater than pad_length_to:{}".format(
            self.max_transcript_length, pad_length_to
        )
        padded_dataset = np.array(
            [t.get_sequence(pad_length_to) for t in self.transcripts]
        )
        return padded_dataset

    def get_encoded_dataset(
        self, pad_length_to: int = 0, split_transcript: int = 0
    ) -> np.array:

        if not pad_length_to:
            pad_length_to = self.max_transcript_length

        assert (
            pad_length_to >= self.max_transcript_length
        ), "Maximum transcript length in dataset:{} greater than pad_length_to:{}".format(
            self.max_transcript_length, pad_length_to
        )

        if not split_transcript:
            padded_dataset = np.array(
                [t.encode(pad_length_to) for t in self.transcripts]
            )

        else:
            padded_dataset = []
            for t in self.transcripts:
                # encode transcript
                t_encoded = t.encode()

                extend_distance = (
                    split_transcript - t_encoded.shape[0] % split_transcript
                )

                # Distance to extend
                # Ns are indicated by 4s
                n_to_extend = np.zeros(extend_distance, dtype=int) + 4
                t_encoded = np.concatenate([t_encoded, n_to_extend])

                # Split the transcript along 0 dimension into n chunks of less than
                # split_transcript_lenght
                number_of_chunks = t_encoded.shape[0] / split_transcript
                t_encoded = np.split(t_encoded, number_of_chunks)
                padded_dataset += t_encoded

            padded_dataset = np.array(padded_dataset)

        return padded_dataset

    def drop_long_transcripts(self, max_length: int):
        # count number of long transcripts
        number_of_long_transcripts = len(
            [x for x in self.transcripts if len(x) > max_length]
        )
        self.transcripts = [x for x in self.transcripts if len(x) <= max_length]
        self.max_transcript_length = np.max([len(t) for t in self.transcripts])

        print(
            "Dropped {} exceeding length {}".format(
                number_of_long_transcripts, max_length
            )
        )

    def transcript_lengths(self):
        return dict(pd.Series([len(x) for x in self.transcripts]).describe())
