
import argparse
import pickle
from argparse import ArgumentParser
from pathlib import Path
import pickle
import os
import numpy as np
from torch.utils.data import DataLoader, Subset, TensorDataset
from torchvision import transforms
from torchvision.datasets import CIFAR100, EMNIST
from torchvision.transforms import Compose, Normalize
import ATTACKS
import random
import wandb
import utils
import torch
from data_util.FMNIST.fashionmnist_data_loader import _data_transforms_FashionMNIST
from data_util.constants import MEAN, STD

import sys

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../../../")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../../")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))

import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

PROJECT_DIR = Path(__file__).absolute().parent


def get_dsfl_argparser(args) -> ArgumentParser:
    parser = argparse.ArgumentParser()
    parser.set_defaults(**vars(args))
    parser.add_argument("--digest_epoch", type=int, default=1)
    parser.add_argument("--local_epoch", type=int, default=2)
    parser.add_argument("--public_epoch", type=int, default=5)
    parser.add_argument("--communicat_epoch", type=int, default=100)
    return parser


class DSFL_standalone_API:
    def __init__(self, client_models, train_data_local_num_dict, test_data_local_num_dict,
                 train_data_local_dict, test_data_local_dict, args, test_data_global,train_data_public):
        self.client_models = client_models
        self.test_data_global = test_data_global
        self.criterion_KL = utils.KL_Loss(args.public_T)
        self.criterion_CE = F.cross_entropy
        self.args = get_dsfl_argparser(args).parse_args()

        self.train_data_public=train_data_public
        self.mse_criterion = torch.nn.MSELoss()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    def custom_cross_entropy(self,raw_output, true_labels):
        loss_per_sample = -torch.sum(true_labels * raw_output, dim=1)
        average_loss = torch.mean(loss_per_sample)
        return average_loss

    def do_dsfl_stand_alone(self, client_models, train_data_local_num_dict, test_data_local_num_dict,
                             train_data_local_dict, test_data_local_dict, args):
        # wandb.login(key="4651ed05b06c424a6d1a6c4d9b2135a47edfbc39")
        # wandb.init(project='FDMD', config=args)

        wandb.login(key="913a0944b78830edbb2fdd338acc245686e13363")
        wandb.init(project='DSFL', config=args)

        local_knowledge={}
        global_knowledge={}
        mal_idxs=utils.mal_select_clients(len(self.client_models),args.mal_rate)
        print("恶意客户端ID：{}".format(mal_idxs))
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # train_data_num = {}
        # for client_index, client_model in enumerate(self.client_models):
        #     cur_idx = 0
        #     client_train_data_dist = [0 for _ in range(args.class_num)]
        #     for batch_idx, (images, labels) in enumerate(train_data_local_dict[client_index]):
        #         images, labels=images.cuda(), labels.cuda()
        #         # 转换label的tensor类别为list
        #         batch_idx_list_label = list(labels)
        #         for i in batch_idx_list_label:
        #             client_train_data_dist[i] += 1
        #     train_data_num[client_index]=sum(client_train_data_dist)
        #     print("客户端{}的标签分布{}".format(client_index, client_train_data_dist))
        #     print("客户端{}的数量{}".format(client_index, train_data_num[client_index]))
        #     print("客户端{}的数量{}".format(client_index, train_data_local_num_dict[client_index]))
        # print("数量{}".format(sum(list(train_data_num.values()))))
        for global_epoch in range(args.comm_round):  # 表示进行多少次客户端与服务器之间的交互，设置为无限大则一直不停
            print("开始训练第" + str(global_epoch) + "轮次训练")
            metrics_all = {'test_loss': [], 'test_accTop1': [], 'test_accTop5': [], 'f1': []}

            if args.attack_method_id == 7:
                # print("执行参数反转攻击")
                w_locals = []  # 存储客户端的模型参数
                for client_index, client_model in enumerate(self.client_models):
                    w_locals.append({name: param.clone().detach() for name, param in client_model.state_dict().items()})
                attacker = ATTACKS.ParameterFlipAttack()
                attacker.attack(w_locals, mal_idxs, device)
                for mal_idx in mal_idxs:
                    self.client_models[mal_idx].load_state_dict(w_locals[mal_idx])
            elif args.attack_method_id==6:
                # print("执行IPM攻击")
                w_locals = []  # 存储客户端的模型参数
                for client_index, client_model in enumerate(self.client_models):
                    w_locals.append({name: param.clone().detach() for name, param in client_model.state_dict().items()})
                attacker = ATTACKS.IPMAttack(epsilon=0.1)
                attacker.attack(w_locals, mal_idxs, device)
                for mal_idx in mal_idxs:
                    self.client_models[mal_idx].load_state_dict(w_locals[mal_idx])

            # 每个客户端在私有上训练
            for client_index, model in enumerate(self.client_models):
                # print("开始本地训练第" + str(client_index) + "个客户端")
                model.to(self.device)
                model.train()
                optim = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9,weight_decay=args.wd)
                for _ in range(self.args.local_epoch):
                    for batch_idx, (images, labels) in enumerate(train_data_local_dict[client_index]):
                        images, labels = images.to(self.device), labels.to(self.device).long()
                        if client_index in mal_idxs:
                            if args.attack_method_id == 8:
                                # print("执行标签翻转攻击")
                                label_flip_attack = ATTACKS.LabelFlipAttack(args.class_num)
                                labels = label_flip_attack.attack(client_index, labels)
                            elif args.attack_method_id == 9:
                                # print("执行女巫攻击")
                                witch_attack = ATTACKS.WitchAttack(args.class_num)
                                labels = witch_attack.attack(client_index, labels)
                        log_probs = model(images)
                        loss = self.criterion_CE(log_probs, labels)
                        optim.zero_grad()
                        loss.backward()
                        optim.step()

            acc_top1_all = []
            honest_top1_all = []
            for client_index, client_model in enumerate(self.client_models):
                # 验证客户端的准确性
                # print("开始验证digest第" + str(client_index) + "个客户端")
                client_model.eval()
                loss_avg = utils.RunningAverage()
                accTop1_avg = utils.RunningAverage()
                for batch_idx, (images, labels) in enumerate(test_data_local_dict[client_index]):
                    images, labels = images.to(self.device), labels.to(self.device).long()
                    log_probs = client_model(images)
                    loss = self.criterion_CE(log_probs, labels)
                    # Update average loss and accuracy
                    metrics = utils.accuracy(log_probs, labels, topk=(1, 5))
                    # only one element tensors can be converted to Python scalars
                    accTop1_avg.update(metrics[0].item())
                    loss_avg.update(loss.item())
                acc_top1 = accTop1_avg.value()
                acc_top1_all.append(acc_top1)
                # 收集所有参与者的数据
                if client_index not in mal_idxs:
                    honest_top1_all.append(acc_top1)
                # 收集未感染者的数据
            print("本地训练后客户端的平均准确率：{}".format(np.mean(honest_top1_all)))


            selected_batches = sorted(random.sample(range(len(self.train_data_public)), args.batch_num))
            print("选择的batch：{}".format(selected_batches))


            for client_index, model in enumerate(self.client_models):
                local_knowledge_tem = []
                model.eval()
                with torch.no_grad():
                    for batch_idx, (images, labels) in enumerate(self.train_data_public):
                        if batch_idx not in selected_batches:
                            continue
                        images, labels = images.to(self.device), labels.to(self.device).long()
                        log_probs = model(images)
                        if client_index in mal_idxs:
                            if 1 <= args.attack_method_id <= 5:
                                # print("攻击前知识：{}".format(log_probs[0]))
                                logits_attack = ATTACKS.LogitsProcessor(attack_method_id=args.attack_method_id)
                                log_probs = logits_attack.attack(log_probs)
                                # print("攻击后知识：{} 另一个：{}".format(log_probs[0],log_probs[1]))
                        local_knowledge_tem.append(log_probs.detach())
                local_knowledge[client_index]=local_knowledge_tem
                # print("本地知识：{}".format(local_knowledge[client_index]))

            #聚合过程
            for id, batch_idx in enumerate(selected_batches):
                tmp = []
                for client_id, knowledge in local_knowledge.items():
                    tmp.append(knowledge[id])
                global_knowledge[batch_idx] = torch.mean(torch.stack(tmp), dim=0)
            # print("全局知识：{}".format(global_knowledge[0]))

            print("开始下载知识并本地训练")
            for client_index, model in enumerate(self.client_models):
                # print("开始蒸馏：{}".format(client_index))
                # 开始digest
                model = model.cuda()
                model.train()
                optim = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9,
                                        weight_decay=args.wd)
                for batch_idx, (images, labels) in enumerate(self.train_data_public):
                    if batch_idx not in selected_batches:
                        continue
                    images, labels = images.to(self.device), labels.to(self.device).long()
                    log_probs = model(images)
                    # loss = self.mse_criterion(log_probs, self.consensus[batch_idx])
                    loss=self.criterion_KL(log_probs,global_knowledge[batch_idx]/args.public_T)
                    # loss= F.cross_entropy(log_probs, labels)
                    optim.zero_grad()
                    loss.backward()
                    optim.step()

            # 验证客户端的准确性
            acc_top1_all = []
            acc_top5_all = []
            honest_top1_all = []
            honest_top5_all = []
            for client_index, client_model in enumerate(self.client_models):
                # 验证客户端的准确性
                # print("开始验证第" + str(client_index) + "个客户端")
                client_model.eval()
                loss_avg = utils.RunningAverage()
                accTop1_avg = utils.RunningAverage()
                accTop5_avg = utils.RunningAverage()
                for batch_idx, (images, labels) in enumerate(test_data_local_dict[client_index]):
                    images, labels = images.to(self.device), labels.to(self.device).long()
                    log_probs = client_model(images)
                    loss = self.criterion_CE(log_probs, labels)
                    # Update average loss and accuracy
                    metrics = utils.accuracy(log_probs, labels, topk=(1, 5))
                    # only one element tensors can be converted to Python scalars
                    accTop1_avg.update(metrics[0].item())
                    accTop5_avg.update(metrics[1].item())
                    loss_avg.update(loss.item())
                # compute mean of all metrics in summary
                test_metrics = {str(client_index) + ' test_loss': loss_avg.value(),
                                str(client_index) + ' test_accTop1': accTop1_avg.value(),
                                str(client_index) + ' test_accTop5': accTop5_avg.value(),
                                }
                wandb.log({str(client_index) + " Test/AccTop1": test_metrics[str(client_index) + ' test_accTop1']},
                          step=global_epoch)
                wandb.log({str(client_index) + " Test/AccTop5": test_metrics[str(client_index) + ' test_accTop5']},
                          step=global_epoch)

                acc_top1 = accTop1_avg.value()
                acc_top5 = accTop5_avg.value()
                acc_top1_all.append(acc_top1)
                acc_top5_all.append(acc_top5)

                # 收集参与者的数据
                if client_index not in mal_idxs:
                    honest_top1_all.append(acc_top1)
                    honest_top5_all.append(acc_top5)

                # 收集未感染者的数据
            # print("客户端的准确率：{}".format(honest_top1_all))
            print("客户端的平均准确率：{}".format(np.mean(honest_top1_all)))

            wandb.log({f"mean Test/AccTop1 on all honest clients": float(np.mean(np.array(honest_top1_all)))},
                      step=global_epoch)
            wandb.log({f"mean Test/AccTop5 on all honest clients": float(np.mean(np.array(honest_top5_all)))},
                      step=global_epoch)

        wandb.finish()
