from train import *
from methods import *

parser = argparse.ArgumentParser()
parser.add_argument(
    "--dataset",
    choices=["CIFAR10", "CIFAR100", "tinyimagenet"],
    type=str,
    default="CIFAR10",
)
parser.add_argument("--non-iid", action="store_true", default=True)
parser.add_argument("--rule_arg", default=0.6, type=float)
parser.add_argument("--n_client", default=50, type=int)
parser.add_argument("--epochs", default=100, type=int)
parser.add_argument("--local_epochs", default=1, type=int)
parser.add_argument("--learning_rate", default=1e-3, type=float)
parser.add_argument("--sch_gamma", default=1.0, type=float)
parser.add_argument("--test_per", default=5, type=int)
parser.add_argument("--batchsize", default=32, type=int)
parser.add_argument("--seed", default=20, type=int)
parser.add_argument("--attack_number", default=0, type=float)
parser.add_argument(
    "--attack_name",
    choices=["Gaussian", "Sign-flip", "LIE", "FoE", "Our-attack"],
    type=str,
    default="Gaussian",
)
parser.add_argument("--convex", default=False, action="store_true")
parser.add_argument("--pretrain", default=False, action="store_true")
args = parser.parse_args()
print(args)

# Dataset initialization
data_path = "./"

n_client = args.n_client
# Generate IID or Dirichlet distribution
if args.non_iid is False:
    data_nature = "iid"
    data_obj = DatasetObject(
        dataset=args.dataset,
        n_client=n_client,
        seed=args.seed,
        unbalanced_sgm=0,
        rule="iid",
        data_path=data_path,
    )
else:
    data_nature = "noniid"
    data_obj = DatasetObject(
        dataset=args.dataset,
        n_client=n_client,
        seed=args.seed,
        unbalanced_sgm=0,
        rule="Drichlet",
        rule_arg=args.rule_arg,
        data_path=data_path,
    )

if args.dataset == "CIFAR10":
    loss_function = "nonconvex"
    model_name = "cifar10_resnet18"

elif args.dataset == "CIFAR100":
    loss_function = "nonconvex"
    model_name = "cifar100_VGG16"

elif args.dataset == "tinyimagenet":
    loss_function = "nonconvex"
    model_name = "tinyimagenet_MobileNet"

# Common hyperparameters
com_amount = args.epochs
batch_size = args.batchsize
attack_number = args.attack_number
attack_name = args.attack_name
train_result = cp.zeros((com_amount, 2))
test_result = cp.zeros((com_amount, 2))

# Model function
model_func = lambda: client_model(model_name).to(device)

torch.manual_seed(0)
init_model = model_func().to(device)

epoch = args.local_epochs
learning_rate = args.learning_rate
test_per = args.test_per

# for attack_name in ["Our-attack"]:
    # for attack_name in ["Gaussian", "Sign-flip", "LIE", "FoE"]:
    # for attack_number in [0.1, 0.2, 0.3, 0.4]:
    # for attack_number in [0.2, 0.4]:
        # for attack_name in ["Gaussian", "Same-value", "Sign-flip", "LIE", "FoE"]:
        #     for attack_number in [0.5, 0.6, 0.7, 0.8, 0.9]:
attackers = random.sample(
    range(1, n_client - 1), math.ceil(n_client * attack_number)
)

# Median
logger().info(
    "med_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    )
)

test_result = train_Median(
    data_obj=data_obj,
    learning_rate=learning_rate,
    batch_size=batch_size,
    epoch=epoch,
    com_amount=com_amount,
    test_per=test_per,
    init_model=init_model,
    model_func=model_func,
    attackers=attackers,
    attack_name=attack_name,
)

# test_result, train_result = train_Median(data_obj=data_obj,
#                         learning_rate=learning_rate, batch_size=batch_size, epoch=epoch, com_amount=com_amount, test_per=test_per, init_model = init_model, model_func=model_func,attackers=attackers,attack_name=attack_name)

# cp.save('results/trn_med_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy'.format(data_obj.dataset,    data_nature,loss_function,attack_name, attack_number, epoch, args.rule_arg), train_result)
cp.save(
    "results/tst_med_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    ),
    test_result,
)

# Krum
logger().info(
    "krum_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    )
)

test_result = train_Krum(
    data_obj=data_obj,
    learning_rate=learning_rate,
    batch_size=batch_size,
    epoch=epoch,
    com_amount=com_amount,
    test_per=test_per,
    init_model=init_model,
    model_func=model_func,
    attackers=attackers,
    attack_name=attack_name,
)

# test_result, train_result = train_Median(data_obj=data_obj,
#                         learning_rate=learning_rate, batch_size=batch_size, epoch=epoch, com_amount=com_amount, test_per=test_per, init_model = init_model, model_func=model_func,attackers=attackers,attack_name=attack_name)

# cp.save('results/trn_med_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy'.format(data_obj.dataset,    data_nature,loss_function,attack_name, attack_number, epoch, args.rule_arg), train_result)
cp.save(
    "results/tst_krum_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    ),
    test_result,
)

