import sys

import numpy as np

sys.path.append("../")

from fedlab.contrib.compressor.quantization import QSGDCompressor
from fedlab.contrib.compressor.topk import TopkCompressor
from fedlab.contrib.algorithm.basic_client import SGDClientTrainer, SGDSerialClientTrainer
from fedlab.contrib.algorithm.basic_server import SyncServerHandler
import torch
import argparse
import os
from opcode import cmp_op
from fedlab.models.mlp import MLP
from torchvision import transforms
from fedlab.contrib.dataset.partitioned_mnist import PartitionedMNIST
from fedlab.contrib.dataset.partitioned_cifar10 import PartitionedCIFAR10
from fedlab.utils.functional import evaluate
from fedlab.core.standalone import StandalonePipeline
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from fedlab.models.mlp import MLP
from fedlab.models.cnn import CNN_MNIST
from fedlab.models.cnn import CNN_CIFAR10, AlexNet_CIFAR10
from fedlab.models.convex import LinearModel, PiecewiseLinearModel, LinearModelEnsemble


def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--compressor', type=str, default='topk', help='choose from topk and qsgd')
    parser.add_argument('--k_ratio', type=float, default='1', help='TopK k ratio')
    parser.add_argument('--n_bit', type=int, default=8, help='number of bits for quantization')
    parser.add_argument('--total_client', type=int, default=1)
    parser.add_argument('--alpha', type=float, default=0.7, help="Dirichlet alpha")
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--batch_size', '-bs', type=int, default=1024)
    parser.add_argument('--com_round', type=int, default=1)
    parser.add_argument('--sample_ratio', default=0.1, type=float)
    parser.add_argument('--model', default='LinearModel', type=str, help="MLP?")
    parser.add_argument('--dataset', default='PartitionedMNIST', type=str, help="PartitionedMNIST | PartitionedCIFAR10?")
    args = parser.parse_args()
    return args


args = args_parser()
use_cuda = True if torch.cuda.is_available() else False
# Check if CUDA is available
if torch.cuda.is_available():
    # Specify the CUDA device, 'cuda:0' for the first GPU
    device = torch.device('cuda:0')
else:
    # Fallback to CPU if CUDA is not available
    device = torch.device('cpu')


class CompressSerialClientTrainer(SGDSerialClientTrainer):
    def setup_compressor(self, compressor):
        self.compressor = compressor

    def update_model(self, server_model_state_dict):
        self.model.load_state_dict(server_model_state_dict)

    @property
    def uplink_package(self):
        package = super().uplink_package
        new_package = []
        for content in package:
            pack = [self.compressor.compress(content[0])]
            new_package.append(pack)
        return new_package


class CompressServerHandler(SyncServerHandler):
    def setup_compressor(self, compressor, type):
        self.compressor = compressor
        self.type = type

    def load(self, payload) -> bool:
        if self.type == 'topk':
            # values, indices = payload[0]
            values, indices = payload[0]
            values = values.to(self.model_parameters.device)
            indices = indices.to(self.model_parameters.device)
            decompressed_payload = self.compressor.decompress(values, indices, self.model_parameters.shape)

        if self.type == 'qsgd':
            n, s, l = payload[0]
            decompressed_payload = self.compressor.decompress((n, s, l))

        return  super().load([decompressed_payload])

    # def downlink_package(self):
    #     return self.model.state_dict()


class EvalPipeline(StandalonePipeline):
    def __init__(self, handler, trainer, test_loader):
        super().__init__(handler, trainer)
        self.test_loader = test_loader

    def main(self):
        res_loss, res_acc = [], []
        while self.handler.if_stop is False:
            # server size
            sampled_clients = self.handler.sample_clients()
            broadcast = self.handler.downlink_package

            # client size
            self.trainer.local_process(broadcast, sampled_clients)

            uploads = self.trainer.uplink_package

            # server side
            for pack in uploads:
                self.handler.load(pack)

            loss, acc = evaluate(self.handler.model, nn.CrossEntropyLoss(), self.test_loader)
            print(f"Centralized Evaluation round {self.handler.round}: loss {loss:.4f}, test accuracy {acc:.4f}")
            res_loss.append(loss)
            res_acc.append(acc)

        return res_loss, res_acc


