import sys

import numpy as np

sys.path.append("../")

from fedlab.contrib.compressor.quantization import QSGDCompressor
from fedlab.contrib.compressor.topk import TopkCompressor
# from fedlab.contrib.algorithm.basic_client import SGDSerialClientTrainer
# from fedlab.contrib.algorithm.basic_server import SyncServerHandler
from fedlab.contrib.algorithm.scaffold import ScaffoldSerialClientTrainer
from fedlab.contrib.algorithm.scaffold import ScaffoldServerHandler
import torch
import argparse
import os
from opcode import cmp_op
from fedlab.models.mlp import MLP
from torchvision import transforms
from fedlab.contrib.dataset.partitioned_mnist import PartitionedMNIST
from fedlab.utils.functional import evaluate
from fedlab.core.standalone import StandalonePipeline
from torch import nn
from torch.utils.data import DataLoader
import torchvision

def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--compressor', type=str, default='topk', help='choose from topk and qsgd')
    parser.add_argument('--k_ratio', type=float, default='0.5', help='TopK k ratio')
    parser.add_argument('--n_bit', type=int, default=8, help='number of bits for quantization')
    parser.add_argument('--total_client', type=int, default=100)
    parser.add_argument('--alpha', type=float, default=0.5, help="Dirichlet alpha")
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--epochs', type=int, default=5)
    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--batch_size', '-bs', type=int, default=128)
    parser.add_argument('--com_round', type=int, default=10)
    parser.add_argument('--sample_ratio', default=0.1, type=float)
    args = parser.parse_args()
    return args

args = args_parser()
use_cuda = True if torch.cuda.is_available() else False
preprocess = False if os.path.exists("../datasets/mnist/fedmnist/train/data2.pkl") else True

tpk_compressor = TopkCompressor(compress_ratio=args.k_ratio)
qsgd_compressor = QSGDCompressor(n_bit=args.n_bit)
if args.compressor == 'topk':
    compressor = tpk_compressor
elif args.compressor == 'qsgd':
    compressor = qsgd_compressor


class CompressSerialClientTrainer(ScaffoldSerialClientTrainer):
    def setup_compressor(self, compressor):
        self.compressor = compressor

    @property
    def uplink_package(self):
        package = super().uplink_package
        new_package = []
        for content in package:
            pack = [self.compressor.compress(content[0])]
            new_package.append(pack)
        return new_package


class CompressServerHandler(ScaffoldServerHandler):
    def setup_compressor(self, compressor, type):
        self.compressor = compressor
        self.type = type

    def load(self, payload) -> bool:
        if self.type == 'topk':
            values, indices = payload[0]
            decompressed_payload = self.compressor.decompress(values, indices, self.model_parameters.shape)

        if self.type == 'qsgd':
            n, s, l = payload[0]
            decompressed_payload = self.compressor.decompress((n, s, l))

        return  super().load([decompressed_payload])


fed_mnist = PartitionedMNIST(root="../datasets",
                             path="../datasets/fedmnist/",
                             num_clients=args.total_client,
                             partition='noniid-labeldir',
                             dir_alpha=args.alpha,
                             seed=args.seed,
                             preprocess=preprocess,
                             download=True,
                             verbose=True,
                             transform=transforms.Compose([
                                 transforms.ToPILImage(),
                                 transforms.ToTensor()
                             ]))

dataset = fed_mnist.get_dataset(0)  # get the 0-th client's dataset
dataloader = fed_mnist.get_dataloader(0, batch_size=args.batch_size)

from fedlab.models.mlp import MLP
model = MLP(784, 10)

trainer = CompressSerialClientTrainer(model, args.total_client, cuda=use_cuda)
trainer.setup_dataset(fed_mnist)
trainer.setup_optim(args.epochs, args.batch_size, args.lr)
trainer.setup_compressor(compressor)

handler = CompressServerHandler(model=model,
                                global_round=args.com_round,  # TODO change this to iteration, or align Scaffold
                                num_clients=args.total_client,
                                sample_ratio=args.sample_ratio,
                                cuda=use_cuda)
handler.setup_optim(lr=args.lr)
handler.setup_compressor(compressor, args.compressor)


class EvalPipeline(StandalonePipeline):
    def __init__(self, handler, trainer, test_loader):
        super().__init__(handler, trainer)
        self.test_loader = test_loader

    def main(self):
        while self.handler.if_stop is False:
            # server size
            sampled_clients = self.handler.sample_clients()
            broadcast = self.handler.downlink_package

            # client size
            self.trainer.local_process(broadcast, sampled_clients)
            uploads = self.trainer.uplink_package

            # server side
            for pack in uploads:
                self.handler.load(pack)

            loss, acc = evaluate(self.handler.model, nn.CrossEntropyLoss(), self.test_loader)
            print(f"Centralized Evaluation round {self.handler.round}: loss {loss:.4f}, test accuracy {acc:.4f}")


test_data = torchvision.datasets.MNIST(root="../datasets/mnist",
                                       train=False,
                                       download=True,
                                       transform=transforms.ToTensor())
test_loader = DataLoader(test_data, batch_size=1024)
standalone_eval = EvalPipeline(handler=handler,
                               trainer=trainer,
                               test_loader=test_loader)
standalone_eval.main()

