# Standard Library Imports
import os
import argparse

# Third-Party Library Imports
import numpy as np
import torch

# Specific Imports from Third-Party Libraries
import basic_cnn
import expressive_cnn
from mex import mex
from scrambler import scrambler

BACKGROUND_FREQS = np.array([0.25, 0.25, 0.25, 0.25])

"""
FUNCTIONS FOR ARGUMENT PARSING
"""
def mex_parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--exp_name",
        type=str,
        required=True,
        help="experimental configuration"
    )

    parser.add_argument(
        "--alpha",
        type=float,
        default=1,
        help="alpha=1 for sufficiency and alpha=0 for necessity",
    )

    parser.add_argument(
        "--l1_mult",
        type=float,
        default=3,
        help="Multiplier for l1 norm on explainer outputs",
    )

    parser.add_argument(
        "--sm_mult",
        type=float,
        default=1,
        help="Multiplier for smoothness on explainer outputs",
    )

    parser.add_argument(
        "--batch_size", type=int, default=16, help="Batch size for training",
    )

    parser.add_argument(
        "--lr", type=float, default=1e-3, help="Learning rate for adam optimizer",
    )

    parser.add_argument(
        "--num_epochs", type=int, default=50, help="Number of epochs to train",
    )

    parser.add_argument(
        "--num_bkgd_samples", type=int, default=10, help="Number of background samples to draw",
    )

    parser.add_argument(
        "--freeze", action='store_true', help="Freeze encoder of explainer",
    )

    parser.add_argument(
        "--no-freeze", action='store_false', help="Freeze encoder of explainer",
    )

    return parser.parse_args()

def scrambler_parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--exp_name",
        type=str,
        required=True,
        help="experimental configuration"
    )

    parser.add_argument(
        "--alpha",
        type=float,
        default=1,
        help="alpha=1 for sufficiency and alpha=0 for necessity",
    )

    parser.add_argument(
        "--kl_mult",
        type=float,
        default=1,
        help="Multiplier for kl divergence",
    )

    parser.add_argument(
        "--t_bits",
        type=float,
        default=0.25,
        help="bits",
    )

    parser.add_argument(
        "--batch_size", type=int, default=16, help="Batch size for training",
    )

    parser.add_argument(
        "--lr", type=float, default=1e-3, help="Learning rate for adam optimizer",
    )

    parser.add_argument(
        "--num_epochs", type=int, default=50, help="Number of epochs to train",
    )

    parser.add_argument(
        "--num_bkgd_samples", type=int, default=10, help="Number of background samples to draw",
    )

    return parser.parse_args()

"""
HELPER FUNCTIONS
"""
def load_classifier(device, clf_type="expressive", state_dict_path=None):
    if clf_type == "expressive":
        clf = expressive_cnn.classifier().to(device)
    elif clf_type == "basic":
        clf = basic_cnn.classifier().to(device)
    if state_dict_path is not None:
        clf.load_state_dict(torch.load(os.path.join(state_dict_path, "clf.pt")))
        print("Trained classifier loaded")
    else: 
        print("Untrained classifier loader")
    clf.eval()
    return clf

def load_explainer(device, clf, D, bkgd_seqs, state_dict_path=None, freeze=False):
    explainer = mex(clf, D, bkgd_seqs, init_mask='ones', freeze=freeze).to(device)
    explainer.D = explainer.D.to(device)
    if state_dict_path is not None:
        explainer.load_state_dict(torch.load(os.path.join(state_dict_path, "explainer.pt")))
        print("Trained explainer loaded")
    else:
        print("Untrained explainer loader")
    explainer.eval()
    return explainer

def load_scrambler(device, clf, state_dict_path=None):
    scram = scrambler(clf).to(device)
    if state_dict_path is not None:
        scram.load_state_dict(torch.load(os.path.join(state_dict_path, "scrambler.pt")))
        print("Trained scrambler loaded")
    else:
        print("Untrained scrambler loader")
    scram.eval()
    return scram

def pfm_info_content(pfm, pseudocount=0.001):
    """
    Given an L x 4 PFM, computes information content for each base and
    returns it as an L-array.
    """
    num_bases = pfm.shape[1]
    # Normalize track to probabilities along base axis
    pfm_norm = (pfm + pseudocount) / \
        (np.sum(pfm, axis=1, keepdims=True) + (num_bases * pseudocount))
    ic = pfm_norm * np.log2(pfm_norm / np.expand_dims(BACKGROUND_FREQS, axis=0))
    return np.sum(ic, axis=1)