import argparse
import numpy as np
import os
import torch
import random
from torchvision import datasets
import torch.nn.functional as F
import torch.nn as nn

# torch.set_printoptions(precision=20)

# Utility functions
from codes.tasks.toy_2d import toy_data, ToyNet, CustomLoss
from codes.tasks.nonconvex_2d import SimpleModel, Ackley, CombinedGaussian, TwoMinima, ToyDataset

# Attacks
# from codes.attacks_toy.labelflipping import LableFlippingWorker
from codes.attacks_toy.bitflipping import BitFlippingWorker
from codes.attacks_toy.mimic import MimicVariantAttacker
from codes.attacks_toy.IPM import IPMAttack
from codes.attacks_toy.alittle import ALittleIsEnoughAttack
from codes.attacks_toy.MinMax import MinMaxAttack
from codes.attacks_toy.MinSum import MinSumAttack
from codes.attacks_toy.nonlinear import FullNLPAttack, NLPAttack, NOBLEAttack, TAWFOEAttack

# Main Modules
from codes.components.utils_toy import top1_accuracy, initialize_logger
from codes.components.worker_toy import MomentumWorker
from codes.components.server import TorchServer
from codes.components.simulator_toy import ParallelTrainer, DistributedEvaluator

# IID vs Non-IID
from codes.components.sampler_toy import (
    DistributedSampler,
    NONIIDLTSampler,
)

# Aggregators
from codes.aggregators.base import Mean
from codes.aggregators.coordinatewise_median import CM
from codes.aggregators.clipping import Clipping
from codes.aggregators.rfa import RFA
from codes.aggregators.trimmed_mean import TM
from codes.aggregators.krum import Krum


def get_args():
    parser = argparse.ArgumentParser(description="")

    # Utility
    parser.add_argument("--use-cuda", action="store_true", default=False)
    parser.add_argument("--debug", action="store_true", default=False)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--log_interval", type=int, default=10)
    parser.add_argument("--dry-run", action="store_true", default=False)
    parser.add_argument("--identifier", type=str, default="debug", help="")
    parser.add_argument(
        "--plot",
        action="store_true",
        default=False,
        help="If plot is enabled, then ignore all other options.",
    )

    # Experiment configuration
    parser.add_argument("-n", type=int, help="Number of workers")
    parser.add_argument("-f", type=int, help="Number of Byzantine workers.")
    parser.add_argument("--attack", type=str, default="NA", help="Type of attacks.")
    parser.add_argument("--agg", type=str, default="avg", help="")
    parser.add_argument(
        "--noniid",
        action="store_true",
        default=False,
        help="[HP] noniidness.",
    )
    parser.add_argument("--LT", action="store_true", default=False, help="Long tail")
    parser.add_argument("--loss", type=str, default="mse", help="")

    # Key hyperparameter
    parser.add_argument('--initp', type=float, nargs='+', help='initial point')
    parser.add_argument("--bucketing", type=int, default=0, help="[HP] s")
    parser.add_argument("--momentum", type=float, default=0.0, help="[HP] momentum")
    parser.add_argument("--LR", type=float, default=0.1, help="[HP] Learning Rate")
    parser.add_argument("--EPOCH", type=int, default=0, help="[HP] training epochs")

    parser.add_argument("--clip-tau", type=float, default=10.0, help="[HP] momentum")
    parser.add_argument("--clip-scaling", type=str, default=None, help="[HP] momentum")

    parser.add_argument(
        "--mimic-warmup", type=int, default=1, help="the warmup phase in iterations."
    )

    parser.add_argument("--nlpobj", type=float, default=1.0, help="Coefficient of NLP objection")
    parser.add_argument("--nlpsize", type=int, default=0, help="Segment length of NLP solvers")
    parser.add_argument("--dev_type", type=str, default="unit_vec", help="Perturbation vector type of MinMax/MinSum")

    args = parser.parse_args()

    if args.n <= 0 or args.f < 0 or args.f >= args.n:
        raise RuntimeError(f"n={args.n} f={args.f}")

    assert args.bucketing >= 0, args.bucketing
    assert args.momentum >= 0, args.momentum
    assert len(args.identifier) > 0
    return args


ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + "/"
DATA_DIR = ROOT_DIR + f"datasets/toydata.txt"
EXP_DIR = ROOT_DIR + f"outputs/"

# LR = 20
# Fixed HPs
BATCH_SIZE = 1
TEST_BATCH_SIZE = 1
# EPOCHS = 70


def _get_aggregator(args):
    if args.agg == "avg":
        return Mean()

    if args.agg == "cm":
        return CM()

    if args.agg == "cp":
        if args.clip_scaling is None:
            tau = args.clip_tau
        elif args.clip_scaling == "linear":
            tau = args.clip_tau / (1 - args.momentum)
        elif args.clip_scaling == "sqrt":
            tau = args.clip_tau / np.sqrt(1 - args.momentum)
        else:
            raise NotImplementedError(args.clip_scaling)
        return Clipping(tau=tau, n_iter=3)

    if args.agg == "rfa":
        return RFA(T=8)

    if args.agg == "tm":
        return TM(b=args.f)

    if args.agg == "krum":
        T = int(np.ceil(args.n / args.bucketing)) if args.bucketing > 0 else args.n
        return Krum(n=T, f=args.f, m=1)

    raise NotImplementedError(args.agg)


