import argparse
import numpy as np
import random
import os
import torch
from torchvision import datasets
import torch.nn.functional as F
import torch.nn as nn
# torch.set_default_dtype(torch.float64)

# Utility functions
from codes.tasks.cifar10 import cifar10, Net, get_resnet20

# Attacks
from codes.attacks.labelflipping import LabelFlippingWorker
from codes.attacks.bitflipping import BitFlippingWorker
from codes.attacks.mimic import MimicVariantAttacker
from codes.attacks.IPM import IPMAttack, OptimIPMAttack
from codes.attacks.alittle import ALittleIsEnoughAttack, OptimALittleIsEnoughAttack
from codes.attacks.MinMax import MinMaxAttack
from codes.attacks.MinSum import MinSumAttack
from codes.attacks.nonlinear import SSNLPAttack#, NLP1Attack, NLP1AttackLC, EstNLP1Attack
from codes.attacks.stabreaking import StabilityBreakingAttack, StabilityBreakingAttackLC

# Aggregators
from codes.aggregators.base import Mean
from codes.aggregators.coordinatewise_median import CM
from codes.aggregators.clipping import Clipping
from codes.aggregators.rfa import RFA
from codes.aggregators.trimmed_mean import TM
from codes.aggregators.krum import Krum

# FL framwork
from codes.components.utils import top1_accuracy, initialize_logger
from codes.components.worker import MomentumWorker
from codes.components.server import TorchServer
from codes.components.simulator import ParallelTrainer, DistributedEvaluator

# IID vs Non-IID
from codes.components.sampler import (
    DistributedSampler,
    NONIIDLTSampler,
    DirichletSampler,
)


def get_args():
    parser = argparse.ArgumentParser(description="")

    # Utility
    parser.add_argument("--use-cuda", action="store_true", default=False)
    parser.add_argument("--debug", action="store_true", default=False)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--log_interval", type=int, default=10)
    parser.add_argument("--dry-run", action="store_true", default=False)
    parser.add_argument("--identifier", type=str, default="debug", help="")
    parser.add_argument(
        "--plot",
        action="store_true",
        default=False,
        help="If plot is enabled, then ignore all other options.",
    )

    # Experiment configuration
    parser.add_argument("-n", type=int, help="Number of workers")
    parser.add_argument("-f", type=int, help="Number of Byzantine workers.")
    parser.add_argument("--attack", type=str, default="NA", help="Type of attacks.")
    parser.add_argument("--agg", type=str, default="avg", help="Type of aggregation")
    parser.add_argument(
        "--noniid",
        action="store_true",
        default=False,
        help="[HP] noniidness.",
    )
    parser.add_argument("--LT", action="store_true", default=False, help="Long tail")
    parser.add_argument("--dirichlet", type=float, default=0.0, help="[HP] Alpha of Dirichlet sampler")

    # Key hyperparameter
    parser.add_argument("--bucketing", type=int, default=0, help="[HP] s")
    parser.add_argument("--mixing", action="store_true", default=False, help="[HP] mixing")
    parser.add_argument("--momentum", type=float, default=0.0, help="[HP] momentum")

    parser.add_argument("--clip-tau", type=float, default=10.0, help="[HP] momentum")
    parser.add_argument("--clip-scaling", type=str, default=None, help="[HP] momentum")

    parser.add_argument(
        "--mimic-warmup", type=int, default=1, help="the warmup phase in iterations."
    )
    parser.add_argument("--nlpobj", type=float, default=1.0, help="Coefficient of NLP objection")
    parser.add_argument("--nlpsize", type=int, default=0, help="Size of search space of NLP solvers")

    parser.add_argument("--stab_len", type=int, default=1e31, help="Attack length of STAB attack")

    parser.add_argument("--grad_clip", type=float, default=0.0, help="[HP] apply gradient clipping when training.")
    # parser.add_argument(
    #     "--grad_clip",
    #     action="store_true",
    #     default=False,
    #     help="[HP] apply gradient clipping when training.",
    # )

    args = parser.parse_args()

    if args.n <= 0 or args.f < 0 or args.f >= args.n:
        raise RuntimeError(f"n={args.n} f={args.f}")

    assert args.bucketing >= 0, args.bucketing
    assert args.momentum >= 0, args.momentum
    assert len(args.identifier) > 0
    return args


ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + "/"
DATA_DIR = ROOT_DIR + "datasets/"
EXP_DIR = ROOT_DIR + f"outputs/"

