# Compare with baselines

import sys

import numpy as np

sys.path.append("../")

from fedlab.contrib.compressor.quantization import QSGDCompressor
from fedlab.contrib.compressor.topk import TopkCompressor
# from fedlab.contrib.algorithm.basic_client import SGDClientTrainer, SGDSerialClientTrainer
# from fedlab.contrib.algorithm.basic_server import SyncServerHandler
# from fedlab.contrib.algorithm.proxskip import ProxSkipSerialClientTrainer, ProxSkipServerHandler
from fedlab.contrib.algorithm.proxskip_baselines import ProxSkipSerialClientTrainer, ProxSkipServerHandler
import torch
import argparse
import os
from opcode import cmp_op
from fedlab.models.mlp import MLP
from torchvision import transforms
from fedlab.contrib.dataset.partitioned_mnist import PartitionedMNIST
from fedlab.contrib.dataset.partitioned_cifar10 import PartitionedCIFAR10
from fedlab.utils.functional import evaluate
from fedlab.core.standalone import StandalonePipeline
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from fedlab.models.mlp import MLP
from fedlab.models.cnn import CNN_MNIST
from fedlab.models.cnn import CNN_CIFAR10, AlexNet_CIFAR10


def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--compressor', type=str, default='topk', help='choose from topk and qsgd')
    parser.add_argument('--k_ratio', type=float, default='0.5', help='TopK k ratio')
    parser.add_argument('--n_bit', type=int, default=8, help='number of bits for quantization')
    parser.add_argument('--total_client', type=int, default=10)
    parser.add_argument('--alpha', type=float, default=0.7, help="Dirichlet alpha")
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--epochs', type=int, default=5)
    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--batch_size', '-bs', type=int, default=1024)
    parser.add_argument('--com_round', type=int, default=50000)
    parser.add_argument('--sample_ratio', default=0.1, type=float)
    parser.add_argument('--model', default='MLP', type=str, help="MLP?")
    parser.add_argument('--dataset', default='PartitionedMNIST', type=str, help="PartitionedMNIST | PartitionedCIFAR10?")
    parser.add_argument('--method', default='SpProxSkip_Com', type=str,
                        help="FedAvg, SpProxSkip_Com, SpProxSkip_Local, SpProxSkip_Global")
    args = parser.parse_args()
    return args


args = args_parser()
use_cuda = True if torch.cuda.is_available() else False
if args.dataset == "PartitionedMNIST":
    partition = "noniid-labeldir"
    data_path = f"../datasets/mnist/{partition}_{args.alpha}_{args.total_client}"

elif args.dataset == "PartitionedCIFAR10":
    partition = "dirichlet"
    data_path = f"../datasets/cifar10/{partition}_{args.alpha}_{args.total_client}"

preprocess = False if os.path.exists(data_path) else True

constraint = f"{args.method}_{args.model}_{args.compressor}_{args.k_ratio}_{args.n_bit}_{args.total_client}_{args.alpha}_{args.epochs}_{args.lr}_{args.com_round}"


if args.method == "SpProxSkip_Com":
    class CompressSerialClientTrainer(ProxSkipSerialClientTrainer):
        def setup_compressor(self, compressor):
            self.compressor = compressor


    class CompressServerHandler(ProxSkipServerHandler):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)

        def setup_compressor(self, compressor, type):
            self.compressor = compressor
            self.type = type

        def load(self, payload) -> bool:
            if self.type == 'topk':
                (values, indices), data_size = payload
                values = values.to(self.model_parameters.device)
                indices = indices.to(self.model_parameters.device)
                decompressed_payload = self.compressor.decompress(values, indices, self.model_parameters.shape)

            elif self.type == 'qsgd':
                (n, s, l), data_size = payload
                decompressed_payload = self.compressor.decompress((n, s, l))

            decompressed_payload_with_weight = (decompressed_payload, data_size)
            return super().load([decompressed_payload_with_weight])


