# Standard Library Imports
import os
import argparse

# Third-Party Library Imports
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.utils
import torch.utils.data
from torch.utils.data import DataLoader

# Specific Imports from Third-Party Libraries
from tqdm import tqdm

"""
FUNCTIONS FOR ARGUMENT PARSING
"""


def clf_parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--batch_size",
        type=int,
        default=16,
        help="Batch size for training",
    )

    parser.add_argument(
        "--lr",
        type=float,
        default=1e-3,
        help="Learning rate for adam optimizer",
    )

    parser.add_argument(
        "--num_epochs",
        type=int,
        default=10,
        help="Number of epochs to train",
    )

    parser.add_argument(
        "--data_path",
        type=str,
        required=True,
        help="path to folder with data",
    )

    return parser.parse_args()


def mex_parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--alpha",
        type=float,
        default=1,
        help="alpha=1 for sufficiency, alpha=0 for necessity, alpha=0.5 for both",
    )

    parser.add_argument(
        "--l1_mult",
        type=float,
        default=4,
        help="Multiplier for l1 norm on explainer outputs",
    )

    parser.add_argument(
        "--sm_mult",
        type=float,
        default=2,
        help="Multiplier for smoothness on explainer outputs",
    )

    parser.add_argument(
        "--batch_size",
        type=int,
        default=16,
        help="Batch size for training",
    )

    parser.add_argument(
        "--lr",
        type=float,
        default=1e-3,
        help="Learning rate for adam optimizer",
    )

    parser.add_argument(
        "--num_epochs",
        type=int,
        default=50,
        help="Number of epochs to train",
    )

    parser.add_argument(
        "--num_bkgd_samples",
        type=int,
        default=10,
        help="Number of background samples to draw",
    )

    parser.add_argument(
        "--data_path",
        type=str,
        required=True,
        help="path to folder with data",
    )

    return parser.parse_args()


"""
FUNCTIONS FOR CREATING AND SAVING FILES
"""


def get_gzfile(data_path):
    files = os.listdir(data_path)
    for file in files:
        if file.endswith(".gz"):
            gz_file = file
            break
    return gz_file


def create_folder(folder_name):
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
        print(f"Created folder: {folder_name}")
    else:
        print(f"Folder already exists: {folder_name}")


def save_dict(dict, save_dir, file_name):
    path = os.path.join(save_dir, file_name)
    with open(path, "w") as f:
        for key, value in dict.items():
            f.write(f"{key}: {value}\n")
    print(f"{file_name} saved to {path}")


"""
HELPER FUNCTIONS TO RUN EXPERIMENTS
"""
def get_dist_matrix():
    indices = torch.arange(500)
    D = torch.abs(indices.unsqueeze(1) - indices.unsqueeze(0))
    return D.unsqueeze(0)

def get_zero_seq_idx(seqs):
    zero_array = np.zeros((500, 4))
    zero_idx = np.where((seqs == zero_array).all(axis=(1, 2)))[0]
    return zero_idx

def remove_zero_seqs(original_dataloader):
    seqs = []
    labels = []
    for batch in tqdm(original_dataloader):
        x, y = batch
        seqs.append(x)
        labels.append(y)
    seqs = np.concatenate(seqs)
    labels = np.concatenate(labels)
    zero_idx = get_zero_seq_idx(seqs)
    seqs = np.delete(seqs, zero_idx, axis=0)
    labels = np.delete(labels, zero_idx)
    return torch.utils.data.TensorDataset(torch.tensor(seqs), torch.tensor(labels))

def filter_dataset(original_dataloader, clf, thresh=None, true_pos=False):
    # Generate predictions
    device = next(clf.parameters()).device
    seqs = []
    labels = []
    probs = []
    for batch in tqdm(original_dataloader):
        x, y = batch
        seqs.append(x)
        labels.append(y)
        x = x.float().to(device)
        probs.append(clf(x).squeeze(-1).detach().cpu())
    seqs = np.concatenate(seqs)
    labels = np.concatenate(labels)
    probs = np.concatenate(probs)
    y_hat = (probs >= 0.5) * 1.0

    # Get correct predictions
    correct_idx = y_hat == labels
    if true_pos == True:
        assert thresh is not None
        correct_idx = correct_idx & (probs >= thresh)
    correct_x = seqs[correct_idx]
    correct_y = labels[correct_idx]
    print("Filtered data!")
    return torch.utils.data.TensorDataset(
        torch.tensor(correct_x), torch.tensor(correct_y)
    )