# GM
logger().info(
    "gm_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    )
)

test_result = train_GM(
    data_obj=data_obj,
    learning_rate=learning_rate,
    batch_size=batch_size,
    epoch=epoch,
    com_amount=com_amount,
    test_per=test_per,
    init_model=init_model,
    model_func=model_func,
    attackers=attackers,
    attack_name=attack_name,
)

# test_result, train_result = train_SAGA(data_obj=data_obj,
#                         learning_rate=learning_rate, batch_size=batch_size, epoch=epoch, com_amount=com_amount, test_per=test_per, model_func=model_func,attackers=attackers,attack_name=attack_name)

# cp.save('results/trn_saga_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy'.format(data_obj.dataset,    data_nature,loss_function,attack_name, attack_number, epoch, args.rule_arg), train_result)
cp.save(
    "results/tst_gm_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    ),
    test_result,
)

# MCA
logger().info(
    "mca_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    )
)

test_result = train_MCA(
    data_obj=data_obj,
    learning_rate=learning_rate,
    batch_size=batch_size,
    epoch=epoch,
    com_amount=com_amount,
    test_per=test_per,
    init_model=init_model,
    model_func=model_func,
    attackers=attackers,
    attack_name=attack_name,
)

# test_result, train_result = train_MCA(data_obj=data_obj,
#                         learning_rate=learning_rate, batch_size=batch_size, epoch=epoch, com_amount=com_amount, test_per=test_per, init_model = init_model, model_func=model_func,attackers=attackers,attack_name=attack_name)

# cp.save('results/trn_mca_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy'.format(data_obj.dataset,    data_nature,loss_function,attack_name, attack_number, epoch, args.rule_arg), train_result)
cp.save(
    "results/tst_mca_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    ),
    test_result,
)

# CClip
logger().info(
    "cclip_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    )
)

test_result = train_CClip(
    data_obj=data_obj,
    learning_rate=learning_rate,
    batch_size=batch_size,
    epoch=epoch,
    com_amount=com_amount,
    test_per=test_per,
    init_model=init_model,
    model_func=model_func,
    attackers=attackers,
    attack_name=attack_name,
)

# test_result, train_result = train_MCA(data_obj=data_obj,
#                         learning_rate=learning_rate, batch_size=batch_size, epoch=epoch, com_amount=com_amount, test_per=test_per, init_model = init_model, model_func=model_func,attackers=attackers,attack_name=attack_name)

# cp.save('results/trn_mca_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy'.format(data_obj.dataset,    data_nature,loss_function,attack_name, attack_number, epoch, args.rule_arg), train_result)
cp.save(
    "results/tst_cclip_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    ),
    test_result,
)

# H_Public
logger().info(
    "hp_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    )
)

test_result = train_H_public(
    data_obj=data_obj,
    learning_rate=learning_rate,
    batch_size=batch_size,
    epoch=epoch,
    com_amount=com_amount,
    test_per=test_per,
    init_model=init_model,
    model_func=model_func,
    attackers=attackers,
    attack_name=attack_name,
)

# test_result, train_result = train_MCA(data_obj=data_obj,
#                         learning_rate=learning_rate, batch_size=batch_size, epoch=epoch, com_amount=com_amount, test_per=test_per, init_model = init_model, model_func=model_func,attackers=attackers,attack_name=attack_name)

# cp.save('results/trn_mca_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy'.format(data_obj.dataset,    data_nature,loss_function,attack_name, attack_number, epoch, args.rule_arg), train_result)
cp.save(
    "results/tst_hp_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    ),
    test_result,
)

# Zeno++
logger().info(
    "zeno_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    )
)

test_result = train_Zeno(
    data_obj=data_obj,
    learning_rate=learning_rate,
    batch_size=batch_size,
    epoch=epoch,
    com_amount=com_amount,
    test_per=test_per,
    init_model=init_model,
    model_func=model_func,
    attackers=attackers,
    attack_name=attack_name,
)

# test_result, train_result = train_MCA(data_obj=data_obj,
#                         learning_rate=learning_rate, batch_size=batch_size, epoch=epoch, com_amount=com_amount, test_per=test_per, init_model = init_model, model_func=model_func,attackers=attackers,attack_name=attack_name)

# cp.save('results/trn_mca_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy'.format(data_obj.dataset,    data_nature,loss_function,attack_name, attack_number, epoch, args.rule_arg), train_result)
cp.save(
    "results/tst_zeno_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    ),
    test_result,
)

# FLTurst
logger().info(
    "trust_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    )
)

test_result = train_FLTurst(
    data_obj=data_obj,
    learning_rate=learning_rate,
    batch_size=batch_size,
    epoch=epoch,
    com_amount=com_amount,
    test_per=test_per,
    init_model=init_model,
    model_func=model_func,
    attackers=attackers,
    attack_name=attack_name,
)

# test_result, train_result = train_MCA(data_obj=data_obj,
#                         learning_rate=learning_rate, batch_size=batch_size, epoch=epoch, com_amount=com_amount, test_per=test_per, init_model = init_model, model_func=model_func,attackers=attackers,attack_name=attack_name)

