import os
import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from sklearn import metrics
from collections import Counter
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR

dataset_name = 'cifar100'
torch.autograd.set_detect_anomaly(True)

def get_current_lr(optimizer):
    # 获取优化器的参数组
    for param_group in optimizer.param_groups:
        # 返回第一个参数组的学习率
        return param_group['lr']

class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label


class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None, coodinator_dict=None, coord=None):
        self.args = args
        # self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
        
        if coodinator_dict is not None:
            self.ldr_train_dict = {
                client_id: DataLoader(DatasetSplit(dataset, idxs_set), batch_size=self.args.local_bs, shuffle=True)
                for client_id, idxs_set in coodinator_dict.items()
            }
        self.clientset_id = coord
        
        
        self.saved_params_dir = f'{dataset_name}_saved_params'  # 目录用于保存客户端参数

        # 确保保存参数的目录存在
        if not os.path.exists(self.saved_params_dir):
            os.makedirs(self.saved_params_dir)

        # 计算每个类别的样本数量
        all_labels = []
        for _, labels in self.ldr_train:
            all_labels.extend(labels.numpy())

        num_classes = 100
        class_counts = Counter(all_labels)

        # print(class_counts)

        class_weights = torch.tensor([1.0 / class_counts[c] if class_counts[c] != 0 else 0.0 for c in range(num_classes)]).float().to(self.args.device)

        class_weights = class_weights / class_weights.sum()
        
        # print(class_weights)
        
        # 使用类别平衡交叉熵
        self.loss_func = nn.CrossEntropyLoss()
        self.loss_func_weight = nn.CrossEntropyLoss(weight=class_weights)

    def train(self, net, lr_glob):
        net.train()
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr = lr_glob, momentum=self.args.momentum, weight_decay=self.args.weight_decay)
        scheduler = StepLR(optimizer, step_size=1, gamma=0.99)

        epoch_loss = []
        # for iter in range(self.args.local_ep):
        for iter in tqdm(range(self.args.local_ep), desc="Local Epochs", leave=False):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                labels = labels.long()
                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func_weight(log_probs, labels)
                loss.backward()
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
            scheduler.step()
        print(get_current_lr(optimizer))
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss), get_current_lr(optimizer)
    
    def train_to_mnist(self, net, lr_glob, selected_clients):
        
        net.train()

        total_data_points = sum([len(self.ldr_train_dict[k].dataset) for k in self.clientset_id])
        # total_data_points = sum([len(self.ldr_train_dict[k].dataset) for k in selected_clients if k in self.clientset_id])
        fc1_weighted_sum = None

        client_losses = []  # 用于存储每个客户端的损失

        for k in self.clientset_id:
            
            # 训练开始时加载对应客户端的网络参数
            param_file = os.path.join(self.saved_params_dir, f'client_{k}_params.pth')
            if os.path.exists(param_file):
                # 从文件加载参数
                net.load_state_dict(torch.load(param_file))
            
            # 锁定 conv1、conv2、fc1 和 fc2 的参数，使它们在训练过程中保持不变
            for name, param in net.named_parameters():
                if 'conv1' in name or 'conv2' in name or 'fc2' in name:
                    param.requires_grad = False
                else:
                    param.requires_grad = True

            
            # 滤波器fc1版本
            # -------------------------------------------------------------------------------------------------------------------------------------
            optimizer_fc1 = torch.optim.SGD(net.fc1.parameters(), lr=lr_glob, momentum=self.args.momentum, weight_decay=self.args.weight_decay)
            optimizer_fc3 = torch.optim.SGD(net.fc3.parameters(), lr=lr_glob, momentum=self.args.momentum, weight_decay=self.args.weight_decay)
            scheduler_fc1 = StepLR(optimizer_fc1, step_size=1, gamma=0.99)
            scheduler_fc3 = StepLR(optimizer_fc3, step_size=1, gamma=0.99)
            # -------------------------------------------------------------------------------------------------------------------------------------
            # # # 滤波器conv2版本
            # # -------------------------------------------------------------------------------------------------------------------------------------
            # optimizer_conv2 = torch.optim.SGD(net.conv2.parameters(), lr=lr_glob, momentum=self.args.momentum, weight_decay=self.args.weight_decay)
            # optimizer_fc3 = torch.optim.SGD(net.fc3.parameters(), lr=lr_glob, momentum=self.args.momentum, weight_decay=self.args.weight_decay)
            # optimizer_fc4 = torch.optim.SGD(net.fc4.parameters(), lr=lr_glob, momentum=self.args.momentum, weight_decay=self.args.weight_decay)
            # scheduler_conv2 = StepLR(optimizer_conv2, step_size=1, gamma=0.99)
            # scheduler_fc3 = StepLR(optimizer_fc3, step_size=1, gamma=0.99)
            # scheduler_fc4 = StepLR(optimizer_fc4, step_size=1, gamma=0.99)
            # # -------------------------------------------------------------------------------------------------------------------------------------

            epoch_loss = []
            # for iter in range(self.args.local_ep):
            for iter in tqdm(range(self.args.local_ep), desc="Local Epochs", leave=False):
                batch_loss = []
                for batch_idx, (images, labels) in enumerate(self.ldr_train_dict[k]):
                    images, labels = images.to(self.args.device), labels.to(self.args.device)
                    labels = labels.long()
                    net.zero_grad()

                    # 前向传播
                    log_probs_fc2, log_probs_fc3 = net(images)

                    # 使用类别平衡交叉熵更新 fc1 (log_probs_fc2 是通过 fc1 + fc2 输出)
                    loss_fc2 = self.loss_func_weight(log_probs_fc2, labels)

                    # 使用普通交叉熵更新 fc3
                    loss_fc3 = self.loss_func(log_probs_fc3, labels)

                    # 总损失是两部分损失的加权和 (根据需要调整权重)
                    loss = loss_fc2 + loss_fc3

                    loss.backward(retain_graph=True)  # 计算 fc3 的梯度

                    optimizer_fc1.zero_grad()

                    optimizer_fc3.step()  # 更新 fc3 的参数
                    # optimizer_fc4.step()
                    
                    loss_fc2.backward()  # 保留计算图以进行后续操作
                    
                    optimizer_fc1.step()  # 更新 fc1 的参数
                  
                    if self.args.verbose and batch_idx % 10 == 0:
                        print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                            iter, batch_idx * len(images), len(self.ldr_train.dataset),
                                100. * batch_idx / len(self.ldr_train), loss.item()))
                    batch_loss.append(loss.item())
                epoch_loss.append(sum(batch_loss) / len(batch_loss))
                scheduler_fc1.step()
                scheduler_fc3.step()
                # scheduler_fc3.step()
            
            # 记录每个客户端的平均损失
            client_losses.append(sum(epoch_loss) / len(epoch_loss))

            print(client_losses)

            # if k in selected_clients:
            #     client_data_size = len(self.ldr_train_dict[k].dataset)
            #     conv2_state_dict = {name: param for name, param in net.state_dict().items() if 'conv2' in name}

            #     if conv2_weighted_sum is None:
            #         conv2_weighted_sum = {name: param * client_data_size for name, param in conv2_state_dict.items()}
            #     else:
            #         for name in conv2_state_dict:
            #             conv2_weighted_sum[name] += conv2_state_dict[name] * client_data_size
            
            client_data_size = len(self.ldr_train_dict[k].dataset)
            fc1_state_dict = {name: param for name, param in net.state_dict().items() if 'fc1' in name}
            
            if fc1_weighted_sum is None:
                fc1_weighted_sum = {name: param * client_data_size for name, param in fc1_state_dict.items()}
            else:
                for name in fc1_state_dict:
                    fc1_weighted_sum[name] += fc1_state_dict[name] * client_data_size
            
            torch.save(net.state_dict(), os.path.join(self.saved_params_dir, f'client_{k}_params.pth'))
            
        
        
        # 计算加权平均
        for name in fc1_weighted_sum:
            fc1_weighted_sum[name] /= total_data_points

        # 将聚合后的 fc1 参数更新到所有客户端并保存
        for k in self.clientset_id:
            # 加载客户端的完整参数
            param_file = os.path.join(self.saved_params_dir, f'client_{k}_params.pth')
            if os.path.exists(param_file):
                net.load_state_dict(torch.load(param_file))

            # 更新 fc1 参数为聚合后的值
            net_state_dict = net.state_dict()
            for name in fc1_weighted_sum:
                net_state_dict[name] = fc1_weighted_sum[name]
            net.load_state_dict(net_state_dict)

            # 保存更新后的参数
            torch.save(net.state_dict(), os.path.join(self.saved_params_dir, f'client_{k}_params.pth'))

            
        # # 只上传 fc1 的参数到服务器进行聚合
        # fc1_state_dict = {k: v for k, v in net.state_dict().items() if 'fc1' in k}
        # return fc1_state_dict, sum(epoch_loss) / len(epoch_loss)
        return sum(client_losses) / len(client_losses), get_current_lr(optimizer_fc1)
    
    def train_to_cifar(self, net, lr_glob, selected_clients):
        
        net.train()

        total_data_points = sum([len(self.ldr_train_dict[k].dataset) for k in self.clientset_id])
        # # 采样版本
        # # ------------------------------------------------------------------------------------------------------------
        # total_data_points = sum([len(self.ldr_train_dict[k].dataset) for k in selected_clients if k in self.clientset_id])
        # # ------------------------------------------------------------------------------------------------------------
        
        # # 卷积层滤波器
        # # ------------------------------------------------------------------------------------------------------------
        # conv3_weighted_sum = None
        # bn3_weighted_sum = None
        # # ------------------------------------------------------------------------------------------------------------
        # 线性层滤波器
        # ------------------------------------------------------------------------------------------------------------
        fc1_weighted_sum = None
        # ------------------------------------------------------------------------------------------------------------

        client_losses = []  # 用于存储每个客户端的损失

        for k in self.clientset_id:
            
            # 训练开始时加载对应客户端的网络参数
            param_file = os.path.join(self.saved_params_dir, f'client_{k}_params.pth')
            if os.path.exists(param_file):
                # 从文件加载参数
                net.load_state_dict(torch.load(param_file))
            
            # 锁定参数，使它们在训练过程中保持不变
            
            # # 卷积层滤波器
            # # ------------------------------------------------------------------------------------------------------------
            # for name, param in net.named_parameters():
            #     if 'conv1' in name or 'bn1' in name or 'conv2' in name or 'bn2' in name or 'fc1' in name or 'fc2' in name:
            #         param.requires_grad = False
            #     else:
            #         param.requires_grad = True

            # optimizer_conv3 = torch.optim.SGD(net.conv3.parameters(), lr=lr_glob, momentum=self.args.momentum, weight_decay=self.args.weight_decay)
            # optimizer_bn3 = torch.optim.SGD(net.bn3.parameters(), lr=lr_glob, momentum=self.args.momentum, weight_decay=self.args.weight_decay)
            # optimizer_fc3 = torch.optim.SGD(net.fc3.parameters(), lr=lr_glob, momentum=self.args.momentum, weight_decay=self.args.weight_decay)
            # optimizer_fc4 = torch.optim.SGD(net.fc4.parameters(), lr=lr_glob, momentum=self.args.momentum, weight_decay=self.args.weight_decay)
            # scheduler_conv3 = StepLR(optimizer_conv3, step_size=1, gamma=0.99)
            # scheduler_bn3 = StepLR(optimizer_bn3, step_size=1, gamma=0.99)
            # scheduler_fc3 = StepLR(optimizer_fc3, step_size=1, gamma=0.99)
            # scheduler_fc4 = StepLR(optimizer_fc4, step_size=1, gamma=0.99)
            # # ------------------------------------------------------------------------------------------------------------
            # 线性层滤波器
            # ------------------------------------------------------------------------------------------------------------
            for name, param in net.named_parameters():
                if 'conv1' in name or 'bn1' in name or 'conv2' in name or 'bn2' in name or 'conv3' in name or 'bn3' in name or 'fc2' in name:
                    param.requires_grad = False
                else:
                    param.requires_grad = True

            optimizer_fc1 = torch.optim.SGD(net.fc1.parameters(), lr=lr_glob, momentum=self.args.momentum, weight_decay=self.args.weight_decay)
            optimizer_fc3 = torch.optim.SGD(net.fc3.parameters(), lr=lr_glob, momentum=self.args.momentum, weight_decay=self.args.weight_decay)
            scheduler_fc1 = StepLR(optimizer_fc1, step_size=1, gamma=0.99)
            scheduler_fc3 = StepLR(optimizer_fc3, step_size=1, gamma=0.99)
            # ------------------------------------------------------------------------------------------------------------


            epoch_loss = []
            # for iter in range(self.args.local_ep):
            for iter in tqdm(range(self.args.local_ep), desc="Local Epochs", leave=False):
                batch_loss = []
                for batch_idx, (images, labels) in enumerate(self.ldr_train_dict[k]):
                    images, labels = images.to(self.args.device), labels.to(self.args.device)
                    labels = labels.long()
                    net.zero_grad()

                    # 前向传播
                    log_probs_fc2, log_probs_fc4 = net(images)

                    # 使用类别平衡交叉熵更新 fc1 (log_probs_fc2 是通过 fc1 + fc2 输出)
                    loss_fc2 = self.loss_func_weight(log_probs_fc2, labels)

                    # 使用普通交叉熵更新 fc4
                    loss_fc4 = self.loss_func(log_probs_fc4, labels)

                    # 总损失是两部分损失的加权和 (根据需要调整权重)
                    loss = loss_fc2 + loss_fc4

                    loss.backward(retain_graph=True)  # 计算 fc3 和 fc4 的梯度

                    # optimizer_conv3.zero_grad()
                    # optimizer_bn3.zero_grad()
                    optimizer_fc1.zero_grad()

                    optimizer_fc3.step()  # 更新 fc3 的参数
                    # optimizer_fc4.step()
                    
                    loss_fc2.backward()  # 保留计算图以进行后续操作
                    
                    # optimizer_conv3.step()  # 更新 fc1 的参数
                    # optimizer_bn3.step()  # 更新 fc1 的参数
                    optimizer_fc1.step()
                  
                    if self.args.verbose and batch_idx % 10 == 0:
                        print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                            iter, batch_idx * len(images), len(self.ldr_train.dataset),
                                100. * batch_idx / len(self.ldr_train), loss.item()))
                    batch_loss.append(loss.item())
                epoch_loss.append(sum(batch_loss) / len(batch_loss))
                # scheduler_conv3.step()
                # scheduler_bn3.step()
                scheduler_fc3.step()
                # scheduler_fc4.step()
                scheduler_fc1.step()
            
            # 记录每个客户端的平均损失
            client_losses.append(sum(epoch_loss) / len(epoch_loss))

            print(client_losses)

            
        # # 卷积层滤波器
        # # ------------------------------------------------------------------------------------------------------------------
        #     # # 采样版本
        #     # # ------------------------------------------------------------------------------------------------------------
        #     # if k in selected_clients:
        #     #     client_data_size = len(self.ldr_train_dict[k].dataset)
        #     #     conv3_state_dict = {name: param for name, param in net.state_dict().items() if 'conv3' in name}

        #     #     if conv3_weighted_sum is None:
        #     #         conv3_weighted_sum = {name: param * client_data_size for name, param in conv3_state_dict.items()}
        #     #     else:
        #     #         for name in conv3_state_dict:
        #     #             conv3_weighted_sum[name] += conv3_state_dict[name] * client_data_size
        #     # # ------------------------------------------------------------------------------------------------------------

        #     # 不采样版本
        #     # ------------------------------------------------------------------------------------------------------------
        #     client_data_size = len(self.ldr_train_dict[k].dataset)
        #     conv3_state_dict = {name: param for name, param in net.state_dict().items() if 'conv3' in name}
            
        #     if conv3_weighted_sum is None:
        #         conv3_weighted_sum = {name: param * client_data_size for name, param in conv3_state_dict.items()}
        #     else:
        #         for name in conv3_state_dict:
        #             conv3_weighted_sum[name] += conv3_state_dict[name] * client_data_size
            
        #     bn3_state_dict = {name: param for name, param in net.state_dict().items() if 'bn3' in name}
            
        #     if bn3_weighted_sum is None:
        #         bn3_weighted_sum = {name: param.float() * client_data_size for name, param in bn3_state_dict.items()}
        #     else:
        #         for name in bn3_state_dict:
        #             bn3_weighted_sum[name] += bn3_state_dict[name].float() * client_data_size
            
        #     torch.save(net.state_dict(), os.path.join(self.saved_params_dir, f'client_{k}_params.pth'))
        #     # ------------------------------------------------------------------------------------------------------------
            
        # # 计算加权平均
        # for name in conv3_weighted_sum:
        #     conv3_weighted_sum[name] /= total_data_points
        
        # for name in bn3_weighted_sum:
        #     bn3_weighted_sum[name] /= float(total_data_points)

        # # 将聚合后的 fc1 参数更新到所有客户端并保存
        # for k in self.clientset_id:
        #     # 加载客户端的完整参数
        #     param_file = os.path.join(self.saved_params_dir, f'client_{k}_params.pth')
        #     if os.path.exists(param_file):
        #         net.load_state_dict(torch.load(param_file))

        #     # 更新 conv3 参数为聚合后的值
        #     net_state_dict = net.state_dict()
        #     for name in conv3_weighted_sum:
        #         net_state_dict[name] = conv3_weighted_sum[name]
        #     for name in bn3_weighted_sum:
        #         net_state_dict[name] = bn3_weighted_sum[name]
        #     net.load_state_dict(net_state_dict)

        #     # 保存更新后的参数
        #     torch.save(net.state_dict(), os.path.join(self.saved_params_dir, f'client_{k}_params.pth'))
        # # ------------------------------------------------------------------------------------------------------------------
        
        # 线性层滤波器
        # ------------------------------------------------------------------------------------------------------------------
            client_data_size = len(self.ldr_train_dict[k].dataset)
            fc1_state_dict = {name: param for name, param in net.state_dict().items() if 'fc1' in name}
            
            if fc1_weighted_sum is None:
                fc1_weighted_sum = {name: param * client_data_size for name, param in fc1_state_dict.items()}
            else:
                for name in fc1_state_dict:
                    fc1_weighted_sum[name] += fc1_state_dict[name] * client_data_size
            
            torch.save(net.state_dict(), os.path.join(self.saved_params_dir, f'client_{k}_params.pth'))
        
        for name in fc1_weighted_sum:
            fc1_weighted_sum[name] /= total_data_points
        
        for k in self.clientset_id:
            # 加载客户端的完整参数
            param_file = os.path.join(self.saved_params_dir, f'client_{k}_params.pth')
            if os.path.exists(param_file):
                net.load_state_dict(torch.load(param_file))

            # 更新 fc1 参数为聚合后的值
            net_state_dict = net.state_dict()
            for name in fc1_weighted_sum:
                net_state_dict[name] = fc1_weighted_sum[name]
            net.load_state_dict(net_state_dict)

            torch.save(net.state_dict(), os.path.join(self.saved_params_dir, f'client_{k}_params.pth'))
        # ------------------------------------------------------------------------------------------------------------------

        return sum(client_losses) / len(client_losses), get_current_lr(optimizer_fc1)