# first round, find the global optimal
args.total_client = 1
if not os.path.exists(f"../datasets/{args.dataset}"):
    os.mkdir(f"../datasets/{args.dataset}")

if args.dataset == "PartitionedMNIST":
    partition = "noniid-labeldir"
    data_path = f"../datasets/{args.dataset}/{partition}_{args.alpha}_{args.total_client}"

elif args.dataset == "PartitionedCIFAR10":
    partition = "dirichlet"
    data_path = f"../datasets/{args.dataset}/{partition}_{args.alpha}_{args.total_client}"

elif args.dataset == "PartitionedCIFAR100":
    partition = "dirichlet"
    data_path = f"../datasets/{args.dataset}/{partition}_{args.alpha}_{args.total_client}"

preprocess = False if os.path.exists(data_path) else True

constraint = f"check_sigma_{args.model}_{args.compressor}_{args.k_ratio}_{args.n_bit}_{args.total_client}_{args.alpha}_{args.epochs}_{args.lr}_{args.com_round}"

if args.dataset == "PartitionedMNIST":
    dataset = PartitionedMNIST(root="../datasets",
                                 path=data_path,
                                 num_clients=args.total_client,
                                 partition=partition,
                                 dir_alpha=args.alpha,
                                 seed=args.seed,
                                 preprocess=preprocess,
                                 download=True,
                                 verbose=True,
                                 transform=transforms.Compose([
                                     transforms.ToPILImage(),
                                     transforms.ToTensor()
                                 ]))
    test_data = torchvision.datasets.MNIST(root=f"../datasets/{args.dataset}",
                                           train=False,
                                           download=True,
                                           transform=transforms.ToTensor())

    if args.model == "LinearModel":
        model = LinearModel(784, 10)
    elif args.model == "PiecewiseLinearModel":
        model = PiecewiseLinearModel(784, 10, 3)
    elif args.model == "LinearModelEnsemble":
        model = LinearModelEnsemble(784, 10, 3)

elif args.dataset == "PartitionedCIFAR10":
    dataset = PartitionedCIFAR10(root="../datasets",
                                 path=data_path,
                                 num_clients=args.total_client,
                                 partition=partition,
                                 dir_alpha=args.alpha,
                                 seed=args.seed,
                                 dataname="cifar10",
                                 preprocess=preprocess,
                                 download=True,
                                 verbose=True,
                                 transform=transforms.Compose([
                                     # transforms.ToPILImage(),
                                     transforms.ToTensor()
                                 ]))
    test_data = torchvision.datasets.CIFAR10(root=f"../datasets/{args.dataset}/",
                                             train=False,
                                             download=True,
                                             transform=transforms.ToTensor())

    if args.model == "LinearModel":
        model = LinearModel(3072, 10)
    elif args.model == "PiecewiseLinearModel":
        model = PiecewiseLinearModel(3072, 10, 3)
    elif args.model == "LinearModelEnsemble":
        model = LinearModelEnsemble(3072, 10, 3)


elif args.dataset == "PartitionedCIFAR100":
    dataset = PartitionedCIFAR10(root="../datasets",
                                 path=data_path,
                                 num_clients=args.total_client,
                                 partition=partition,
                                 dir_alpha=args.alpha,
                                 seed=args.seed,
                                 dataname="cifar100",
                                 preprocess=preprocess,
                                 download=True,
                                 verbose=True,
                                 transform=transforms.Compose([
                                     # transforms.ToPILImage(),
                                     transforms.ToTensor()
                                 ]))
    test_data = torchvision.datasets.CIFAR10(root=f"../datasets/{args.dataset}/",
                                             train=False,
                                             download=True,
                                             transform=transforms.ToTensor())

    if args.model == "LinearModel":
        model = LinearModel(3072, 100)
    elif args.model == "PiecewiseLinearModel":
        model = PiecewiseLinearModel(3072, 100, 3)
    elif args.model == "LinearModelEnsemble":
        model = LinearModelEnsemble(3072, 100, 3)

