import argparse
import logging
import os
import sys

import numpy as np
import torch

import aggregation.experiments

from aggregation.tasks.cifar10 import Cifar10Task
from aggregation.tasks.cifar100 import Cifar100Task
from aggregation.tasks.emnist import EmnistTask
from aggregation.tasks.emnist62 import Emnist62Task
from aggregation.tasks.mnist import MnistTask


def get_args():
    """Parses CLI args."""
    parser = argparse.ArgumentParser()


    # Utility
    parser.add_argument("--debug", action="store_true",
                        help="only train for 3 rounds")

    parser.add_argument("--use_cuda", action="store_true")

    parser.add_argument("--gpu_id", type=int, help="gpu # to use.")

    parser.add_argument("--overwrite", action="store_true",
                        help=("!!! if results for the experiment already "
                              "exist, overwrite them (rather than exiting) "
                              "!!!"))


    # Experiment Config
    parser.add_argument("--agg", type=str, default=None,
                        help="aggregation method to test")

    parser.add_argument("--attack", type=str, default="sf",
                        help="type of byzantine attacks (default:sf=signflip)")

    parser.add_argument("--batch_size", type=int, default=32,
                        help="training batch size")

    parser.add_argument("--task", type=str, default='mnist',
                        help="'mnist' or 'emnist' or 'cifar10' or 'cifar100'")

    parser.add_argument("--epochs", type=int, default=600,
                        help="# of epochs to train for")

    parser.add_argument("--f", type=int, default=0,
                        help="# of Byzantine workers.")

    parser.add_argument("--use_linear_model", action="store_true",
                        help=("whether to use a simple linear model (rather than a neural net);"
                              "currently supported by mnist"))

    parser.add_argument("--max_batches_per_epoch", type=int, default=1,
                        help=("controls max # of batches per worker per round; "
                              "1=single batch"))

    parser.add_argument("--lr", type=float, default=0.01,
                        help="learning rate")

    parser.add_argument("--n", type=int, default=20,
                        help="# of workers")

    parser.add_argument("--noniid", action="store_true",
                        help="whether to use non-iid data")

    parser.add_argument("--p_norm", type=str, default='1',
                        help="p norm for aggregation methods")

    parser.add_argument("--examples_per_worker", type=int, default=None,
                        help=("how many samples should be assigned to each worker; "
                              "if None, the entire dataset will be split evenly"))

    parser.add_argument("--test_batch_size", type=int, default=128)

    parser.add_argument("--seed", type=int, default=0,
                        help="random seed value for torch and np")

    parser.add_argument("--fixed_point", action="store_true",
                        help="whether to use fixed-point numbers for server operations")


    # Agg:(Multi-)Krum Parameters
    parser.add_argument("--krum_m", type=int, default=1,
                        help="Multi-Krum m value")

    parser.add_argument("--n_reduced_dims", type=int, default=None,
                        help="# of dimensions to reduce to")

    parser.add_argument("--dim_reduce_method", type=str, default=None,
                        help="dimensionality reduction method to use")


    # Agg:RSA Parameters
    parser.add_argument("--rsa_lambda", type=float, default=None,
                        help="RSA bias lambda value")

    parser.add_argument("--rsa_weight_decay", type=float, default=(0.01 / 2),
                        help="RSA weight decay value")

    # Agg: FLTrust Parameters
    parser.add_argument("--fltrust_param1", type=float, default=None,
                        help="placeholder for now. ")

    # Agg:CClip Parameters
    parser.add_argument("--cclip_tau", type=float, default=None,
                        help="tau value")

    parser.add_argument("--cclip_momentum", type=float, default=None,
                        help="momentum value")

    parser.add_argument("--cclip_iters", type=int, default=1,
                        help="# of inner iterations")

    parser.add_argument("--cclip_scaling", type=str, default=None,
                        help="scaling type (optional); None, linear, or sqrt")


    args = parser.parse_args()

    if args.agg is None:
        raise ValueError(f"agg: {args.agg}. Please provide a valid aggregation method.")

    # validate worker numbers
    if args.n <= 0 or args.f < 0 or args.f >= args.n:
        raise ValueError(f"n={args.n} f={args.f}")

    # limit the number of training rounds
    if args.debug:
        args.epochs = 3

    return args


