from argparse import ArgumentParser
import os

def get_args():
    psr = ArgumentParser()
    #  file arguments
    psr.add_argument("--base-path", type=str, default=os.path.expanduser("~"))
    psr.add_argument("--img-folder", type=str, default="shot_scale/images/")


    # dataset arguments
    psr.add_argument("--dataset", required=True, type=str, help="One of ['transition', 'scale', 'conversation', 'hmdb51', 'ucf101'")
    psr.add_argument("--mode", type=str, default='sequence', help="'image' or 'sequence'; only used if --dataset is 'scale'")
    psr.add_argument("--max-frames", type=int, default=16) # standard for action recognition
    psr.add_argument("--num-workers", type=int, default=32)
    psr.add_argument("--n-classes", type=int)
    psr.add_argument("--load-width", type=int, default=112) 
    psr.add_argument("--load-height", type=int, default=112)  
    psr.add_argument("--load-first", type=int, default=None)
    psr.add_argument("--normalize", type=int, default=255)
    psr.add_argument("--load-classes", nargs='+', default=None, type=str, help="A space-separated list of the classes in the dataset to load. Only supported on HMDB51 and UCF101.")
    psr.add_argument("--keep-class-order", action='store_true', default=False)

    # Augmentations
    psr.add_argument("--transform", nargs='+', default=None, type=str, help="A space-separated list of the transformation class names to compose.")
    psr.add_argument("--val-transform", nargs='+', default=None, type=str)

    # when applicable
    psr.add_argument("--apply-prob", default=0.5, type=float)

    # File corruption
    psr.add_argument("--corrupt-prob", type=float, default=1.)
    psr.add_argument("--corrupt-mode", type=str, choices=['random','contiguous'], default='random')
    psr.add_argument("--bit-corrupt-levels", default=None, choices=['low', 'high', 'all', None])
    psr.add_argument("--contiguous-probs", default=[], nargs='+', type=float)
    psr.add_argument("--random-probs", default=[], nargs='+', type=float)
    psr.add_argument("--network-probs", default=[], nargs='+', type=float)
    psr.add_argument("--targeted", action='store_true')

    # Network corruption
    psr.add_argument("--network-error-mode", type=str, choices=['loss', 'onoff'], default='loss')
    psr.add_argument("--packet-loss-rate", type=float, default=0.2)
    psr.add_argument("--onofftime", nargs=2, default=(0, 0), type=float)
    psr.add_argument("--link-mode", type=str, choices=['uplink', 'downlink'], default='uplink')
    psr.add_argument("--port", type=int, default=12345)
    psr.add_argument("--ffmpeg-log-level", type=str, choices=['quiet', 'panic', 'fatal', 'error', 'warning', 'info', 'verbose', 'debug'], default='quiet')
    psr.add_argument("--stream-loops", type=int, default=0)
    psr.add_argument("--rebuild-filename-cache", action='store_true')
    psr.add_argument("--corruption-version", type=int, default=0, choices=list(range(5)))
    psr.add_argument("--enforce-readability", action='store_true')
    psr.add_argument("--stream-framerate", default=None)

    # checkpointing
    psr.add_argument("--model-save-path", type=str, default="model.pth")
    psr.add_argument("--load-pretrained", type=str, default=None)
    psr.add_argument("--model-module", type=str, default=None)
    psr.add_argument("--lightweight", action='store_true', default=False)

    psr.add_argument("--model-classname", type=str, default='resnet18')

    # training shenanigans
    psr.add_argument("--bs", type=int, default=16)
    psr.add_argument("--epochs", type=int, default=5)
    psr.add_argument("--val-prop", type=float, default=0.3)
    psr.add_argument("--seed", type=int, default=42)

    # optimization settings
    psr.add_argument("--optimizer", type=str, default='sgd')
    psr.add_argument("--lr", type=float, default=0.1)
    psr.add_argument("--dampening", type=float, default=0)
    psr.add_argument("--momentum", type=float, default=0.9)
    psr.add_argument("--weight-decay", type=float, default=1e-3)

    # training callbacks
    psr.add_argument("--early-stopping", type=int, default=9999) # don't
    psr.add_argument("--lr-patience", type=int, default=10)

    # adversarial training (PGD)
    psr.add_argument("--num-steps", type=int, default=40)
    psr.add_argument("--pgd-eps", type=float, default=8/255)
    psr.add_argument("--pgd-step-size", type=float, default=0.01)
    psr.add_argument("--pgd-norm", default='inf')

    psr.add_argument("--adversarial", type=int, default=None)
    psr.add_argument("--corruption-augmented", action='store_true')

    # ood
    psr.add_argument("--temperature", type=float, default=1.)
    psr.add_argument("--percentile", type=float, default=95)
    psr.add_argument("--force-recalculate-scores", action='store_true')
    psr.add_argument("--force-recalculate-threshold", action='store_true')

    # evaluation hooks
    psr.add_argument("--hooks", type=str, nargs='+', default=None)
    psr.add_argument("--limit", type=int, default=None)
    psr.add_argument("--experiment-mode", type=str, choices=['file', 'network'])

    args = psr.parse_args()
    return args


def get_file_corrupter_args():
    psr = ArgumentParser()
    psr.add_argument("--source", type=str, required=True)
    psr.add_argument("--dest", type=str)
    psr.add_argument("--corruption", type=str, choices=['fill', 'flip', 'whack'], required=True)
    psr.add_argument("--start", type=int)
    psr.add_argument("--length", type=int)
    psr.add_argument("--fill-val", type=int, default=0)

    #  whack options
    psr.add_argument("-p", type=float, default=1.)
    psr.add_argument("--mode", type=str, choices=['random','contiguous'], default='random')
    psr.add_argument("--eps", type=float, default=1e-2)
    psr.add_argument("--attempts", type=int, default=1)
    psr.add_argument("--stop-on-success", action='store_true')
    return psr.parse_args()


def get_network_corrupter_args():
    psr = ArgumentParser()
    psr.add_argument("--source", type=str, required=True)
    psr.add_argument("--dest", type=str, required=True)
    psr.add_argument("--error-mode", type=str, choices=['loss', 'onoff'], default='loss')
    psr.add_argument("--link", type=str, choices=['uplink', 'downlink'])
    psr.add_argument("--rate", type=float, default=0.2)
    psr.add_argument("--onofftime", nargs=2, default=(0, 0), type=float)
    psr.add_argument("--port", type=int, default=12345)
    psr.add_argument("-v", "--verbose", action='store_true')
    return psr.parse_args()


def get_mot15_args():
    psr = ArgumentParser()
    psr.add_argument("--base-path", type=str, default=os.path.expanduser("~"))
    psr.add_argument("--corrupt-mode", type=str, choices=['random', 'contiguous'])
    psr.add_argument("--experiment-mode", type=str, choices=['file','network'])
    psr.add_argument("-p", type=float, required=True)
    psr.add_argument("--verbose", "-v", action='store_true')
    psr.add_argument("--version", type=int)
    return psr.parse_args()


def prettyprint_args(args, banner_width=20):
    return "\n" + "="*banner_width +  "COMMAND LINE ARGS" + "=" * banner_width \
        + "\n" + "\n".join(["{}={}".format(k, v) for k, v in sorted(vars(args).items())]) \
        + "\n" + "="*(2 * banner_width + len("COMMAND LINE ARGS") + 2) + "\n"