LR = 0.25
DECAY = 0.1
DECAY_ITERS = 1000
REGU = 0.01
# Fixed HPs
BATCH_SIZE = 64
TEST_BATCH_SIZE = 128
input_dim = 32 * 32 * 3
output_dim = 10


def _get_aggregator(args):
    if args.agg == "avg":
        return Mean()

    if args.agg == "cm":
        return CM()

    if args.agg == "cp":
        if args.clip_scaling is None:
            tau = args.clip_tau
        elif args.clip_scaling == "linear":
            tau = args.clip_tau / (1 - args.momentum)
        elif args.clip_scaling == "sqrt":
            tau = args.clip_tau / np.sqrt(1 - args.momentum)
        else:
            raise NotImplementedError(args.clip_scaling)
        return Clipping(tau=tau, n_iter=3)

    if args.agg == "rfa":
        return RFA(T=8)

    if args.agg == "tm":
        return TM(b=args.f)

    if args.agg == "krum":
        T = int(np.ceil(args.n / args.bucketing)) if args.bucketing > 0 else args.n
        return Krum(n=T, f=args.f, m=1)

    raise NotImplementedError(args.agg)


def bucketing_wrapper(args, aggregator, s):
    """
    Key functionality.
    """
    print("Using bucketing wrapper.")

    def aggr(inputs):
        indices = list(range(len(inputs)))
        np.random.shuffle(indices)

        T = int(np.ceil(args.n / s))

        reshuffled_inputs = []
        for t in range(T):
            indices_slice = indices[t * s : (t + 1) * s]
            g_bar = sum(inputs[i] for i in indices_slice) / len(indices_slice)
            reshuffled_inputs.append(g_bar)
        return aggregator(reshuffled_inputs)

    return aggr

def nearest_neighbor_mixing(n, f, inputs):
    # n: number of inputs
    # f: number of Byzantine inputs
    # inputs: a list of input vectors

    # Convert list of inputs to a tensor
    x = torch.stack(inputs)

    # Create a list to store output vectors
    outputs = []

    for i in range(n):
        # Compute the Euclidean distances between x_i and all other vectors
        dists = torch.norm(x - x[i], dim=1)

        # Sort the distances and keep the indices
        _, indices = torch.sort(dists)

        # Keep the n-f nearest neighbors and compute their mean
        nearest_neighbors = x[indices[:n-f]]
        y_i = torch.mean(nearest_neighbors, dim=0)

        # Append the mean vector to the outputs list
        outputs.append(y_i)

    return outputs


def mixing_wrapper(args, aggregator):
    """
    Key functionality.
    """
    print("Using mixing wrapper.")

    def aggr(inputs):
        mixed_inputs = nearest_neighbor_mixing(args.n, args.f, inputs)
        return aggregator(mixed_inputs)

    return aggr


def get_aggregator(args):
    aggr = _get_aggregator(args)

    if args.mixing:
        aggr = mixing_wrapper(args, aggr)

    if args.bucketing != 0:
        aggr = bucketing_wrapper(args, aggr, args.bucketing)

    return aggr


