from utils.DataLoader import DataLoader
import utils
import os
import torch
import torchvision
import random
from torchvision import transforms as transforms
import numpy as np
import copy
from torch.utils.data import ConcatDataset


class DataLoader_cifar10_noise(DataLoader):
    def __init__(self,
                 split_num=200,
                 pick_num=2,
                 batch_size=50,
                 input_require_shape=[3, -1, -1],
                 shuffle=True,
                 pool_size=None,
                 recreate=False,
                 params=None,
                 *args,
                 **kwargs):
        if params is not None:
            if pool_size is not None:
                split_num = pool_size * params['N']
                pick_num = params['N']
            else:
                split_num = int(params['C'] * params['N'])
                pick_num = params['N']
            batch_size = params['batch_size']
            noise_ratio = params['NR']
        if split_num % pick_num != 0:
            raise RuntimeError('split_num must be divisible by the number of pick_num.')
        pool_size = split_num // pick_num
        name = 'CIFAR10_noise_pool_' + str(pool_size) + 'split_' + str(split_num) + 'pick' + str(
            pick_num) + '_batchsize_' + str(batch_size) + '_sort_split_input_require_shape_' + str(input_require_shape) + '_noise_ratio_' + str(noise_ratio)
        nickname = 'cifar10_noise B' + str(batch_size) + ' S' + str(split_num) + ' P' + str(pick_num) + ' N' + str(pool_size)
        super().__init__(name, nickname, pool_size, batch_size, input_require_shape)

        file_path = utils.pool_folder_path + name + '.npy'
        if os.path.exists(file_path) and (recreate == False):
            data_loader = np.load(file_path, allow_pickle=True).item()
            for attr in list(data_loader.__dict__.keys()):
                setattr(self, attr, data_loader.__dict__[attr])
            print('Successfully Read the Data Pool.')
        else:
            transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
            trainset = torchvision.datasets.CIFAR10(root=utils.data_folder_path, train=True,
                                                    download=True, transform=transform)
            trainloader = torch.utils.data.DataLoader(trainset, batch_size=trainset.data.shape[0],
                                                      shuffle=True, num_workers=1)
            testset = torchvision.datasets.CIFAR10(root=utils.data_folder_path, train=False,
                                                   download=True, transform=transform)
            testloader = torch.utils.data.DataLoader(testset, batch_size=testset.data.shape[0],
                                                     shuffle=False, num_workers=1)
            totalset = ConcatDataset([trainset, testset])
            totalloader = torch.utils.data.DataLoader(totalset, batch_size=len(totalset),
                                                      shuffle=True, num_workers=1)
            # num_samples = len(trainset)
            # num_samples_to_modify = int(num_samples * noise_ratio)
            # indices_to_modify = np.random.choice(num_samples, num_samples_to_modify, replace=False)
            # num_classes = 10
            # for idx in indices_to_modify:
            #     new_label = np.random.randint(0, num_classes)  # Generate random incorrect label
            #     while new_label == trainset.targets[idx]:  # Ensure the new label is different from the original one
            #         new_label = np.random.randint(0, num_classes)
            #     trainset.targets[idx] = new_label


            for i, (input_data, targets) in enumerate(trainloader):
                train_input_data = input_data
                train_target_data = targets
            for i, (input_data, targets) in enumerate(testloader):
                test_input_data = input_data
                test_target_data = targets
            for i, (input_data, targets) in enumerate(totalloader):
                total_input_data = input_data
                total_target_data = targets

            self.cal_data_shape(train_input_data.shape)
            self.target_class_num = 10

            self.total_training_number = len(trainset)
            self.total_test_number = len(testset)

            self.output_size = 10
            self.model4data = 'cnn'
            self.task_name = 'cifar10_classification'

            def LSR(target,):
                num_samples = len(target)
                num_samples_to_modify = int(num_samples * noise_ratio)
                indices_to_modify = np.random.choice(num_samples, num_samples_to_modify, replace=False)
                # print(indices_to_modify)
                # assert 1==0
                for idx in indices_to_modify:
                    new_label = np.random.randint(0, self.output_size)  # Generate random incorrect label
                    while new_label == target[idx]:  # Ensure the new label is different from the original one
                        new_label = np.random.randint(0, self.output_size)
                    target[idx] = new_label
                return target

            def create_data_pool(data_pool, input_data, target_data):
                order = torch.argsort(target_data)
                input_data = input_data[order, :]
                target_data = target_data[order]

                count = 0
                amount = input_data.shape[0] // split_num
                indices = list(range(input_data.shape[0]))
                split_data_indices_list = []
                for split_idx in range(split_num):
                    start_idx = count
                    end_idx = count + amount
                    if end_idx > input_data.shape[0] - 1:
                        end_idx = input_data.shape[0] - 1
                    split_data_indices = indices[start_idx: end_idx]
                    split_data_indices_list.append(split_data_indices)
                    count += amount
                for pool_idx in range(pool_size):
                    data_indices = []

                    for i in range(pick_num):
                        pick_data_indices = split_data_indices_list[random.randint(0, len(split_data_indices_list) - 1)]
                        data_indices += pick_data_indices
                        split_data_indices_list.remove(pick_data_indices)
                    random.shuffle(data_indices)
                    local_data_number = len(data_indices)
                    train_test_split_idx = int(len(data_indices) * self.total_training_number / (
                                self.total_training_number + self.total_test_number))

                    train_indices = data_indices[:train_test_split_idx]
                    test_indices = data_indices[train_test_split_idx:]

                    local_train_number = len(train_indices)
                    local_test_number = len(test_indices)

                    train_batch_data_indices_list = DataLoader.separate_list(train_indices, self.batch_size)
                    test_batch_data_indices_list = DataLoader.separate_list(test_indices, self.batch_size)

                    local_train_data = []
                    for batch_data_indices in train_batch_data_indices_list:
                        batch_input_data = input_data[batch_data_indices].reshape([-1] + self.input_data_shape).float()

                        batch_target_data = target_data[batch_data_indices]
                        batch_target_data = LSR(batch_target_data)
                        local_train_data.append((batch_input_data, batch_target_data))

                    local_test_data = []
                    for batch_data_indices in test_batch_data_indices_list:
                        batch_input_data = input_data[batch_data_indices].reshape([-1] + self.input_data_shape).float()

                        batch_target_data = target_data[batch_data_indices]
                        local_test_data.append((batch_input_data, batch_target_data))

                    data_pool[pool_idx]['local_training_data'] = local_train_data
                    data_pool[pool_idx]['local_test_data'] = local_test_data
                    data_pool[pool_idx]['local_training_number'] = local_train_number
                    data_pool[pool_idx]['local_test_number'] = local_test_number

                    data_pool[pool_idx]['data_name'] = str(pool_idx)

            data_pool = [{} for _ in range(self.pool_size)]
            # local_training local_test
            create_data_pool(data_pool, total_input_data, total_target_data)

            self.data_pool = data_pool

            np.save(file_path, self)

    def allocate(self, client_list):
        # choose_data_pool_item_indices = np.random.choice(list(range(self.pool_size)), len(client_list), replace=False)
        choose_data_pool_item_indices = list(range(self.pool_size))
        for idx, client in enumerate(client_list):
            data_pool_item = self.data_pool[choose_data_pool_item_indices[idx]]
            client.update_data(choose_data_pool_item_indices[idx],
                               data_pool_item['local_training_data'],
                               data_pool_item['local_training_number'],
                               data_pool_item['local_test_data'],
                               data_pool_item['local_test_number'])

# python main.py --task cifar10_classification --dataloader cifar10_noise --model cnn --algorithm fedeba --num_rounds 2000 --num_epochs 1 --learning_rate 0.1 --learning_rate_decay 0.999 --C 100 --N 2 --P 0.1 --batch_size 50 --eval_interval 1 --lr_scheduler 0 --gpu 4 --LSR