def sample_masked_X(x, masks, b_samples):
    masks = masks.repeat(1, 1, 4)
    pssm_S = (masks * x).unsqueeze(1) + b_samples * ((1 - masks).unsqueeze(1))
    pssm_Sc = ((1 - masks) * x).unsqueeze(1) + b_samples * (masks.unsqueeze(1))
    x_S = F.gumbel_softmax(torch.logit(pssm_S, eps=0.01), tau=1, hard=True)
    x_Sc = F.gumbel_softmax(torch.logit(pssm_Sc, eps=0.01), tau=1, hard=True)
    return x_S.float(), x_Sc.float()

def generate_background(N, bkgd_probs=None):
    if bkgd_probs is None:
        bkgd_probs = torch.ones((N, 500, 4)) * 0.25
    else:
        bkgd_probs = torch.tensor(bkgd_probs).unsqueeze(0).repeat(N, 1, 1)
    bkgd_logits = torch.logit(bkgd_probs, eps=0.01)
    bkgd_samples = F.gumbel_softmax(bkgd_logits, tau=1, hard=True)
    train_seqs = bkgd_samples.numpy()
    return train_seqs


"""
FUNCTIONS FOR LOADING/FORMATTING DATA
"""


def string_to_char_array(seq):
    """
    Converts an ASCII string to a NumPy array of byte-long ASCII codes.
    e.g. "ACGT" becomes [65, 67, 71, 84].
    """
    return np.frombuffer(bytes(seq, "utf8"), dtype=np.int8)


def char_array_to_string(arr):
    """
    Converts a NumPy array of byte-long ASCII codes into an ASCII string.
    e.g. [65, 67, 71, 84] becomes "ACGT".
    """
    return arr.tostring().decode("ascii")


def one_hot_to_tokens(one_hot):
    """
    Converts an L x D one-hot encoding into an L-vector of integers in the range
    [0, D], where the token D is used when the one-hot encoding is all 0. This
    assumes that the one-hot encoding is well-formed, with at most one 1 in each
    column (and 0s elsewhere).
    """
    tokens = np.tile(one_hot.shape[1], one_hot.shape[0])  # Vector of all D
    seq_inds, dim_inds = np.where(one_hot)
    tokens[seq_inds] = dim_inds
    return tokens


def tokens_to_one_hot(tokens, one_hot_dim):
    """
    Converts an L-vector of integers in the range [0, D] to an L x D one-hot
    encoding. The value `D` must be provided as `one_hot_dim`. A token of D
    means the one-hot encoding is all 0s.
    """
    identity = np.identity(one_hot_dim + 1)[:, :-1]  # Last row is all 0s
    return identity[tokens]


def dinuc_shuffle(seq, num_shufs, rng=None):
    """
    Creates shuffles of the given sequence, in which dinucleotide frequencies
    are preserved.
    Arguments:
        `seq`: either a string of length L, or an L x D NumPy array of one-hot
            encodings
        `num_shufs`: the number of shuffles to create, N
        `rng`: a NumPy RandomState object, to use for performing shuffles
    If `seq` is a string, returns a list of N strings of length L, each one
    being a shuffled version of `seq`. If `seq` is a 2D NumPy array, then the
    result is an N x L x D NumPy array of shuffled versions of `seq`, also
    one-hot encoded.
    """
    if type(seq) is str:
        arr = string_to_char_array(seq)
    elif type(seq) is np.ndarray and len(seq.shape) == 2:
        seq_len, one_hot_dim = seq.shape
        arr = one_hot_to_tokens(seq)
    else:
        raise ValueError("Expected string or one-hot encoded array")

    if not rng:
        rng = np.random.RandomState()

    # Get the set of all characters, and a mapping of which positions have which
    # characters; use `tokens`, which are integer representations of the
    # original characters
    chars, tokens = np.unique(arr, return_inverse=True)

    # For each token, get a list of indices of all the tokens that come after it
    shuf_next_inds = []
    for t in range(len(chars)):
        mask = tokens[:-1] == t  # Excluding last char
        inds = np.where(mask)[0]
        shuf_next_inds.append(inds + 1)  # Add 1 for next token

    if type(seq) is str:
        all_results = []
    else:
        all_results = np.empty((num_shufs, seq_len, one_hot_dim), dtype=seq.dtype)

    for i in range(num_shufs):
        # Shuffle the next indices
        for t in range(len(chars)):
            inds = np.arange(len(shuf_next_inds[t]))
            inds[:-1] = rng.permutation(len(inds) - 1)  # Keep last index same
            shuf_next_inds[t] = shuf_next_inds[t][inds]

        counters = [0] * len(chars)

        # Build the resulting array
        ind = 0
        result = np.empty_like(tokens)
        result[0] = tokens[ind]
        for j in range(1, len(tokens)):
            t = tokens[ind]
            ind = shuf_next_inds[t][counters[t]]
            counters[t] += 1
            result[j] = tokens[ind]

        if type(seq) is str:
            all_results.append(char_array_to_string(chars[result]))
        else:
            all_results[i] = tokens_to_one_hot(chars[result], one_hot_dim)
    return all_results