tpk_compressor = TopkCompressor(compress_ratio=args.k_ratio)
qsgd_compressor = QSGDCompressor(n_bit=args.n_bit)
if args.compressor == 'topk':
    compressor = tpk_compressor
elif args.compressor == 'qsgd':
    compressor = qsgd_compressor

trainer = CompressSerialClientTrainer(model, args.total_client, cuda=use_cuda)
trainer.setup_dataset(dataset)
trainer.setup_optim(args.epochs, args.batch_size, args.lr)
trainer.setup_compressor(compressor)

handler = CompressServerHandler(model=model,
                                global_round=args.com_round,
                                num_clients=args.total_client,
                                sample_ratio=args.sample_ratio,
                                cuda=use_cuda)
handler.setup_compressor(compressor, args.compressor)

test_loader = DataLoader(test_data, batch_size=1024)
standalone_eval = EvalPipeline(handler=handler,
                               trainer=trainer,
                               test_loader=test_loader)
res_loss, res_acc = standalone_eval.main()
opt_model_parameters = handler.model.state_dict()  # or trainer.model.state_dict(), depending on where your model is
# print(opt_model_parameters)


# Additional function to compute the gradient norm for a single client
def compute_gradient_norm(model, data_loader, criterion, device):
    model.eval()  # Set model to evaluation mode
    model.to(device)  # Move model to the specified device
    total_grad_norm = 0.0

    for inputs, labels in data_loader:
        inputs, labels = inputs.to(device), labels.to(device)  # Move data to the same device as the model
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        model.zero_grad()
        loss.backward()

        # Calculate gradient norm
        grad_norm = sum(p.grad.norm() ** 2 for p in model.parameters() if p.grad is not None)
        total_grad_norm += grad_norm

    return total_grad_norm


constraint = f"check_sigma_{args.dataset}_{args.model}_{args.epochs}_{args.lr}_{args.com_round}"
# total_clients = [1, 16]
model.load_state_dict(opt_model_parameters)
# Criterion for loss computation
criterion = nn.CrossEntropyLoss()

