import os
import pickle
import struct
import collections
import numpy as np
import torch
from torchvision import transforms
from scipy import special as sp

dataset_name = 'cifar10'
split_ratio= 0.0

class Processor:
    def __init__(self):
        # data load from data set
        self.train_feature = None
        self.train_label = None
        self.test_feature = None
        self.test_label = None
        # data used for train and test totally
        self.global_train_feature = None
        self.global_train_label = None
        self.global_test_feature = None
        self.global_test_label = None
        # data used for each device
        self.local_train_feature = None
        self.local_train_label = None
        self.local_train_index = None
        
        self.local_test_feature = None
        self.local_test_label = None
        self.local_test_index = None

        self.size_class = None
        self.size_device = None
        self.size_feature = None

        self.train_transform = None
        self.test_transform = None

        self.type = 'train'
        self.data_source = None

        output_dir = f"{dataset_name}_Management_Information"

        # 确保目录存在
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        self.output_dir = output_dir

    def __len__(self):
        if self.type == 'train':
            return len(self.global_train_label)
        else:
            return len(self.global_test_label)

    def __getitem__(self, idx):
        if self.type == 'train':
            feature, label = self.global_train_feature[idx], self.global_train_label[idx]
        else:
            feature, label = self.global_test_feature[idx], self.global_test_label[idx]
        if self.data_source == "cifar10":
            feature = feature.reshape(32, 32, 3).astype(np.float32)
        elif self.data_source == "cifar100":
            feature = feature.reshape(32, 32, 3).astype(np.float32)
        elif self.data_source == "mnist":
            feature = feature.reshape(28, 28, 1).astype(np.float32)
        elif self.data_source == "FMNIST":
            feature = feature.reshape(28, 28, 1).astype(np.float32)
        if self.type == 'train':
            img = self.train_transform(feature)
        else:
            img = self.test_transform(feature)
        return img, label

    # region data storage
    def get_input(self, name):
        self.__init__()
        self.data_source = name
        if name == 'cifar10':
            dimension_size = 3072
            self.train_feature = np.empty((0, dimension_size), dtype=np.int)
            self.train_label = np.array([], dtype=np.int)
            for i in range(1, 6):
                with open('./data/cifar10/data_batch_{}'.format(i), 'rb') as fo:
                    dic = pickle.load(fo, encoding='bytes')
                self.train_feature = np.vstack((self.train_feature, dic[b'data']))
                self.train_label = np.hstack((self.train_label, np.array(dic[b'labels'], dtype=np.int)))
            self.train_feature = self.train_feature.reshape(len(self.train_feature), 3, 32, 32).transpose(0, 2, 3, 1)
            self.train_feature = self.train_feature.reshape(len(self.train_feature), -1)
            with open('./data/cifar10/test_batch', 'rb') as fo:
                dic = pickle.load(fo, encoding='bytes')
            self.test_feature = dic[b'data']
            self.test_label = np.array(dic[b'labels'], dtype=np.int)
            self.test_feature = self.test_feature.reshape(len(self.test_feature), 3, 32, 32).transpose(0, 2, 3, 1)
            self.test_feature = self.test_feature.reshape(len(self.test_feature), -1)
            self.train_transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomCrop(32, padding=2), 
                transforms.RandomHorizontalFlip(),  
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
            self.test_transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        
        elif name == 'cifar100':
            dimension_size = 3072 
            self.train_feature = np.empty((0, dimension_size), dtype=np.int)
            self.train_label = np.array([], dtype=np.int)
    
            with open('./data/cifar100/train', 'rb') as fo:
                dic = pickle.load(fo, encoding='bytes')
            self.train_feature = dic[b'data']
            self.train_label = np.array(dic[b'fine_labels'], dtype=np.int)
    
            self.train_feature = self.train_feature.reshape(len(self.train_feature), 3, 32, 32).transpose(0, 2, 3, 1)
            self.train_feature = self.train_feature.reshape(len(self.train_feature), -1)
    
            with open('./data/cifar100/test', 'rb') as fo:
                dic = pickle.load(fo, encoding='bytes')
            self.test_feature = dic[b'data']
            self.test_label = np.array(dic[b'fine_labels'], dtype=np.int)
    
            self.test_feature = self.test_feature.reshape(len(self.test_feature), 3, 32, 32).transpose(0, 2, 3, 1)
            self.test_feature = self.test_feature.reshape(len(self.test_feature), -1)
    
            self.train_transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomCrop(32, padding=2), 
                transforms.RandomHorizontalFlip(), 
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
    
            self.test_transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])

        elif name == 'mnist':
            def load_mnist(path, kind='train'):
                labels_path = os.path.join(path, '{}-labels-idx1-ubyte'.format(kind))
                images_path = os.path.join(path, '{}-images-idx3-ubyte'.format(kind))
                with open(labels_path, 'rb') as lbpath:
                    magic, n = struct.unpack('>II', lbpath.read(8))
                    labels = np.fromfile(lbpath, dtype=np.uint8)

                with open(images_path, 'rb') as imgpath:
                    magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
                    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)

                return images, labels

            self.train_feature, self.train_label = load_mnist('./data/mnist', 'train')
            self.test_feature, self.test_label = load_mnist('./data/mnist', 't10k')
            self.train_transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])
            self.test_transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])
        
        elif name == 'FMNIST':
            def load_mnist(path, kind='train'):
                labels_path = os.path.join(path, '{}-labels-idx1-ubyte'.format(kind))
                images_path = os.path.join(path, '{}-images-idx3-ubyte'.format(kind))
                with open(labels_path, 'rb') as lbpath:
                    magic, n = struct.unpack('>II', lbpath.read(8))
                    labels = np.fromfile(lbpath, dtype=np.uint8)

                with open(images_path, 'rb') as imgpath:
                    magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
                    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)

                return images, labels

            self.train_feature, self.train_label = load_mnist('./data/FMNIST', 'train')
            self.test_feature, self.test_label = load_mnist('./data/FMNIST', 't10k')
            self.train_transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])
            self.test_transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])

        self.size_class = len(set(self.train_label) | set(self.test_label))
        self.size_feature = self.train_feature.shape[1]

        self.train_feature = self.train_feature.astype(int)
        self.train_label = self.train_label.astype(int)
        self.test_feature = self.test_feature.astype(int)
        self.test_label = self.test_label.astype(int)

    

    
    @staticmethod
    def get_size_difference(arr):
        for i, a in enumerate(arr):
            print('the {}th device size: {}'.format(i, len(a)))

   
    def get_local_difference(self, arr):
        n = len(arr)
        res = np.zeros((n, n))
        for i in range(n):
            for j in range(n):
                res[i][j] = self.get_kl_divergence(arr[i], arr[j])
        return res

    def get_global_difference(self, arr):
        c1 = collections.Counter(arr).values()
        return [0] * (self.size_class - len(c1)) + sorted(c1)

    @staticmethod
    def get_kl_divergence(input1, input2):
        c1, c2 = collections.Counter(input1), collections.Counter(input2)
        # print("c1:",c1)
        # print("c2:",c2)
        d1, d2 = [], []
        for key in c1.keys():
            d1.append(c1[key] / len(input1) + 0.00000001)
            d2.append(c2[key] / len(input2) + 0.00000001)
        # print("d1:",d1)
        # print("d2:",d2)
        return sum(sp.rel_entr(d1, d2))
    # endregion

    @staticmethod
    def get_js_divergence(input1, input2):
        c1, c2 = collections.Counter(input1), collections.Counter(input2)

        
        keys = set(c1.keys()).union(set(c2.keys()))

       
        p1 = np.array([c1[key] / len(input1) for key in keys])
        p2 = np.array([c2[key] / len(input2) for key in keys])

        
        m = 0.5 * (p1 + p2)

       
        p1 += 1e-10
        p2 += 1e-10
        m += 1e-10

       
        js_divergence = 0.5 * (np.sum(sp.rel_entr(p1, m)) + np.sum(sp.rel_entr(p2, m)))

        return js_divergence

    
    def gen_local_imbalance(self, num_device, device_size, alpha):
        
        self.size_device = num_device
        self.local_train_feature = []
        self.local_train_label = []
        # init local train index
        self.local_train_index = []

        # separate data set by label 0 - 9
        feature_by_class = []
        for i in range(self.size_class):
            need_idx = np.where(self.train_label == i)[0]
            feature_by_class.append(self.train_feature[need_idx])

        remain_size = int(device_size * alpha)
        sample_size = device_size - remain_size

        sample_feature_pool = np.array([], dtype=np.int)
        sample_label_pool = np.array([], dtype=np.int)

        # keep the proportion of alpha of the original data in the specific class
        for i in range(self.size_class):
            need_idx = np.arange(len(feature_by_class[i]))
            np.random.shuffle(need_idx)
            step = -1
            for j in range(i, self.size_device, self.size_class):
                step += 1
                select_idx = need_idx[step*remain_size:(step+1)*remain_size]
                self.local_train_feature.append(feature_by_class[i][select_idx])
                self.local_train_label.append(np.repeat(i, remain_size))

            # put the data that not selected into the sample pool
            select_idx = need_idx[(step + 1) * remain_size:]
            if sample_feature_pool.size:
                sample_feature_pool = np.vstack([sample_feature_pool, feature_by_class[i][select_idx]])
            else:
                sample_feature_pool = feature_by_class[i][select_idx]
            sample_label_pool = np.hstack([sample_label_pool, np.repeat(i, len(select_idx))])

        # add the data from the sample pool to each device
        need_idx = np.arange(len(sample_feature_pool))
        np.random.shuffle(need_idx)
        step = -1
        for i in range(self.size_device):
            step += 1
            select_idx = need_idx[step*sample_size:(step+1)*sample_size]
            if self.local_train_feature[i].size:
                self.local_train_feature[i] = np.vstack([self.local_train_feature[i], sample_feature_pool[select_idx]])
            else:
                self.local_train_feature[i] = sample_feature_pool[select_idx]
            self.local_train_label[i] = np.hstack([self.local_train_label[i], sample_label_pool[select_idx]])
        self.refresh_global_data()
    
    def gen_local_imbalance_hsu(self, num_device, device_size, alpha):
        
        
        self.size_device = num_device
        self.local_train_feature = []
        self.local_train_label = []
        self.local_test_feature = []
        self.local_test_label = []
    
       
        train_class_counts = np.zeros((num_device, self.size_class), dtype=int)
        test_class_counts = np.zeros((num_device, self.size_class), dtype=int)
    
       
        feature_by_class_train = []
        feature_by_class_test = []
        for i in range(self.size_class):
           
            train_idx = np.where(self.train_label == i)[0]
            test_idx = np.where(self.test_label == i)[0]
        
           
            feature_by_class_train.append(self.train_feature[train_idx])
            feature_by_class_test.append(self.test_feature[test_idx])
    
        
        for i in range(self.size_class):
          
            class_data_train = feature_by_class_train[i]
            class_data_test = feature_by_class_test[i]
        
            num_class_data_train = len(class_data_train)
            num_class_data_test = len(class_data_test)
        
            
            num_data_to_split_train = int(num_class_data_train * split_ratio)
            num_data_to_split_test = int(num_class_data_test * split_ratio)
        
           
            avg_train_data = num_data_to_split_train // num_device
            avg_test_data = num_data_to_split_test // num_device

           
            # remaining_train_data_to_split = num_data_to_split_train - avg_train_data * num_device
            # remaining_test_data_to_split = num_data_to_split_test - avg_test_data * num_device
        
            idx_start_train = 0
            idx_start_test = 0
            for client_id in range(num_device):
               
                idx_end_train = idx_start_train + avg_train_data
                idx_end_test = idx_start_test + avg_test_data
            
               
                if idx_end_train > num_data_to_split_train:
                    idx_end_train = num_data_to_split_train
                if idx_end_test > num_data_to_split_test:
                    idx_end_test = num_data_to_split_test
            
               
                if len(self.local_train_feature) <= client_id:
                    self.local_train_feature.append(class_data_train[idx_start_train:idx_end_train])
                    self.local_train_label.append(np.repeat(i, avg_train_data))
                else:
                    self.local_train_feature[client_id] = np.vstack([self.local_train_feature[client_id], class_data_train[idx_start_train:idx_end_train]])
                    self.local_train_label[client_id] = np.hstack([self.local_train_label[client_id], np.repeat(i, avg_train_data)])
            
               
                if len(self.local_test_feature) <= client_id:
                    self.local_test_feature.append(class_data_test[idx_start_test:idx_end_test])
                    self.local_test_label.append(np.repeat(i, avg_test_data))
                else:
                    self.local_test_feature[client_id] = np.vstack([self.local_test_feature[client_id], class_data_test[idx_start_test:idx_end_test]])
                    self.local_test_label[client_id] = np.hstack([self.local_test_label[client_id], np.repeat(i, avg_test_data)])
            
                
                train_class_counts[client_id, i] = avg_train_data
                test_class_counts[client_id, i] = avg_test_data
            
                
                idx_start_train = idx_end_train
                idx_start_test = idx_end_test
    
        
        for i in range(self.size_class):
           
            remaining_train_data = feature_by_class_train[i][int(num_class_data_train * split_ratio):]
            remaining_test_data = feature_by_class_test[i][int(num_class_data_test * split_ratio):]
            
            
        
            num_remaining_train = len(remaining_train_data)
            num_remaining_test = len(remaining_test_data)
        
            
            qc = np.random.dirichlet([alpha] * num_device)
        
           
            remaining_train_split = (qc * num_remaining_train).astype(int)
            remaining_test_split = (qc * num_remaining_test).astype(int)
        
            
            diff_train = num_remaining_train - np.sum(remaining_train_split)
            diff_test = num_remaining_test - np.sum(remaining_test_split)
        
            if diff_train > 0:
                remaining_train_split[np.argmax(qc)] += diff_train
            if diff_test > 0:
                remaining_test_split[np.argmax(qc)] += diff_test
        
           
            idx_start_train = 0
            idx_start_test = 0
            for client_id in range(num_device):
                idx_end_train = idx_start_train + remaining_train_split[client_id]
                idx_end_test = idx_start_test + remaining_test_split[client_id]
            
                
                if idx_end_train > num_remaining_train:
                    idx_end_train = num_remaining_train
                if idx_end_test > num_remaining_test:
                    idx_end_test = num_remaining_test
            
               
                self.local_train_feature[client_id] = np.vstack([self.local_train_feature[client_id], remaining_train_data[idx_start_train:idx_end_train]])
                self.local_train_label[client_id] = np.hstack([self.local_train_label[client_id], np.repeat(i, remaining_train_split[client_id])])
            
               
                self.local_test_feature[client_id] = np.vstack([self.local_test_feature[client_id], remaining_test_data[idx_start_test:idx_end_test]])
                self.local_test_label[client_id] = np.hstack([self.local_test_label[client_id], np.repeat(i, remaining_test_split[client_id])])
            
               
                train_class_counts[client_id, i] += remaining_train_split[client_id]
                test_class_counts[client_id, i] += remaining_test_split[client_id]
            
                
                idx_start_train = idx_end_train
                idx_start_test = idx_end_test


       
        # for client_id in range(num_device):
        #     num_train_samples = len(self.local_train_feature[client_id])
        #     num_test_samples = len(self.local_test_feature[client_id])
        
        #     print(f"Client {client_id}:")
        #     print(f"  Train data: {num_train_samples} samples")
        #     for i in range(self.size_class):
        #         print(f"    Class {i}: {train_class_counts[client_id, i]} samples")

        #     print(f"  Test data: {num_test_samples} samples")
        #     for i in range(self.size_class):
        #         print(f"    Class {i}: {test_class_counts[client_id, i]} samples")
    
        output_file = os.path.join(self.output_dir, "client_data_summary.txt")
    
       
        with open(output_file, "w") as f:
            f.write("数据分配完成，客户端数据量汇总:\n")
            for client_id in range(num_device):
                num_train_samples = len(self.local_train_feature[client_id])
                num_test_samples = len(self.local_test_feature[client_id])

                f.write(f"Client {client_id}:\n")
                f.write(f"  Train data: {num_train_samples} samples\n")
                for i in range(self.size_class):
                    f.write(f"    Class {i}: {train_class_counts[client_id, i]} samples\n")

                f.write(f"  Test data: {num_test_samples} samples\n")
                for i in range(self.size_class):
                    f.write(f"    Class {i}: {test_class_counts[client_id, i]} samples\n")

        print(f"数据已保存至 {output_file}")    

        self.refresh_global_data()

    def gen_size_imbalance(self, list_size):
        # generate size imbalance
        # list_size is a size indicates the size of each device
        self.size_device = len(list_size)
        self.local_train_feature = []
        self.local_train_label = []
        # init local_train_index
        self.local_train_index = []

        need_idx = np.arange(len(self.train_feature))
        np.random.shuffle(need_idx)
        cur_idx = 0
        for s in list_size:
            self.local_train_feature.append(self.train_feature[need_idx[cur_idx:cur_idx+s]])
            self.local_train_label.append(self.train_label[need_idx[cur_idx:cur_idx+s]])
            cur_idx += s
        self.refresh_global_data()

    def gen_global_imbalance(self, num_device, device_size, num_each_class):
        
        self.size_device = num_device
        self.local_train_feature = []
        self.local_train_label = []

        feature_by_class = []
        for i in range(self.size_class):
            need_idx = np.where(self.train_label == i)[0]
            feature_by_class.append(self.train_feature[need_idx])
        sample_feature_pool = np.array([], dtype=np.int)
        sample_label_pool = np.array([], dtype=np.int)

        for i in range(self.size_class):
            need_idx = np.arange(len(feature_by_class[i]))
            np.random.shuffle(need_idx)
            if sample_feature_pool.size:
                sample_feature_pool = np.vstack([sample_feature_pool,
                                                 feature_by_class[i][need_idx[:num_each_class[i]]]])
            else:
                sample_feature_pool = feature_by_class[i][need_idx[:num_each_class[i]]]
            sample_label_pool = np.hstack([sample_label_pool, np.repeat(i, num_each_class[i])])

        need_idx = np.arange(len(sample_feature_pool))
        np.random.shuffle(need_idx)
        step = -1
        for i in range(self.size_device):
            step += 1
            select_idx = need_idx[step*device_size:(step+1)*device_size]
            self.local_train_feature.append(sample_feature_pool[select_idx])
            self.local_train_label.append(sample_label_pool[select_idx])
        self.refresh_global_data()

    def refresh_global_data(self):

        # initialize global train features and labels
        self.global_train_feature = np.empty((0, self.size_feature), dtype=np.int)
        self.global_train_label = np.array([], dtype=np.int)
        self.local_train_index = []

        
        self.global_test_feature = np.empty((0, self.size_feature), dtype=np.int)
        self.global_test_label = np.array([], dtype=np.int)
        self.local_test_index = []

        idx_start_train = 0
        idx_start_test = 0

        for i in range(self.size_device):
           
            self.global_train_feature = np.vstack([self.global_train_feature, self.local_train_feature[i]])
            self.global_train_label = np.hstack([self.global_train_label, self.local_train_label[i]])
            self.local_train_index.append(np.arange(idx_start_train, idx_start_train + len(self.local_train_label[i])))
            idx_start_train += len(self.local_train_label[i])
            # print(self.local_train_index)
        
            
            self.global_test_feature = np.vstack([self.global_test_feature, self.local_test_feature[i]])
            self.global_test_label = np.hstack([self.global_test_label, self.local_test_label[i]])
            self.local_test_index.append(np.arange(idx_start_test, idx_start_test + len(self.local_test_label[i])))
            idx_start_test += len(self.local_test_label[i])
    #  endregion
