"""
Author:*** Time:2024/05/15
"""
from global_update_method.FedTOGA_server import *

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', choices=['CIFAR10', 'CIFAR100', 'Imagenet'], type=str, default='CIFAR10')
parser.add_argument('--split-rule', choices=['Dirichlet', 'Pathological'], type=str,
                    default='Pathological')  # select the dataset splitting rule
parser.add_argument('--non-iid', action='store_true', default=True)
parser.add_argument('--rule-arg', default=6.0, type=float)
parser.add_argument('--act_prob', default=0.1, type=float)
parser.add_argument('--n_client', default=100, type=int)
parser.add_argument('--epochs', default=800, type=int)
parser.add_argument('--local_epochs', default=5, type=int)
parser.add_argument('--alpha', default=0.1, type=float)
parser.add_argument('--beta', default=0.8, type=float)
parser.add_argument('--alpha_coef', default=0.01, type=float)
parser.add_argument('--local-learning-rate', default=0.1, type=float)
parser.add_argument('--global-learning-rate', default=1.0, type=float)
parser.add_argument('--lr_decay', default=0.998, type=float)
parser.add_argument('--sch-gamma', default=1.0, type=float)
parser.add_argument('--test-per', default=1, type=int)
parser.add_argument('--batchsize', default=50, type=int)
parser.add_argument('--seed', default=20, type=int)
parser.add_argument('--rho', default=0.1, type=float)
parser.add_argument('--kappa', default=1, type=float)
args = parser.parse_args()
print(args)

torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
torch.backends.cudnn.deterministic = True

# Dataset initialization
# --n_client=200 0.05 --local_epochs=5  --batchsize=25
# --n_client=100 0.1  --local_epochs=5  --batchsize=50
data_path = './'

savepath = args.dataset

if args.non_iid is False:
    savepath += "_iid"
else:
    savepath += "_non iid_Drichlet" + str(args.rule_arg)
if args.dataset == 'CIFAR10':
    model_name = 'Resnet18'
elif args.dataset == 'CIFAR100':
    model_name = 'Cifar100_Resnet18'
elif args.dataset == 'Imagenet':
    model_name = 'Imagenet_Resnet18'

savepath += "_batch" + str(args.batchsize) + "_frac" + str(args.act_prob) + "_clnt" + str(
    args.n_client) + "_loep" + str(args.local_epochs) + "_glep" + str(args.epochs) + model_name
n_client = args.n_client
# Generate IID or Dirichlet distribution
if args.non_iid is False:
    data_obj = DatasetObject(dataset=args.dataset, n_client=n_client, seed=args.seed, unbalanced_sgm=0, rule='iid',
                             data_path=data_path)
else:
    data_obj = DatasetObject(dataset=args.dataset, n_client=n_client, seed=args.seed, unbalanced_sgm=0,
                             rule=args.split_rule,
                             rule_arg=args.rule_arg, data_path=data_path)
file_path = "Result/" + savepath + '.txt'

# Common hyperparameters
com_amount = args.epochs
save_period = 10
weight_decay = 1e-3
batch_size = args.batchsize
act_prob = args.act_prob
suffix = model_name
lr_decay_per_round = args.lr_decay

# Model function
model_func = lambda: client_model(model_name)
save_dict = {}
# Initialize the model for all methods with a random seed or load it from a saved initial model
# torch.manual_seed(23)
init_model = model_func()
file_path = "Result/" + savepath + '.txt'
save_dict = {}


print("Train FedTOGA")
#
epoch = args.local_epochs
alpha_coef = args.alpha
beta = args.beta
lr = args.lr_decay
Np = True
learning_rate = args.local_learning_rate
test_per = args.test_per
test_perf, train_perf, divergence = train_FedTOGA(data_obj=data_obj, act_prob=act_prob, learning_rate=learning_rate,
                                      batch_size=batch_size, epoch=epoch, com_amount=com_amount, test_per=test_per,
                                      weight_decay=weight_decay, model_func=model_func, init_model=init_model,
                                      alpha_coef=alpha_coef, beta=beta, sch_step=1, sch_gamma=args.sch_gamma,
                                      rho=args.rho, kappa=args.kappa, Np=Np,
                                      rand_seed=args.seed, lr_decay_per_round=lr)
save_dict['FedTOGA' + str(alpha_coef) + ':' + str(beta) + ':' + str(Np) + str(lr)] = [
    test_perf[:com_amount, 1].tolist(),
    train_perf[:com_amount, 0].tolist(), divergence.tolist()]
with open(file_path, mode='w', encoding='utf-8') as file_obj:
    file_obj.write(str(save_dict))



plt.figure(figsize=(6, 5))
for item in save_dict:
    plt.plot(np.arange(com_amount) + 1, save_dict[item][0], label=item)
    # print(item)

plt.ylabel('Test Accuracy', fontsize=16)
plt.xlabel('Communication Rounds', fontsize=16)
plt.legend(fontsize=16, loc='lower right', bbox_to_anchor=(1.015, -0.02))
plt.grid()
plt.xlim([0, com_amount + 2])
# plt.title(data_obj.name, fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.savefig("Result/" + savepath + "_avg_acc" + ".pdf", dpi=1000, bbox_inches='tight')

# plt.clf()

# plt.figure(figsize=(6, 5))
# for item in save_dict:
#     plt.plot(np.arange(com_amount) + 1, save_dict[item][1], label=item)
#
# plt.ylabel('Train Loss', fontsize=16)
# plt.xlabel('Communication Rounds', fontsize=16)
# plt.legend(fontsize=16, loc='upper right')
# plt.grid()
# plt.xlim([0, com_amount + 2])
# # plt.title(data_obj.name, fontsize=16)
# plt.xticks(fontsize=16)
# plt.yticks(fontsize=16)
# plt.savefig("Result/" + savepath + "_loss" + ".pdf", dpi=1000, bbox_inches='tight')