def seqs_to_one_hot(seqs, alphabet="ACGT", to_upper=True, out_dtype=np.float64):
    """
    Converts a list of strings to one-hot encodings, where the position of 1s is
    ordered by the given alphabet.
    Arguments:
        `seqs`: a list of N strings, where every string is the same length L
        `alphabet`: string of length D containing the alphabet used to do
            the encoding; defaults to "ACGT", so that the position of 1s is
            alphabetical according to "ACGT"
        `to_upper`: if True, convert all bases to upper-case prior to performing
            the encoding
        `out_dtype`: NumPy datatype of the output one-hot sequences; defaults
            to `np.float64` but can be changed (e.g. `np.int8` drastically
            reduces memory usage)
    Returns an N x L x D NumPy array of one-hot encodings, in the same order as
    the input sequences. Any bases that are not in the alphabet will be given an
    encoding of all 0s.
    """
    seq_len = len(seqs[0])
    assert np.all(np.array([len(s) for s in seqs]) == seq_len)

    # Get ASCII codes of alphabet in order
    alphabet_codes = np.frombuffer(bytearray(alphabet, "utf8"), dtype=np.int8)

    # Join all sequences together into one long string, all uppercase
    seq_concat = "".join(seqs).upper() + alphabet
    # Add one example of each base, so np.unique doesn't miss indices later

    one_hot_map = np.identity(len(alphabet) + 1)[:, :-1].astype(out_dtype)

    # Convert string into array of ASCII character codes;
    base_vals = np.frombuffer(bytearray(seq_concat, "utf8"), dtype=np.int8)

    # Anything that's not in the alphabet gets assigned a higher code
    base_vals[~np.isin(base_vals, alphabet_codes)] = np.max(alphabet_codes) + 1

    # Convert the codes into indices, in ascending order by code
    _, base_inds = np.unique(base_vals, return_inverse=True)

    # Get the one-hot encoding for those indices, and reshape back to separate
    return one_hot_map[base_inds[: -len(alphabet)]].reshape(
        (len(seqs), seq_len, len(alphabet))
    )


def one_hot_to_seqs(one_hot, alphabet="ACGT", unk_token="N"):
    """
    Converts a one-hot encoding into a list of strings, where the position of 1s
    is ordered by the given alphabet.
    Arguments:
        `one_hot`: an N x L x D array of one-hot encodings
        `alphabet`: string of length D containing the alphabet used to do
            the decoding; defaults to "ACGT", so that the position of 1s is
            alphabetical according to "ACGT"
        `unk_token`: token to use for a one-hot encoding of all 0s
    Returns a list of N strings, each of length L, in the same order as the
    input array. The returned sequences will only consist of characters in the
    alphabet or `unk_token`. Any encodings that are all 0s will be translated to
    `unk_token`.
    """
    assert len(alphabet) == one_hot.shape[2]
    assert len(unk_token) == 1
    bases = np.array(list(alphabet) + [unk_token])

    # Create N x L array of all Ds
    one_hot_inds = np.tile(one_hot.shape[2], one_hot.shape[:2])

    # Get indices of where the 1s are
    batch_inds, seq_inds, base_inds = np.where(one_hot)

    # In each of the locations in the N x L array, fill in the location of the 1
    one_hot_inds[batch_inds, seq_inds] = base_inds

    # Fetch the corresponding base for each position using indexing
    seq_array = bases[one_hot_inds]
    return ["".join(seq) for seq in seq_array]


def import_peaks_bed(peaks_bed):
    """
    Imports a peaks BED file in NarrowPeak format as a Pandas DataFrame.
    Arguments:
        `peaks_bed`: a BED file (gzipped or not) containing peaks in ENCODE
            NarrowPeak format
    Returns a Pandas DataFrame.
    """
    return pd.read_csv(
        peaks_bed,
        sep="\t",
        header=None,  # Infer compression
        names=[
            "chrom",
            "peak_start",
            "peak_end",
            "name",
            "score",
            "strand",
            "signal",
            "pval",
            "qval",
            "summit_offset",
        ],
    )


"""
def filter_indices(data_frame, t1, t2, pos_only=False, neg_only=False):
    # Filters data based off confidence of predictions
    pos_indices = np.where((data_frame.y == 1) & (data_frame.probs >= t2))[0]
    neg_indices = np.where((data_frame.y == 0) & (data_frame.probs < t1))[0]
    
    if pos_only==True:
        return pos_indices
    elif neg_only==True:
        return neg_indices
    else: 
        return np.concatenate([pos_indices, neg_indices])
"""