def log_args(args):
    # log a table showing the selected CLI args
    debug_logger = logging.getLogger("debug")
    debug_logger.info("{}Config{}".format('-' * 15, '-' * 15))
    args_dict = dict(**vars(args))
    longest_name = max(map(len, args_dict))
    for setting, value in sorted(args_dict.items()):
        line = setting.ljust(longest_name) + " : " + str(value)
        debug_logger.info(line)
    debug_logger.info("-" * 30)

    if args.debug:
        debug_logger.info("\n=> !!! DEBUG: epochs=3 !!!\n")


def init_loggers(results_path, overwrite, debug):
    stats_path = results_path + '_stats'
    debug_path = results_path + '_debug'

    # check if existing results should be overwritten
    if os.path.isfile(stats_path):
        if not (overwrite or debug):
            print(f"Results for this experiment already exist: {stats_path}")
            sys.exit(0)
        os.remove(stats_path)
        os.remove(debug_path)

    parent_dir = os.path.dirname(stats_path)
    if not os.path.exists(parent_dir):
        os.makedirs(parent_dir)

    # Initialize loggers to display training and debug information.
    # Only to file; One dict per line; Easy to process
    for logger_name in ['stats', 'debug']:
        logger = logging.getLogger(logger_name)
        logger.setLevel(logging.INFO)

        if logger_name == 'stats':
            filename = stats_path

        elif logger_name == 'debug':
            filename = debug_path

            stream_handler = logging.StreamHandler()
            stream_handler.setLevel(logging.INFO)
            stream_handler.setFormatter(logging.Formatter("%(message)s"))
            logger.addHandler(stream_handler)

        file_handler = logging.FileHandler(filename)
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(logging.Formatter("%(message)s"))
        logger.addHandler(file_handler)

    logger.info(f"Logging files to {stats_path}")


def main(args=None):
    # parse CLI args
    if args is None:
        args = get_args()

    # init cuda/random seeds
    if args.use_cuda:
        if not torch.cuda.is_available():
            raise ValueError("=> There is no cuda device!!!")
        kwargs = {"pin_memory": True}

        if args.gpu_id:
            device = torch.device(f"cuda:{args.gpu_id}")
        else:
            device = torch.device("cuda")
    else:
        kwargs = {}
        device = torch.device("cpu")
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # create experiment; an error will be raised if any args are invalid
    root_dir = os.path.dirname(os.path.abspath(__file__))
    dataset_dir = f"/ssd003/home/{os.getenv('USER')}/data"
    # dataset_dir = os.path.join(root_dir, 'datasets')

    experiments = {'avg': aggregation.experiments.AverageExperiment,
                   'cclip': aggregation.experiments.CClipExperiment,
                   'krum': aggregation.experiments.KrumExperiment,
                   'rsa': aggregation.experiments.RSAExperiment,
                   'fltrust': aggregation.experiments.FLTrustExperiment}
    experiment = experiments[args.agg](args=args,
                                       device=device,
                                       use_cuda=args.use_cuda,
                                       dataset_dir=dataset_dir)
    results_path = experiment.get_results_path(root_dir=root_dir)

    # initialize loggers/results file
    init_loggers(results_path=results_path, overwrite=args.overwrite, debug=args.debug)
    log_args(args=args)

    # initialize Task and assign to Experiment
    tasks = {'cifar10': Cifar10Task,
             'cifar100': Cifar100Task,
             'mnist': MnistTask,
             'emnist': EmnistTask,
             'emnist62': Emnist62Task}
    experiment.set_task(tasks[args.task])

    # initialize trainer (which contains both the server and workers' data)
    trainer = experiment.get_trainer()

    # initialize evaluator
    evaluator = experiment.get_evaluator(**kwargs)

    # training loop
    for epoch in range(1, args.epochs + 1):
        trainer.train(epoch)
        evaluator.evaluate(epoch)
        trainer.parallel_call(lambda w: w.data_loader.sampler.set_epoch(epoch))


if __name__ == '__main__':
    main()
