

import copy

import numpy as np
import torch
from torch.utils.data import DataLoader

# from Data import DatasetSplit
from datasets import DatasetSplit
from utils import init_model
from utils import init_optimizer, model_parameter_vector

class Node(object):

    def __init__(self, num_id, local_data, train_set, args):
        self.num_id = num_id
        self.args = args
        self.node_num = self.args.node_num
        if num_id == -1:
            self.valid_ratio = args.server_valid_ratio
        else:
            self.valid_ratio = args.client_valid_ratio

        if self.args.dataset == 'cifar10' or self.args.dataset == 'fmnist':
            self.num_classes = 10
        elif self.args.dataset == 'cifar100':
            self.num_classes = 100

        if args.iid == 1 or num_id == -1:
            # for the server, use the validate_set as the training data, and use local_data for testing
            self.local_data, self.validate_set = self.train_val_split_forServer(local_data.indices, train_set, self.valid_ratio, self.num_classes)
        else:
            self.local_data, self.validate_set = self.train_val_split(local_data, train_set, self.valid_ratio)

        self.model = init_model(self.args.local_model, self.args).cuda()
        self.optimizer = init_optimizer(self.num_id, self.model, args)

        # cluster_id assignment for swapping label
        if self.num_id != self.node_num:
            if self.args.num_cluster == 4:
                # 4 clusters existed
                if self.num_id  < self.node_num // 4:
                    self.cluster_id = 0
                elif self.num_id  < 2 * self.node_num // 4:
                    self.cluster_id = 1
                elif self.num_id  < 3 * self.node_num // 4:
                    self.cluster_id = 2
                else:
                    self.cluster_id = 3
            elif self.args.num_cluster == 2:
                # 2 clusters existed
                if self.num_id + 1 > self.node_num // 2:
                    self.cluster_id = 1
                else:
                    self.cluster_id = 0
            else:
                raise ValueError('The number of clusters is not well-defined...')
        else:
            self.cluster_id = 0
        
        # node init for feddyn
        if args.client_method == 'feddyn':
            self.old_grad = None
            self.old_grad = copy.deepcopy(self.model)
            self.old_grad = model_parameter_vector(args, self.old_grad)
            self.old_grad = torch.zeros_like(self.old_grad)
        if 'feddyn' in args.server_method:
            self.server_state = copy.deepcopy(self.model)
            for param in self.server_state.parameters():
                param.data = torch.zeros_like(param.data)
        
        # node init for fedadam's server
        if args.server_method == 'fedadam' and num_id == -1:
            m = copy.deepcopy(self.model)
            self.zero_weights(m)
            self.m = m
            v = copy.deepcopy(self.model)
            self.zero_weights(v)
            self.v = v

    def zero_weights(self, model):
        for n, p in model.named_parameters():
            p.data.zero_()

    def train_val_split(self, idxs, train_set, valid_ratio): 

        np.random.shuffle(idxs)

        validate_size = valid_ratio * len(idxs)
        # print(len(idxs))

        idxs_test = idxs[:int(validate_size)]
        idxs_train = idxs[int(validate_size):]

        train_loader = DataLoader(DatasetSplit(train_set, idxs_train),
                                  batch_size=self.args.batchsize, num_workers=0, shuffle=True)

        test_loader = DataLoader(DatasetSplit(train_set, idxs_test),
                                 batch_size=self.args.validate_batchsize,  num_workers=0, shuffle=True)
        

        return train_loader, test_loader

    def train_val_split_forServer(self, idxs, train_set, valid_ratio, num_classes=10): # local data index, trainset

        np.random.shuffle(idxs)

        validate_size = int(valid_ratio * len(idxs))

        # generate proxy dataset with balanced classes
        idxs_test = []

        if self.args.longtail_proxyset == 'none':
            test_class_count = [int(validate_size)/num_classes for _ in range(num_classes)]
        elif self.args.longtail_proxyset == 'LT':
            # pass
            imb_factor = 0.1
            test_class_count = [int(validate_size/num_classes * (imb_factor**(_classes_idx / (num_classes - 1.0)))) for _classes_idx in range(num_classes)]
        # print('proxysize ', sum(test_class_count))
        k = 0
        while sum(test_class_count) != 0:
            if test_class_count[train_set[idxs[k]][1]] > 0:
                idxs_test.append(idxs[k])
                test_class_count[train_set[idxs[k]][1]] -= 1
            else: 
                pass
            k += 1
        label_list = []
        for k in idxs_test:
            label_list.append(train_set[k][1])

        idxs_train = [idx for idx in idxs if idx not in idxs_test]

        train_loader = DataLoader(DatasetSplit(train_set, idxs_train),
                                  batch_size=self.args.batchsize, num_workers=0, shuffle=True)
        test_loader = DataLoader(DatasetSplit(train_set, idxs_test),
                                 batch_size=self.args.validate_batchsize,  num_workers=0, shuffle=True)

        return train_loader, test_loader