if args.method == "SpProxSkip_Local":
    # TODO: STARTINING FROM HERE
    class CompressSerialClientTrainer(ProxSkipSerialClientTrainer):
        def setup_compressor(self, compressor):
            self.compressor = compressor

        def uplink_package(self):
            new_package = []
            for pack in self.cache:
                model_parameters, data_size = pack  # Assuming pack has model parameters and data size
                # work with cpu parameters for compression
                # compressed_parameters = self.compressor.compress(model_parameters.to('cpu'))
                compressed_parameters = model_parameters # no average compression
                new_package.append((compressed_parameters, data_size))

            return new_package

        def train(self, id, model_parameters, train_loader, p=0.1):
            self.set_model(model_parameters)
            self._model.train()

            # local compression, only difference with SpProxSkip_Com
            self._model.parameters = self.compressor.compress(self._model.parameters())
            self._model.parameters = self.compressor.decompress(self._model.parameters())

            tmp_model_parameters = [torch.clone(param.data) for param in self._model.parameters()]

            if self.h[id] is None:
                self.h[id] = [torch.zeros_like(param) for param in self._model.parameters()]

            it_local = self.rng_skip.geometric(p=p)
            data_size = 0
            for _ in range(it_local):
                for data, target in train_loader:
                    if self.cuda:
                        data = data.cuda(self.device)
                        target = target.cuda(self.device)

                    output = self._model(data)
                    loss = self.criterion(output, target)

                    data_size += len(target)

                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                # Update h after each optimizer step
                with torch.no_grad():
                    for i, param in enumerate(self._model.parameters()):
                        self.h[id][i] += p / self.lr * (tmp_model_parameters[i] - param.data)
                        tmp_model_parameters[i] = torch.clone(param.data)

            # Convert updated model parameters to a single tensor
            updated_model_parameters = torch.cat([param.data.view(-1) for param in self._model.parameters()])
            return [updated_model_parameters, data_size]


    class CompressServerHandler(ProxSkipServerHandler):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)

        def setup_compressor(self, compressor, type):
            self.compressor = compressor
            self.type = type

        def load(self, payload) -> bool:
            if self.type == 'topk':
                (values, indices), data_size = payload
                values = values.to(self.model_parameters.device)
                indices = indices.to(self.model_parameters.device)
                # decompressed_payload = self.compressor.decompress(values, indices, self.model_parameters.shape)
                decompressed_payload = payload

            elif self.type == 'qsgd':
                # (n, s, l), data_size = payload
                # decompressed_payload = self.compressor.decompress((n, s, l))
                decompressed_payload = payload

            decompressed_payload_with_weight = (decompressed_payload, data_size)
            return super().load([decompressed_payload_with_weight])


if args.method == "SpProxSkip_Global":
    # TODO: not implemented
    class CompressSerialClientTrainer(ProxSkipSerialClientTrainer):
        def setup_compressor(self, compressor):
            self.compressor = compressor


    class CompressServerHandler(ProxSkipServerHandler):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)

        def setup_compressor(self, compressor, type):
            self.compressor = compressor
            self.type = type

        def load(self, payload) -> bool:
            if self.type == 'topk':
                (values, indices), data_size = payload
                values = values.to(self.model_parameters.device)
                indices = indices.to(self.model_parameters.device)
                decompressed_payload = self.compressor.decompress(values, indices, self.model_parameters.shape)

            elif self.type == 'qsgd':
                (n, s, l), data_size = payload
                decompressed_payload = self.compressor.decompress((n, s, l))

            decompressed_payload_with_weight = (decompressed_payload, data_size)
            return super().load([decompressed_payload_with_weight])

elif args.method == "FedAvg":
    from fedlab.contrib.algorithm.basic_client import SGDClientTrainer, SGDSerialClientTrainer
    from fedlab.contrib.algorithm.basic_server import SyncServerHandler

    class CompressSerialClientTrainer(SGDSerialClientTrainer):
        def setup_compressor(self, compressor):
            self.compressor = compressor

        @property
        def uplink_package(self):
            package = super().uplink_package
            new_package = []
            for content in package:
                pack = [self.compressor.compress(content[0])]
                new_package.append(pack)
            return new_package


    class CompressServerHandler(SyncServerHandler):
        def setup_compressor(self, compressor, type):
            self.compressor = compressor
            self.type = type

        def load(self, payload) -> bool:
            if self.type == 'topk':
                values, indices = payload[0]
                decompressed_payload = self.compressor.decompress(values, indices, self.model_parameters.shape)

            if self.type == 'qsgd':
                n, s, l = payload[0]
                decompressed_payload = self.compressor.decompress((n, s, l))

            return super().load([decompressed_payload])


class EvalPipeline(StandalonePipeline):
    def __init__(self, handler, trainer, test_loader):
        super().__init__(handler, trainer)
        self.test_loader = test_loader

    def main(self):
        res_loss, res_acc = [], []
        while self.handler.if_stop is False:
            # server size
            sampled_clients = self.handler.sample_clients()
            broadcast = self.handler.downlink_package

            self.trainer.local_process(broadcast, sampled_clients, method=args.method)
            uploads = self.trainer.uplink_package

            for pack in uploads:
                self.handler.load(pack)

            loss, acc = evaluate(self.handler.model, nn.CrossEntropyLoss(), self.test_loader)
            print(f"Centralized Evaluation round {self.handler.round}: loss {loss:.4f}, test accuracy {acc:.4f}")
            res_loss.append(loss)
            res_acc.append(acc)

        return res_loss, res_acc


if args.dataset == "PartitionedMNIST":
    dataset = PartitionedMNIST(root="../datasets",
                                 path=data_path,
                                 num_clients=args.total_client,
                                 partition=partition,
                                 dir_alpha=args.alpha,
                                 seed=args.seed,
                                 preprocess=preprocess,
                                 download=True,
                                 verbose=True,
                                 transform=transforms.Compose([
                                     transforms.ToPILImage(),
                                     transforms.ToTensor()
                                 ]))
    test_data = torchvision.datasets.MNIST(root="../datasets/mnist",
                                           train=False,
                                           download=True,
                                           transform=transforms.ToTensor())
    if args.model == "MLP":
        model = MLP(784, 10)
    elif args.model == "CNN_MNIST":
        model = CNN_MNIST()

