import sys
sys.path.append("../")
from fedlab.contrib.algorithm.spproxskip import SpProxSkipServerHandler, SpProxSkipSerialClientTrainer
from config import args_parser, SelectModelClasses, EvalPipeline
from plot import single_plot

import torch
from torch.utils.data import DataLoader

args = args_parser()
use_cuda = True if torch.cuda.is_available() else False

select_model_classes = SelectModelClasses(args)
train_dataset, test_dataset, model = select_model_classes.mselect()
print_name = select_model_classes.print_name()
compressor = select_model_classes.get_compressor()

trainer = SpProxSkipSerialClientTrainer(model, args.total_client, cuda=use_cuda)
trainer.setup_method(args.method)
trainer.setup_dataset(train_dataset)
trainer.setup_optim(args.epochs, args.batch_size, args.lr)
trainer.setup_compressor(compressor)

handler = SpProxSkipServerHandler(model=model, global_round=args.com_round, num_clients=args.total_client,
                                sample_ratio=args.sample_ratio, cuda=use_cuda)
handler.setup_compressor(compressor, args.compressor)
handler.setup_method(args.method)

test_loader = DataLoader(test_dataset, batch_size=1024)
standalone_eval = EvalPipeline(handler=handler, trainer=trainer, test_loader=test_loader)
res_loss, res_acc = standalone_eval.main()
results = (res_loss, res_acc)

single_plot(args, print_name)
