import argparse
import os, sys
sys.path.append("../")
from fedlab.contrib.dataset.partitioned_mnist import PartitionedMNIST
from fedlab.contrib.dataset.partitioned_cifar10 import PartitionedCIFAR10
from torchvision import transforms
from fedlab.models.mlp import MLP
from fedlab.models.cnn import CNN_MNIST
from fedlab.utils.functional import evaluate
from fedlab.core.standalone import StandalonePipeline
from torch import nn
import torchvision
from fedlab.contrib.compressor.quantization import QSGDCompressor
from fedlab.contrib.compressor.topk import TopkCompressor
from fedlab.models.cnn import CNN_CIFAR10, AlexNet_CIFAR10


def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--compressor', type=str, default='topk', help='choose from topk and qsgd')
    parser.add_argument('--k_ratio', type=float, default='0.5', help='TopK k ratio')
    parser.add_argument('--n_bit', type=int, default=8, help='number of bits for quantization')
    parser.add_argument('--total_client', type=int, default=100)
    parser.add_argument('--alpha', type=float, default=0.7, help="Dirichlet alpha")
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--epochs', type=int, default=5)
    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--batch_size', '-bs', type=int, default=1024)
    parser.add_argument('--com_round', type=int, default=50000)
    parser.add_argument('--sample_ratio', default=0.1, type=float)
    parser.add_argument('--model', default='MLP', type=str, help="MLP?")
    parser.add_argument('--dataset', default='PartitionedMNIST', type=str, help="PartitionedMNIST?")
    parser.add_argument('--method', default="SpProxSkip_Com", type=str, help="Local, Com, Global?")
    args = parser.parse_args()
    return args


class SelectModelClasses():
    def __init__(self, args):
        self.args = args

    def basic_config(self, partition):
        self.data_path = f"../datasets/{self.args.dataset}/{partition}_{self.args.alpha}_{self.args.total_client}"
        self.preprocess = False if os.path.exists(self.data_path) else True

    def dataset_MNIST(self):
        partition = "noniid-labeldir"
        self.basic_config(partition)
        dataset = PartitionedMNIST(root="../datasets", path=self.data_path, num_clients=self.args.total_client,
                                   partition=partition, dir_alpha=self.args.alpha, seed=self.args.seed,
                                   preprocess=self.preprocess, download=True, verbose=True,
                                   transform=transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()]))
        test_data = torchvision.datasets.MNIST(root="../datasets/mnist", train=False, download=True,
                                               transform=transforms.ToTensor())
        return dataset, test_data

    def dataset_CIFAR10(self):
        partition = "dirichlet"
        self.basic_config(partition)
        dataset = PartitionedCIFAR10(root="../datasets",path=self.data_path,num_clients=self.args.total_client,
                                     partition=partition,dir_alpha=self.args.alpha,seed=self.args.seed,
                                     dataname="cifar10",preprocess=self.preprocess,download=True,verbose=True,
                                     transform=transforms.Compose([transforms.ToTensor()]))
        test_data = torchvision.datasets.CIFAR10(root="../dataset/CIFAR10/",train=False,download=True,
                                                 transform=transforms.ToTensor())
        return dataset, test_data

    def dataset_CIFAR100(self):
        parition = "dirichlet"
        self.basic_config(parition)
        pass

    def model_MLP_MNIST(self):
        model = MLP(784, 10)
        return model

    def model_CNN_MNIST(self):
        model = CNN_MNIST()
        return model

    def model_CNN_CIFAR10(self):
        model = CNN_CIFAR10()
        return model

    # def model_CNN_CIFAR100(self):
    #     model = CNN_CIFAR100()
    #     return model
    #
    # def model_CNN_CIFAR100(self):
    #     model = AlexNet_CIFAR10()
    #     return model

    def model_ResNet18(self):
        import torchvision.models as models
        from torchvision.models.resnet import ResNet18_Weights
        # Load the pre-trained ResNet-18 model with the new 'weights' parameter
        model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        return model

    def mselect(self):
        datasets = {"PartitionedMNIST": self.dataset_MNIST(),
                    "PartitionedCIFAR10": self.dataset_CIFAR10()}
        models = {"MLP_MNIST": self.model_MLP_MNIST(),
                  "CNN_MNIST": self.model_CNN_MNIST()}
        return datasets[self.args.dataset], models[self.args.model]

    def print_name(self):
        constraint = f"{self.args.dataset}_{self.args.model}_{self.args.compressor}_{self.args.k_ratio}_{self.args.n_bit}_" \
                     f"{self.args.total_client}_{self.args.alpha}_{self.args.epochs}_{self.args.lr}_{self.args.com_round}"
        return constraint

    def get_compressor(self):
        tpk_compressor = TopkCompressor(compress_ratio=self.args.k_ratio)
        qsgd_compressor = QSGDCompressor(n_bit=self.args.n_bit)
        if self.args.compressor == 'topk':
            compressor = tpk_compressor
        elif self.args.compressor == 'qsgd':
            compressor = qsgd_compressor
        return compressor


class EvalPipeline(StandalonePipeline):
    # General evaluatioin pipeline
    def __init__(self, handler, trainer, test_loader):
        super().__init__(handler, trainer)
        self.test_loader = test_loader

    def main(self):
        res_loss, res_acc = [], []
        while self.handler.if_stop is False:
            sampled_clients = self.handler.sample_clients()
            broadcast = self.handler.downlink_package
            self.trainer.local_process(broadcast, sampled_clients)
            uploads = self.trainer.uplink_package

            for pack in uploads:
                self.handler.load(pack)

            loss, acc = evaluate(self.handler.model, nn.CrossEntropyLoss(), self.test_loader)
            print(f"Centralized Evaluation round {self.handler.round}: loss {loss:.4f}, test accuracy {acc:.4f}")
            res_loss.append(loss)
            res_acc.append(acc)

        return res_loss, res_acc