def bucketing_wrapper(args, aggregator, s):
    """
    Key functionality.
    """
    print("Using bucketing wrapper.")

    def aggr(inputs):
        indices = list(range(len(inputs)))
        np.random.shuffle(indices)

        T = int(np.ceil(args.n / s))

        reshuffled_inputs = []
        for t in range(T):
            indices_slice = indices[t * s : (t + 1) * s]
            g_bar = sum(inputs[i] for i in indices_slice) / len(indices_slice)
            reshuffled_inputs.append(g_bar)
        return aggregator(reshuffled_inputs)

    return aggr


def get_aggregator(args):
    aggr = _get_aggregator(args)
    if args.bucketing == 0:
        return aggr

    return bucketing_wrapper(args, aggr, args.bucketing)


def get_sampler_callback(args, rank):
    """
    Get sampler based on the rank of a worker.
    The first `n-f` workers are good, and the rest are Byzantine
    """
    # if rank >= n_good:
    #     # Byzantine workers
    #     return lambda x: DistributedSampler(
    #         num_replicas=n_good,
    #         rank=rank % (n_good),
    #         shuffle=True,
    #         dataset=x,
    #     )

    return lambda x: NONIIDLTSampler(
        alpha=not args.noniid,
        beta=0.5 if args.LT else 1.0,
        num_replicas=args.n,
        rank=rank,
        shuffle=True,
        dataset=x,
    )


def get_test_sampler_callback(args):
    # This alpha argument is not important as there is
    # only 1 replica
    return lambda x: NONIIDLTSampler(
        alpha=True,
        beta=0.5 if args.LT else 1.0,
        num_replicas=1,
        rank=0,
        shuffle=False,
        dataset=x,
    )