# cp.save('results/trn_mca_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy'.format(data_obj.dataset,    data_nature,loss_function,attack_name, attack_number, epoch, args.rule_arg), train_result)
cp.save(
    "results/tst_trust_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    ),
    test_result,
)

# H_Median
logger().info(
    "hmed_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    )
)

test_result = train_H_median(
    data_obj=data_obj,
    learning_rate=learning_rate,
    batch_size=batch_size,
    epoch=epoch,
    com_amount=com_amount,
    test_per=test_per,
    init_model=init_model,
    model_func=model_func,
    attackers=attackers,
    attack_name=attack_name,
)

# test_result, train_result = train_MCA(data_obj=data_obj,
#                         learning_rate=learning_rate, batch_size=batch_size, epoch=epoch, com_amount=com_amount, test_per=test_per, init_model = init_model, model_func=model_func,attackers=attackers,attack_name=attack_name)

# cp.save('results/trn_mca_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy'.format(data_obj.dataset,    data_nature,loss_function,attack_name, attack_number, epoch, args.rule_arg), train_result)
cp.save(
    "results/tst_hmed_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    ),
    test_result,
)

# H_Krum
logger().info(
    "hkrum_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    )
)

test_result = train_H_Krum(
    data_obj=data_obj,
    learning_rate=learning_rate,
    batch_size=batch_size,
    epoch=epoch,
    com_amount=com_amount,
    test_per=test_per,
    init_model=init_model,
    model_func=model_func,
    attackers=attackers,
    attack_name=attack_name,
)

# test_result, train_result = train_MCA(data_obj=data_obj,
#                         learning_rate=learning_rate, batch_size=batch_size, epoch=epoch, com_amount=com_amount, test_per=test_per, init_model = init_model, model_func=model_func,attackers=attackers,attack_name=attack_name)

# cp.save('results/trn_mca_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy'.format(data_obj.dataset,    data_nature,loss_function,attack_name, attack_number, epoch, args.rule_arg), train_result)
cp.save(
    "results/tst_hkrum_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    ),
    test_result,
)

# H_GM
logger().info(
    "hgm_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    )
)

test_result = train_H_GM(
    data_obj=data_obj,
    learning_rate=learning_rate,
    batch_size=batch_size,
    epoch=epoch,
    com_amount=com_amount,
    test_per=test_per,
    init_model=init_model,
    model_func=model_func,
    attackers=attackers,
    attack_name=attack_name,
)

# test_result, train_result = train_MCA(data_obj=data_obj,
#                         learning_rate=learning_rate, batch_size=batch_size, epoch=epoch, com_amount=com_amount, test_per=test_per, init_model = init_model, model_func=model_func,attackers=attackers,attack_name=attack_name)

# cp.save('results/trn_mca_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy'.format(data_obj.dataset,    data_nature,loss_function,attack_name, attack_number, epoch, args.rule_arg), train_result)
cp.save(
    "results/tst_hgm_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    ),
    test_result,
)

# H_MCA
logger().info(
    "hmca_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    )
)

test_result = train_H_MCA(
    data_obj=data_obj,
    learning_rate=learning_rate,
    batch_size=batch_size,
    epoch=epoch,
    com_amount=com_amount,
    test_per=test_per,
    init_model=init_model,
    model_func=model_func,
    attackers=attackers,
    attack_name=attack_name,
)

# test_result, train_result = train_MCA(data_obj=data_obj,
#                         learning_rate=learning_rate, batch_size=batch_size, epoch=epoch, com_amount=com_amount, test_per=test_per, init_model = init_model, model_func=model_func,attackers=attackers,attack_name=attack_name)

# cp.save('results/trn_mca_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy'.format(data_obj.dataset,    data_nature,loss_function,attack_name, attack_number, epoch, args.rule_arg), train_result)
cp.save(
    "results/tst_hmca_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    ),
    test_result,
)

# H_CClip
logger().info(
    "hcclip_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    )
)

test_result = train_H_CClip(
    data_obj=data_obj,
    learning_rate=learning_rate,
    batch_size=batch_size,
    epoch=epoch,
    com_amount=com_amount,
    test_per=test_per,
    init_model=init_model,
    model_func=model_func,
    attackers=attackers,
    attack_name=attack_name,
)

# test_result, train_result = train_MCA(data_obj=data_obj,
#                         learning_rate=learning_rate, batch_size=batch_size, epoch=epoch, com_amount=com_amount, test_per=test_per, init_model = init_model, model_func=model_func,attackers=attackers,attack_name=attack_name)

# cp.save('results/trn_mca_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy'.format(data_obj.dataset,    data_nature,loss_function,attack_name, attack_number, epoch, args.rule_arg), train_result)
cp.save(
    "results/tst_hcclip_{:s}_{:s}_{:s}_{:s}_{:.2f}_{:d}_{:.2f}.npy".format(
        data_obj.dataset,
        data_nature,
        loss_function,
        attack_name,
        attack_number,
        epoch,
        args.rule_arg,
    ),
    test_result,
)