elif args.dataset == "PartitionedCIFAR10":
    dataset = PartitionedCIFAR10(root="../datasets",
                                 path=data_path,
                                 num_clients=args.total_client,
                                 partition=partition,
                                 dir_alpha=args.alpha,
                                 seed=args.seed,
                                 dataname="cifar10",
                                 preprocess=preprocess,
                                 download=True,
                                 verbose=True,
                                 transform=transforms.Compose([
                                     # transforms.ToPILImage(),
                                     transforms.ToTensor()
                                 ]))
    test_data = torchvision.datasets.CIFAR10(root="../dataset/CIFAR10/",
                                             train=False,
                                             download=True,
                                             transform=transforms.ToTensor())

    if args.model == "CNN_CIFAR10":
        model = CNN_CIFAR10()
    elif args.model == "AlexNet_CIFAR10":
        model = AlexNet_CIFAR10()
    elif args.model == "ResNet18":
        # TODO: fix the issues here.
        import torchvision
        import torchvision.models as models
        from torchvision.models.resnet import ResNet18_Weights
        # Load the pre-trained ResNet-18 model with the new 'weights' parameter
        model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)


elif args.dataset == "FedMNIST":
    print()



def main(k_ratio):
    tpk_compressor = TopkCompressor(compress_ratio=k_ratio)
    qsgd_compressor = QSGDCompressor(n_bit=args.n_bit)
    if args.compressor == 'topk':
        compressor = tpk_compressor
    elif args.compressor == 'qsgd':
        compressor = qsgd_compressor

    trainer = CompressSerialClientTrainer(model, args.total_client, cuda=use_cuda)
    trainer.setup_dataset(dataset)
    trainer.setup_optim(args.epochs, args.batch_size, args.lr)
    trainer.setup_compressor(compressor)

    handler = CompressServerHandler(model=model,
                                    global_round=args.com_round,
                                    num_clients=args.total_client,
                                    sample_ratio=args.sample_ratio,
                                    cuda=use_cuda)
    handler.setup_compressor(compressor, args.compressor)

    test_loader = DataLoader(test_data, batch_size=1024)
    standalone_eval = EvalPipeline(handler=handler,
                                   trainer=trainer,
                                   test_loader=test_loader)
    res_loss, res_acc = standalone_eval.main()
    return res_loss, res_acc


import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('fast')
mpl.rcParams['mathtext.fontset'] = 'cm'
# mpl.rcParams['mathtext.fontset'] = 'dejavusans'
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['lines.linewidth'] = 2.0
mpl.rcParams['legend.fontsize'] = 'large'
mpl.rcParams['axes.titlesize'] = 'xx-large'
mpl.rcParams['xtick.labelsize'] = 'x-large'
mpl.rcParams['ytick.labelsize'] = 'x-large'
mpl.rcParams['axes.labelsize'] = 'xx-large'
# ratios = [0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
ratios = [0.5, 0.9, 1.0]

color_ar_1 = [u'#1f77b4', u'#ff7f0e', u'#2ca02c', u'#d62728', u'#9467bd', u'#8c564b', u'#e377c2', u'#7f7f7f', u'#bcbd22', u'#17becf']
markers = ['x', '.', '+', '1', 'p','*', 'D' , '.',  's']
fig, axs = plt.subplots(2, figsize=(7, 10), constrained_layout=True)
# rounds = np.arange(args.com_round)
i = 0

loss_file = f"./losses_{constraint}.txt"
acc_file = f"./acc_{constraint}.txt"
# for k_ratio in ratios:
k_ratio = args.k_ratio
res_loss, res_acc = main(k_ratio)
# if k_ratio == ratios[0]:
    # rounds = np.arange(len(res_loss))
rounds = np.arange(len(res_loss))
# axs[0].plot(rounds, res_loss, marker=markers[i], markersize=10, color=color_ar_1[i], label=k_ratio)
# axs[1].plot(rounds, res_acc, marker=markers[i], markersize=10, color=color_ar_1[i], label=k_ratio)
axs[0].plot(rounds, res_loss, color=color_ar_1[i], label=k_ratio)
axs[1].plot(rounds, res_acc, color=color_ar_1[i], label=k_ratio)
# Join the list items into a single string separated by newlines
# with open(loss_file, 'a+') as file:
#     file.write(' '.join([str(loss) for loss in res_loss]))
# with open(acc_file, 'a+') as file:
#     file.write(' '.join([str(acc) for acc in res_acc]))
with open(loss_file, 'w') as file:
    file.write(' '.join([str(loss) for loss in res_loss]))
with open(acc_file, 'w') as file:
    file.write(' '.join([str(acc) for acc in res_acc]))
i += 1

axs[0].legend()
axs[0].set_ylabel('Loss')
axs[0].set_xlabel('Communication rounds')

axs[1].legend()
axs[1].set_ylabel('Accuracy')
axs[1].set_xlabel('Communication rounds')

plt.savefig(f"{constraint}.pdf")
plt.close()