def get_sampler_callback(args, rank):
    """
    Get sampler based on the rank of a worker.
    The first `n-f` workers are good, and the rest are Byzantine
    """
    n_good = args.n - args.f
    if rank >= n_good:
        # Byzantine workers
        return lambda x: DistributedSampler(
            num_replicas=n_good,
            rank=rank % (n_good),
            shuffle=True,
            dataset=x,
        )

    if args.dirichlet == 0.0:
        return lambda x: NONIIDLTSampler(
            alpha=not args.noniid,
            beta=0.5 if args.LT else 1.0,
            num_replicas=n_good,
            rank=rank,
            shuffle=True,
            dataset=x,
        )
    else:
        return lambda x: DirichletSampler(
            alpha=args.dirichlet,
            num_replicas=n_good,
            rank=rank,
            shuffle=True,
            batchsize=BATCH_SIZE,
            dataset=x,
        )


def get_test_sampler_callback(args):
    if args.dirichlet == 0.0:
        # This alpha argument is not important as there is
        # only 1 replica
        return lambda x: NONIIDLTSampler(
            alpha=True,
            beta=0.5 if args.LT else 1.0,
            num_replicas=1,
            rank=0,
            shuffle=False,
            dataset=x,
        )
    else:
        return lambda x: NONIIDLTSampler(
            alpha=True,
            beta=1.0,
            num_replicas=1,
            rank=0,
            shuffle=False,
            dataset=x,
        )