total_clients = [1, 2, 4, 8, 16, 32, 64, 128, 256]
alphas = [0.3, 0.5, 0.7, 0.9, 1.0, 1000]
# total_clients = [16, 32, 64, 128, 256, 10, 100, 8, 4, 2, 1]
# alphas = [0.7]
# alphas = [0.1, 0.3, 1.0]
# alphas = [0.7]
gradient_norms_over_alphas = []
for alpha in alphas:
    gradient_norms = []
    for total_client in total_clients:
        # print(f" ++++++++++++++ total_client: {total_client}")
        if args.dataset == "PartitionedMNIST":
            partition = "noniid-labeldir"
            data_path = f"../datasets/{args.dataset}/{partition}_{alpha}_{total_client}"

        elif args.dataset == "PartitionedCIFAR10":
            partition = "dirichlet"
            data_path = f"../datasets/{args.dataset}/{partition}_{alpha}_{total_client}"

        elif args.dataset == "PartitionedCIFAR100":
            partition = "dirichlet"
            data_path = f"../datasets/{args.dataset}/{partition}_{alpha}_{total_client}"

        preprocess = False if os.path.exists(data_path) else True

        if args.dataset == "PartitionedMNIST":
            dataset = PartitionedMNIST(root="../datasets",
                                       path=data_path,
                                       num_clients=total_client,
                                       partition=partition,
                                       dir_alpha=alpha,
                                       seed=args.seed,
                                       preprocess=preprocess,
                                       download=True,
                                       verbose=True,
                                       transform=transforms.Compose([
                                           transforms.ToPILImage(),
                                           transforms.ToTensor()
                                       ]))
            test_data = torchvision.datasets.MNIST(root=f"../datasets/{args.dataset}",
                                                   train=False,
                                                   download=True,
                                                   transform=transforms.ToTensor())

        elif args.dataset == "PartitionedCIFAR10":
            dataset = PartitionedCIFAR10(root="../datasets",
                                         path=data_path,
                                         num_clients=total_client,
                                         partition=partition,
                                         dir_alpha=alpha,
                                         seed=args.seed,
                                         dataname="cifar10",
                                         preprocess=preprocess,
                                         download=True,
                                         verbose=True,
                                         transform=transforms.Compose([
                                             # transforms.ToPILImage(),
                                             transforms.ToTensor()
                                         ]))
            test_data = torchvision.datasets.CIFAR10(root=f"../datasets/{args.dataset}/",
                                                     train=False,
                                                     download=True,
                                                     transform=transforms.ToTensor())

        elif args.dataset == "PartitionedCIFAR100":
            dataset = PartitionedCIFAR10(root="../datasets",
                                         path=data_path,
                                         num_clients=total_client,
                                         partition=partition,
                                         dir_alpha=alpha,
                                         seed=args.seed,
                                         dataname="cifar100",
                                         preprocess=preprocess,
                                         download=True,
                                         verbose=True,
                                         transform=transforms.Compose([
                                             # transforms.ToPILImage(),
                                             transforms.ToTensor()
                                         ]))
            test_data = torchvision.datasets.CIFAR10(root=f"../datasets/{args.dataset}/",
                                                     train=False,
                                                     download=True,
                                                     transform=transforms.ToTensor())

        # Compute gradient norms for each client
        client_grad_norms = []
        for client_id in range(total_client):
            # print(f"======= Client id: {client_id}")
            client_data_loader = dataset.get_dataloader(client_id, batch_size=args.batch_size)
            # client_data_loader = DataLoader(dataset[client_id], batch_size=args.batch_size)
            grad_norm = compute_gradient_norm(model, client_data_loader, criterion, device)
            client_grad_norms.append(grad_norm.cpu())

        # Average gradient norms
        avg_grad_norm = sum(client_grad_norms) / total_client
        gradient_norms.append(avg_grad_norm)
        print(f"Average gradient norm for {total_client} clients: {avg_grad_norm}")

    gradient_norms_over_alphas.append(gradient_norms)

# Plot the results
import matplotlib.pyplot as plt
import matplotlib as mpl
plt.style.use('fast')
mpl.rcParams['mathtext.fontset'] = 'cm'
# mpl.rcParams['mathtext.fontset'] = 'dejavusans'
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['lines.linewidth'] = 2.0
mpl.rcParams['legend.fontsize'] = 'large'
mpl.rcParams['axes.titlesize'] = 'xx-large'
mpl.rcParams['xtick.labelsize'] = 'x-large'
mpl.rcParams['ytick.labelsize'] = 'x-large'
mpl.rcParams['axes.labelsize'] = 'xx-large'

markers = ['x', '.', '+', '1', 'p','*', 'D' , '.',  's']

plt.xscale('log')
plt.ylabel(r"$\sigma_{\star}$", fontsize=12)
plt.xlabel("n", fontsize=12)
offset = max(gradient_norms_over_alphas[0]) * 0.02
for i in range(len(gradient_norms_over_alphas)):
    plt.plot(total_clients, gradient_norms_over_alphas[i], marker=markers[i], markersize=10, label=alphas[i])
    for x, y in zip(total_clients, gradient_norms_over_alphas[i]):
        plt.text(x, y+offset, f'{y:.3f}', ha='center', va='bottom', fontsize=8)
# Setting the x-axis ticks explicitly to the values in n_workerss
plt.xticks(total_clients, labels=[str(k) for k in total_clients], fontsize=8)
plt.yticks(fontsize=8)
plt.legend()
plt.savefig(f"{constraint}.pdf")
plt.close()