def initialize_worker(
    args,
    trainer,
    worker_rank,
    model,
    optimizer,
    loss_func,
    device,
    max_batches_per_epoch,
    kwargs,
):
    train_loader = toy_data(
        data_dir=DATA_DIR,
        train=True,
        download=True,
        batch_size=BATCH_SIZE,
        sampler_callback=get_sampler_callback(args, worker_rank),
        dataset_cls=ToyDataset,
        drop_last=True,  # Exclude the influence of non-full batch.
        **kwargs,
    )

    if worker_rank < args.n - args.f:
        return MomentumWorker(
            momentum=args.momentum,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            **kwargs,
        )

    if args.dev_type != 'unit_vec':
        filename = f"toy2d_{args.loss}_{args.agg}_{args.attack}_{args.dev_type}_niid{args.noniid}_n{args.n}_f{args.f}_nlpobj{args.nlpobj}_nlpsize{args.nlpsize}_initp{args.initp[0]}_{args.initp[1]}_lr{args.LR}_iter{args.EPOCH}_seed{args.seed}"
    else:
        filename = f"toy2d_{args.loss}_{args.agg}_{args.attack}_niid{args.noniid}_n{args.n}_f{args.f}_nlpobj{args.nlpobj}_nlpsize{args.nlpsize}_initp{args.initp[0]}_{args.initp[1]}_lr{args.LR}_iter{args.EPOCH}_seed{args.seed}"
    foldername = filename
    save_dir = EXP_DIR + 'images_2dnonconvex/' + foldername
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    save_dir += '/byz_grads.txt'

    if args.attack == "BF":
        attacker = BitFlippingWorker(
            save_dir=save_dir,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    # if args.attack == "LF":
    #     return LableFlippingWorker(
    #         revertible_label_transformer=lambda target: 9 - target,
    #         data_loader=train_loader,
    #         model=model,
    #         loss_func=loss_func,
    #         device=device,
    #         optimizer=optimizer,
    #         **kwargs,
    #     )

    if args.attack == "mimic":
        attacker = MimicVariantAttacker(
            save_dir=save_dir,
            warmup=args.mimic_warmup,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    if args.attack == "IPM":
        attacker = IPMAttack(
            epsilon=0.1,
            save_dir=save_dir,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    if args.attack == "ALIE":
        attacker = ALittleIsEnoughAttack(
            n=args.n,
            m=args.f,
            # z=1.5,
            save_dir=save_dir,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    if args.attack == "MinMax":
        attacker = MinMaxAttack(
            dev_type=args.dev_type,
            # z=1.5,
            save_dir=save_dir,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    if args.attack == "MinSum":
        attacker = MinSumAttack(
            dev_type=args.dev_type,
            # z=1.5,
            save_dir=save_dir,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    if args.attack == "FullNLP":
        attacker = FullNLPAttack(
            # dev_type = args.dev_type,
            # z=1.5,
            worker_rank=worker_rank,
            save_dir=save_dir,
            momentum=args.momentum,
            n=args.n, f=args.f,
            T=args.EPOCH,
            max_batches_per_epoch=max_batches_per_epoch,
            search_size=args.nlpsize,
            agg=get_aggregator(args),
            args1=args,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    if args.attack == "NLP":
        attacker = NLPAttack(
            # dev_type = args.dev_type,
            # z=1.5,
            save_dir=save_dir,
            momentum=args.momentum,
            n=args.n, f=args.f,
            T=args.EPOCH,
            max_batches_per_epoch=max_batches_per_epoch,
            search_size=args.nlpsize,
            agg=get_aggregator(args),
            args1=args,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    if args.attack == "NOBLE":
        attacker = NOBLEAttack(
            # dev_type = args.dev_type,
            # z=1.5,
            save_dir=save_dir,
            momentum=args.momentum,
            n=args.n, f=args.f,
            T=args.EPOCH,
            max_batches_per_epoch=max_batches_per_epoch,
            search_size=args.nlpsize,
            agg=get_aggregator(args),
            args1=args,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    if args.attack == "TAWFOE":
        attacker = TAWFOEAttack(
            epsilon=0.1,
            n=args.n,
            f=args.f,
            save_dir=save_dir,
            agg=get_aggregator(args),
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    raise NotImplementedError(f"No such attack {args.attack}")


def main(args, LOG_DIR, EPOCHS, MAX_BATCHES_PER_EPOCH, initp):
    initialize_logger(LOG_DIR)

    if args.use_cuda and not torch.cuda.is_available() and not torch.backends.mps.is_available():
        print("=> There is no cuda/mps device!!!!")
        device = "cpu"
    elif torch.cuda.is_available():
        print(f"=> Use CUDA on {torch.cuda.get_device_name(0)}")
        device = torch.device("cuda" if args.use_cuda else "cpu")
        torch.cuda.manual_seed_all(args.seed)
    elif torch.backends.mps.is_available():
        # print(f"=> Use MPS on macos13_or_newer {torch.backends.mps.is_macos13_or_newer()}")
        device = torch.device("mps" if args.use_cuda else "cpu")
        # torch.backends.mps.manual_seed(args.seed)
    else:
        device = torch.device("cpu")
    # kwargs = {"num_workers": 1, "pin_memory": True} if args.use_cuda else {}
    kwargs = {"pin_memory": True} if args.use_cuda else {}
    print(f"=> Use Device {device}")

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    model = SimpleModel(init_coords=(initp[0], initp[1])).to(device)

    # Each optimizer contains a separate `state` to store info like `momentum_buffer`
    optimizers = [torch.optim.SGD(model.parameters(), lr=args.LR) for _ in range(args.n)]
    server_opt = torch.optim.SGD(model.parameters(), lr=args.LR)


    if args.loss == 'two':
        loss_func = TwoMinima()
    elif args.loss == 'gaussian':
        loss_func = CombinedGaussian()
    else:
        loss_func = CustomLoss()

    metrics = {"top1": top1_accuracy}

    server = TorchServer(optimizer=server_opt, model=model, clipping=0)
    trainer = ParallelTrainer(
        server=server,
        aggregator=get_aggregator(args),
        pre_batch_hooks=[],
        post_batch_hooks=[],
        max_batches_per_epoch=MAX_BATCHES_PER_EPOCH,
        log_interval=args.log_interval,
        metrics=metrics,
        use_cuda=args.use_cuda,
        debug=False,
    )

    # test_loader = toy_data(
    #     data_dir=DATA_DIR,
    #     train=False,
    #     download=True,
    #     batch_size=TEST_BATCH_SIZE,
    #     shuffle=False,
    #     sampler_callback=get_test_sampler_callback(args),
    #     **kwargs,
    # )
    #
    # evaluator = DistributedEvaluator(
    #     model=model,
    #     data_loader=test_loader,
    #     loss_func=loss_func,
    #     device=device,
    #     metrics=metrics,
    #     use_cuda=args.use_cuda,
    #     debug=False,
    # )

    for worker_rank in range(args.n):
        worker = initialize_worker(
            args,
            trainer,
            worker_rank,
            model=model,
            optimizer=optimizers[worker_rank],
            loss_func=loss_func,
            device=device,
            max_batches_per_epoch=MAX_BATCHES_PER_EPOCH,
            kwargs={},
        )
        trainer.add_worker(worker)

    points = []
    grads = []
    if not args.dry_run:
        for epoch in range(1, EPOCHS + 1):
            trainer.train(epoch)
            for group in trainer.server.optimizer.param_groups:
                point = []
                grad = []
                for p in group["params"]:
                    point.append(p.data.clone().detach().cpu().tolist())
                    grad.append(p.grad.data.clone().detach().cpu().tolist())
                points.append(point[0])
                print(point)
                grads.append(grad[0])
            # evaluator.evaluate(epoch)
            trainer.parallel_call(lambda w: w.data_loader.sampler.set_epoch(epoch))
        return points, grads