def initialize_worker(
    args,
    trainer,
    worker_rank,
    model,
    optimizer,
    loss_func,
    device,
    epochs,
    max_batches_per_epoch,
    kwargs,
):
    train_loader = cifar10(
        data_dir=DATA_DIR,
        train=True,
        download=True,
        batch_size=BATCH_SIZE,
        sampler_callback=get_sampler_callback(args, worker_rank),
        dataset_cls=datasets.CIFAR10,
        drop_last=True,  # Exclude the influence of non-full batch.
        **kwargs,
    )

    if worker_rank < args.n - args.f:
        return MomentumWorker(
            momentum=args.momentum,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            clipping=args.grad_clip,
            worker_rank=worker_rank,
            use_cuda=args.use_cuda,
            **kwargs,
        )

    foldername = f"{args.agg}_{args.attack}_niid{args.noniid if args.noniid else args.dirichlet if args.dirichlet else False}" \
                 f"_n{args.n}_f{args.f}_m{args.momentum}_nlpsize{args.nlpsize}_nlpobj{args.nlpobj}_mix{args.mixing}_clip{args.grad_clip}" \
                 f"_s{args.bucketing}_seed{args.seed}"
    save_dir = EXP_DIR + 'images_cifar10/' + foldername
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)

    if args.attack == "BF":
        attacker = BitFlippingWorker(
            momentum=args.momentum,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            clipping=args.grad_clip,
            worker_rank=worker_rank,
            use_cuda=args.use_cuda,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    if args.attack == "LF":
        return LabelFlippingWorker(
            momentum=args.momentum,
            revertible_label_transformer=lambda target: 9 - target,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            clipping=args.grad_clip,
            worker_rank=worker_rank,
            use_cuda=args.use_cuda,
            **kwargs,
        )

    if args.attack == "mimic":
        attacker = MimicVariantAttacker(
            momentum=args.momentum,
            warmup=args.mimic_warmup,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            clipping=args.grad_clip,
            worker_rank=worker_rank,
            use_cuda=args.use_cuda,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    if args.attack == "IPM":
        attacker = IPMAttack(
            momentum=args.momentum,
            epsilon=0.1,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            clipping=args.grad_clip,
            worker_rank=worker_rank,
            use_cuda=args.use_cuda,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    # if args.attack == "OIPM":
    #     attacker = OptimIPMAttack(
    #         momentum=args.momentum,
    #         data_loader=train_loader,
    #         model=model,
    #         loss_func=loss_func,
    #         device=device,
    #         optimizer=optimizer,
    #         clipping=args.grad_clip,
    #         **kwargs,
    #     )
    #     attacker.configure(trainer)
    #     return attacker

    if args.attack == "ALIE":
        attacker = ALittleIsEnoughAttack(
            momentum=args.momentum,
            n=args.n,
            m=args.f,
            # z=1.5,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            clipping=args.grad_clip,
            worker_rank=worker_rank,
            use_cuda=args.use_cuda,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    # if args.attack == "OALIE":
    #     attacker = OptimALittleIsEnoughAttack(
    #         momentum=args.momentum,
    #         n=args.n,
    #         m=args.f,
    #         data_loader=train_loader,
    #         model=model,
    #         loss_func=loss_func,
    #         device=device,
    #         optimizer=optimizer,
    #         clipping=args.grad_clip,
    #         **kwargs,
    #     )
    #     attacker.configure(trainer)
    #     return attacker

    if args.attack == "MinMax":
        attacker = MinMaxAttack(
            momentum=args.momentum,
            # dev_type = args.dev_type,
            # z=1.5,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            clipping=args.grad_clip,
            worker_rank=worker_rank,
            use_cuda=args.use_cuda,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    if args.attack == "MinSum":
        attacker = MinSumAttack(
            momentum=args.momentum,
            # dev_type = args.dev_type,
            # z=1.5,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            clipping=args.grad_clip,
            worker_rank=worker_rank,
            use_cuda=args.use_cuda,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    # if args.attack == "NLP":
    #     attacker = NLPAttack(
    #         momentum=args.momentum,
    #         n=args.n, f=args.f,
    #         T=epochs, LR=LR,
    #         batch_size=BATCH_SIZE,
    #         max_batches_per_epoch=max_batches_per_epoch,
    #         input_dim=28 * 28,
    #         output_dim=10,
    #         agg=args.agg,
    #         args1=args,
    #         data_loader=train_loader,
    #         model=model,
    #         loss_func=loss_func,
    #         device=device,
    #         optimizer=optimizer,
    #         clipping=args.grad_clip,
    #         **kwargs,
    #     )
    #     attacker.configure(trainer)
    #     return attacker
    #
    # if args.attack == "SNLP":
    #     attacker = SimpNLPAttack(
    #         momentum=args.momentum,
    #         n=args.n, f=args.f,
    #         T=epochs, LR=LR,
    #         batch_size=BATCH_SIZE,
    #         max_batches_per_epoch=max_batches_per_epoch,
    #         input_dim=28 * 28,
    #         output_dim=10,
    #         agg=args.agg,
    #         args1=args,
    #         data_loader=train_loader,
    #         model=model,
    #         loss_func=loss_func,
    #         device=device,
    #         optimizer=optimizer,
    #         clipping=args.grad_clip,
    #         **kwargs,
    #     )
    #     attacker.configure(trainer)
    #     return attacker

    if args.attack == "SSNLP":
        # if args.nlpsize > 1:
        attacker = SSNLPAttack(
            momentum=args.momentum,
            n=args.n, f=args.f,
            T=epochs,
            save_dir=save_dir,
            max_batches_per_epoch=max_batches_per_epoch,
            search_size=args.nlpsize,
            agg=get_aggregator(args),
            args1=args,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            clipping=args.grad_clip,
            worker_rank=worker_rank,
            use_cuda=args.use_cuda,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker
    #     elif args.nlpsize == 1:
    #         attacker = NLP1Attack(
    #             momentum=args.momentum,
    #             n=args.n, f=args.f,
    #             T=epochs,
    #             use_cuda=args.use_cuda,
    #             max_batches_per_epoch=max_batches_per_epoch,
    #             agg=get_aggregator(args),
    #             length=args.stab_len,
    #             data_loader=train_loader,
    #             model=model,
    #             loss_func=loss_func,
    #             device=device,
    #             optimizer=optimizer,
    #             clipping=args.grad_clip,
    #             **kwargs,
    #         )
    #         attacker.configure(trainer)
    #         return attacker
    #
    # if args.attack == "SSNLPLC":
    #     if args.nlpsize > 1:
    #         attacker = SSNLPAttack(
    #             momentum=args.momentum,
    #             n=args.n, f=args.f,
    #             T=epochs, LR=LR,
    #             batch_size=BATCH_SIZE,
    #             max_batches_per_epoch=max_batches_per_epoch,
    #             search_size=args.nlpsize,
    #             input_dim=input_dim,
    #             output_dim=output_dim,
    #             agg=args.agg,
    #             args1=args,
    #             data_loader=train_loader,
    #             model=model,
    #             loss_func=loss_func,
    #             device=device,
    #             optimizer=optimizer,
    #             clipping=args.grad_clip,
    #             **kwargs,
    #         )
    #         attacker.configure(trainer)
    #         return attacker
    #     elif args.nlpsize == 1:
    #         attacker = NLP1AttackLC(
    #             momentum=args.momentum,
    #             n=args.n, f=args.f,
    #             T=epochs,
    #             max_batches_per_epoch=max_batches_per_epoch,
    #             agg=get_aggregator(args),
    #             length=args.stab_len,
    #             data_loader=train_loader,
    #             model=model,
    #             loss_func=loss_func,
    #             device=device,
    #             optimizer=optimizer,
    #             clipping=args.grad_clip,
    #             **kwargs,
    #         )
    #         attacker.configure(trainer)
    #         return attacker
    #
    #
    # if args.attack == "EstSSNLP":
    #     if args.nlpsize > 1:
    #         attacker = SSNLPAttack(
    #             momentum=args.momentum,
    #             n=args.n, f=args.f,
    #             T=epochs, LR=LR,
    #             batch_size=BATCH_SIZE,
    #             max_batches_per_epoch=max_batches_per_epoch,
    #             search_size=args.nlpsize,
    #             input_dim=input_dim,
    #             output_dim=output_dim,
    #             agg=args.agg,
    #             args1=args,
    #             data_loader=train_loader,
    #             model=model,
    #             loss_func=loss_func,
    #             device=device,
    #             optimizer=optimizer,
    #             clipping=args.grad_clip,
    #             **kwargs,
    #         )
    #         attacker.configure(trainer)
    #         return attacker
    #     elif args.nlpsize == 1:
    #         attacker = EstNLP1Attack(
    #             momentum=args.momentum,
    #             n=args.n, f=args.f,
    #             T=epochs,
    #             max_batches_per_epoch=max_batches_per_epoch,
    #             agg=get_aggregator(args),
    #             length=args.stab_len,
    #             data_loader=train_loader,
    #             model=model,
    #             loss_func=loss_func,
    #             device=device,
    #             optimizer=optimizer,
    #             clipping=args.grad_clip,
    #             **kwargs,
    #         )
    #         attacker.configure(trainer)
    #         return attacker


    if args.attack == "STAB":
        attacker = StabilityBreakingAttack(
            momentum=args.momentum,
            n=args.n, f=args.f,
            agg=get_aggregator(args),
            length=args.stab_len,
            data_loader=train_loader,
            model=model,
            loss_func=loss_func,
            device=device,
            optimizer=optimizer,
            clipping=args.grad_clip,
            worker_rank=worker_rank,
            use_cuda=args.use_cuda,
            **kwargs,
        )
        attacker.configure(trainer)
        return attacker

    # if args.attack == "STAB_LC":
    #     attacker = StabilityBreakingAttackLC(
    #         momentum=args.momentum,
    #         n=args.n, f=args.f,
    #         agg=get_aggregator(args),
    #         length=args.stab_len,
    #         data_loader=train_loader,
    #         model=model,
    #         loss_func=loss_func,
    #         device=device,
    #         optimizer=optimizer,
    #         clipping=args.grad_clip,
    #         **kwargs,
    #     )
    #     attacker.configure(trainer)
    #     return attacker

    raise NotImplementedError(f"No such attack {args.attack}")


def main(args, LOG_DIR, EPOCHS, MAX_BATCHES_PER_EPOCH):
    initialize_logger(LOG_DIR)

    if args.use_cuda and not torch.cuda.is_available() and not torch.backends.mps.is_available():
        print("=> There is no cuda/mps device!!!!")
        device = "cpu"
    elif torch.cuda.is_available():
        print(f"=> Use CUDA on {torch.cuda.get_device_name(0)}")
        device = torch.device("cuda" if args.use_cuda else "cpu")
        torch.cuda.manual_seed_all(args.seed)
    elif torch.backends.mps.is_available():
        # print(f"=> Use MPS on macos13_or_newer {torch.backends.mps.is_macos13_or_newer()}")
        device = torch.device("mps" if args.use_cuda else "cpu")
        # torch.backends.mps.manual_seed(args.seed)
    else:
        device = torch.device("cpu")
    # kwargs = {"num_workers": 1, "pin_memory": True} if args.use_cuda else {}
    kwargs = {"pin_memory": True} if args.use_cuda else {}
    print(f"=> Use Device {device}")

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    model = Net().to(device)
    # model = get_resnet20().to(device)

    # Each optimizer contains a separate `state` to store info like `momentum_buffer`
    optimizers = [torch.optim.SGD(model.parameters(), lr=LR, weight_decay=REGU) for _ in range(args.n)]
    server_opt = torch.optim.SGD(model.parameters(), lr=LR, weight_decay=REGU)

    decay_epochs = DECAY_ITERS // MAX_BATCHES_PER_EPOCH
    decay_points = [i * decay_epochs for i in range(1, EPOCHS // decay_epochs + 1)]
    print(decay_points)
    schedulers = [torch.optim.lr_scheduler.MultiStepLR(optim, milestones=decay_points, gamma=DECAY) for optim in optimizers]
    server_sched = torch.optim.lr_scheduler.MultiStepLR(server_opt, milestones=decay_points, gamma=DECAY)

    # loss_func = F.nll_loss
    loss_func = nn.CrossEntropyLoss()

    metrics = {"top1": top1_accuracy}

    server = TorchServer(optimizer=server_opt, model=model, clipping=args.grad_clip)
    trainer = ParallelTrainer(
        server=server,
        aggregator=get_aggregator(args),
        pre_batch_hooks=[],
        post_batch_hooks=[],
        max_batches_per_epoch=MAX_BATCHES_PER_EPOCH,
        log_interval=args.log_interval,
        metrics=metrics,
        use_cuda=args.use_cuda,
        debug=False,
    )

    test_loader = cifar10(
        data_dir=DATA_DIR,
        train=False,
        download=True,
        batch_size=TEST_BATCH_SIZE,
        shuffle=False,
        sampler_callback=get_test_sampler_callback(args),
        **kwargs,
    )

    evaluator = DistributedEvaluator(
        model=model,
        data_loader=test_loader,
        loss_func=loss_func,
        device=device,
        metrics=metrics,
        use_cuda=args.use_cuda,
        debug=False,
    )

    for worker_rank in range(args.n):
        worker = initialize_worker(
            args,
            trainer,
            worker_rank,
            model=model,
            optimizer=optimizers[worker_rank],
            loss_func=loss_func,
            device=device,
            epochs=EPOCHS,
            max_batches_per_epoch=MAX_BATCHES_PER_EPOCH,
            kwargs={},
        )
        trainer.add_worker(worker)

    if not args.dry_run:
        evaluator.evaluate(0)
        for epoch in range(1, EPOCHS + 1):
            trainer.train(epoch)
            print(args.agg, args.attack, args.f, args.dirichlet, args.momentum)
            evaluator.evaluate(epoch)
            trainer.parallel_call(lambda w: w.data_loader.sampler.set_epoch(epoch))

            for scheduler in schedulers:
                scheduler.step()

            server_sched.step()
