from ast import Set
import os
import sys
import logging

import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms

# CIFAR10_TRAIN_MEAN = (0.4914, 0.4822, 0.4465)
# CIFAR10_TRAIN_STD = (0.2470, 0.2435, 0.2616)
# CIFAR100_TRAIN_MEAN = (0.5071, 0.4865, 0.4409)
# CIFAR100_TRAIN_STD = (0.2673, 0.2564, 0.2762)

CIFAR10_TRAIN_MEAN = np.array((0.4914, 0.4822, 0.4465))[None, :, None, None]
CIFAR10_TRAIN_STD = np.array((0.2470, 0.2435, 0.2616))[None, :, None, None]
CIFAR100_TRAIN_MEAN = np.array((0.5071, 0.4865, 0.4409))[None, :, None, None]
CIFAR100_TRAIN_STD = np.array((0.2673, 0.2564, 0.2762))[None, :, None, None]


def get_logger(filename):
    # Logging configuration: set the basic configuration of the logging system
    log_formatter = logging.Formatter(fmt='%(asctime)s [%(levelname)-5.5s] %(message)s',
                                      datefmt='%Y-%b-%d %H:%M')
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    # File logger
    file_handler = logging.FileHandler(
        filename, mode='w')  # default is 'a' to append
    file_handler.setFormatter(log_formatter)
    file_handler.setLevel(logging.DEBUG)
    logger.addHandler(file_handler)
    # Stdout logger
    std_handler = logging.StreamHandler(sys.stdout)
    std_handler.setFormatter(log_formatter)
    std_handler.setLevel(logging.DEBUG)
    logger.addHandler(std_handler)
    return logger

def get_path(args, algorithm=None, class_seed=None, seed=None):
    
    path = os.path.join(args.result_path,
                        args.dataset,
                        f'private_model_type_{args.private_model_type}',
                        f'n_clients_{args.n_clients}',
                        f'clusters_{args.n_clusters}_frac_{args.frac}',
                        f'batch_size_{args.batch_size}',
                        f'optimizer_{args.optimizer}',
                        f'lr_{args.lr}',
                        f'use_private_SGD_{args.use_private_SGD}',
                        f'delta_{args.delta}',
                        f'noise_multiplier_{args.noise_multiplier}',
                        f'l2_norm_clip_{args.l2_norm_clip}',
                        f'n_rounds_{args.n_rounds}',
                        f'class_seed_{class_seed}'                          
                        )
    if algorithm is not None:
        path = os.path.join(path, algorithm)
    if algorithm == 'dFedEM':
        path = os.path.join(path, f'n_components_{args.n_components}')
    if algorithm == 'Federico':
        path = os.path.join(path, f'n_neighbors_{args.n_neighbors}', 
                            f'cw_momentum_{args.cw_momentum}', f'epsilon_{args.greedy_eps}')
    if algorithm == 'FedFomo':
        path = os.path.join(path, f'cw_ratio_{args.cw_ratio}')
    if seed is not None:
        path = os.path.join(path, f'seed_{seed}.npz')
    return path

def get_data(args):

    if args.dataset == 'mnist' or args.dataset == 'fashion-mnist':

        data_file = f"{args.data_path}/{args.dataset}.npz"
        dataset = np.load(data_file)
        train_X, train_y = dataset['x_train'], dataset['y_train'].astype(
            np.int64)
        test_X, test_y = dataset['x_test'], dataset['y_test'].astype(np.int64)

        if args.dataset == 'fashion-mnist':
            train_X = np.reshape(train_X, (-1, 1, 28, 28))
            test_X = np.reshape(test_X, (-1, 1, 28, 28))
        else:
            train_X = np.expand_dims(train_X, 1)
            test_X = np.expand_dims(test_X, 1)

    elif args.dataset == 'cifar10':

        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(CIFAR10_TRAIN_MEAN,
                                  CIFAR10_TRAIN_STD)])

        trainset = torchvision.datasets.CIFAR10(root=f"{args.data_path}/{args.dataset}/",
                                                train=True,
                                                transform=transform)
        # download = True,
        train_X = trainset.data.transpose([0, 3, 1, 2])
        train_y = np.array(trainset.targets)

        testset = torchvision.datasets.CIFAR10(root=f"{args.data_path}/{args.dataset}/",
                                               train=False,
                                               transform=transform)
        test_X = testset.data.transpose([0, 3, 1, 2])
        test_y = np.array(testset.targets)

    elif args.dataset == 'cifar100':

        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(CIFAR100_TRAIN_MEAN,
                                  CIFAR100_TRAIN_STD)])

        trainset = torchvision.datasets.CIFAR100(root=f"{args.data_path}/{args.dataset}/",
                                                 train=True,
                                                 transform=transform)
        train_X = trainset.data.transpose([0, 3, 1, 2])
        train_y = np.array(trainset.targets)

        testset = torchvision.datasets.CIFAR100(root=f"{args.data_path}/{args.dataset}/",
                                                train=False,
                                                transform=transform)
        test_X = testset.data.transpose([0, 3, 1, 2])
        test_y = np.array(testset.targets)
    else:

        raise ValueError("Unknown dataset")

    return train_X, train_y, test_X, test_y


def data_loader(dataset, inputs, targets, batch_size, is_train=True, include_idx=False):

    def cifar10_norm(x):
        x -= CIFAR10_TRAIN_MEAN
        x /= CIFAR10_TRAIN_STD
        return x

    def cifar100_norm(x):
        x -= CIFAR10_TRAIN_MEAN
        x /= CIFAR10_TRAIN_STD
        return x

    def no_norm(x):
        return x

    if dataset == 'cifar10':
        norm_func = cifar10_norm
    elif dataset == 'cifar100':
        norm_func = cifar100_norm
    else:
        norm_func = no_norm

    assert inputs.shape[0] == targets.shape[0]
    n_examples = inputs.shape[0]

    sample_rate = batch_size / n_examples
    num_blocks = int(n_examples / batch_size)
    if is_train:
        for i in range(num_blocks):
            mask = np.random.rand(n_examples) < sample_rate
            if np.sum(mask) != 0:
                if include_idx:
                    yield (norm_func(inputs[mask].astype(np.float32) / 255.),
                        targets[mask]), mask.nonzero()
                else:
                    yield (norm_func(inputs[mask].astype(np.float32) / 255.),
                       targets[mask])
    else:
        for i in range(num_blocks):
            if include_idx:
                yield (norm_func(inputs[i * batch_size: (i+1) * batch_size].astype(np.float32) / 255.),
                   targets[i * batch_size: (i+1) * batch_size]), list(range(i * batch_size, (i+1) * batch_size))
            else:
                yield (norm_func(inputs[i * batch_size: (i+1) * batch_size].astype(np.float32) / 255.),
                   targets[i * batch_size: (i+1) * batch_size])
        if num_blocks * batch_size != n_examples:
            if include_idx:
                yield (norm_func(inputs[num_blocks * batch_size:].astype(np.float32) / 255.),
                   targets[num_blocks * batch_size:]), list(range(num_blocks * batch_size, len(inputs)))
            else:
                yield (norm_func(inputs[num_blocks * batch_size:].astype(np.float32) / 255.),
                   targets[num_blocks * batch_size:])


def partition_data(train_X, train_y, args):
    np.random.seed(args.seed)
    client_data_list, client_major_classes = partition_data_by_clusters(train_X, train_y, args)
    return client_data_list, client_major_classes

def partition_data_by_clusters(train_X, train_y, args):
    client_data_list = []
    if args.n_clusters == -1:
        n_clusters = args.n_class
    else:
        n_clusters = args.n_clusters

    partition_seed = args.class_seed if args.class_seed!=-1 else args.seed
    np.random.seed(partition_seed)

    all_labels = list(range(args.n_class))
    np.random.shuffle(all_labels)
    clusters_labels = iid_divide(all_labels, n_clusters)

    label2cluster = dict()  # maps label to its cluster
    for group_idx, labels in enumerate(clusters_labels):
        for label in labels:
            label2cluster[label] = group_idx

    # get subset
    n_samples = int(len(train_y) * args.frac)
    selected_indices = np.random.choice(len(train_y), n_samples)

    clusters_sizes = np.zeros(n_clusters, dtype=int)
    clusters = {k: [] for k in range(n_clusters)}
    for idx in selected_indices:
        label = train_y[idx]
        group_id = label2cluster[label]
        clusters_sizes[group_id] += 1
        clusters[group_id].append(idx)

    for _, cluster in clusters.items():
        np.random.shuffle(cluster)

    client_clusters = np.random.randint(n_clusters, size=args.n_clients)
    client_classes = [clusters_labels[c] for c in client_clusters]
    
    cluster_client_dict = {i:[] for i in range(n_clusters)}
    for client_id, cl in enumerate(client_clusters):
        cluster_client_dict[cl].append(client_id)

    client_data_list = [[] for _ in range(args.n_clients)]
    for cluster_id, client_list in cluster_client_dict.items():
        if len(client_list) > 0:
            client_data_ids_list = iid_divide(clusters[cluster_id], len(client_list))
            for client_id, client_data_idx in zip(client_list, client_data_ids_list):
                client_data = train_X[client_data_idx], train_y[client_data_idx]
                client_data_list[client_id] = client_data
    return client_data_list, client_classes

def iid_divide(l, g):
    """
    https://github.com/TalwalkarLab/leaf/blob/master/data/utils/sample.py
    divide list `l` among `g` groups
    each group has either `int(len(l)/g)` or `int(len(l)/g)+1` elements
    returns a list of groups
    """
    num_elems = len(l)
    group_size = int(len(l) / g)
    num_big_groups = num_elems - g * group_size
    num_small_groups = g - num_big_groups
    glist = []
    for i in range(num_small_groups):
        glist.append(l[group_size * i: group_size * (i + 1)])
    bi = group_size * num_small_groups
    group_size += 1
    for i in range(num_big_groups):
        glist.append(l[bi + group_size * i:bi + group_size * (i + 1)])
    return glist

def split_list_by_indices(l, indices):
    """
    divide list `l` given indices into `len(indices)` sub-lists
    sub-list `i` starts from `indices[i]` and stops at `indices[i+1]`
    returns a list of sub-lists
    #https://github.com/echoyi/FedEM/blob/main/data/cifar10/utils.py
    """
    res = []
    current_index = 0
    for index in indices:
        res.append(l[current_index: index])
        current_index = index
    return res


def evaluate_model(model, data, args):

    model.eval()
    x, y = data

    loader = data_loader(args.dataset, x, y, batch_size=1000, is_train=False)
    acc = 0.
    for xt, yt in loader:
        xt = torch.tensor(xt, requires_grad=False,
                          dtype=torch.float32).to(args.device)
        yt = torch.tensor(yt, requires_grad=False,
                          dtype=torch.int64).to(args.device)
        preds_labels = torch.squeeze(torch.max(model(xt), 1)[1])
        acc += torch.sum(preds_labels == yt).item()

    return acc / x.shape[0]


def evaluate_component_model(client, data, args):
    x, y = data
    loader = data_loader(args.dataset, x, y, batch_size=1000, is_train=False)
    correct_preds = 0.
    with torch.no_grad():
        for xt, yt in loader:
            xt = torch.tensor(xt, requires_grad=False,
                              dtype=torch.float32).to(args.device)
            yt = torch.tensor(yt, requires_grad=False,
                              dtype=torch.int64).to(args.device)
            y_preds = 0.
            for component, component_weight in zip(client.private_components, client.private_component_weights):
                component.eval()
                y_preds += component_weight*component(xt)
            preds_labels = torch.squeeze(torch.max(y_preds, 1)[1])
            correct_preds += torch.sum(preds_labels == yt).item()
    return correct_preds / x.shape[0]


def extract_numpy_params(model):
    return convert_tensor_params_to_numpy(model.state_dict())


def convert_tensor_params_to_numpy(tensor_params):
    numpy_params = {}
    for k in tensor_params.keys():
        numpy_params[k] = tensor_params[k].detach().cpu().numpy()
    return numpy_params

def convert_np_params_to_tensor(params, device=None):
    torch_params = {}
    for k in params.keys():
        p = torch.from_numpy(params[k])
        torch_params[k] = p if device is None else p.to(device)
    return torch_params


def calculate_losses(model, train_data, args):
    model.eval()
    x, y = train_data
    loader = data_loader(args.dataset, x, y, batch_size=1000, is_train=False)
    losses = []

    for xt, yt in loader:
        xt = torch.tensor(xt, requires_grad=False,
                            dtype=torch.float32).to(args.device)
        yt = torch.tensor(yt, requires_grad=False,
                            dtype=torch.int64).to(args.device)
        losses.append(torch.nn.CrossEntropyLoss(reduction='none')(model(xt), yt))

    losses = torch.cat(losses, dim=0)
    return losses


def calculate_all_components_losses(components, train_data, args):
    all_losses = torch.zeros(len(components), len(train_data[0])).to(args.device)

    with torch.no_grad():
        for id, component in enumerate(components):
            losses = calculate_losses(component, train_data, args)
            all_losses[id] = losses

    return all_losses

def get_param_tensor(model):
    param_list = []

    # for param in model.parameters():
    for name, param in model.state_dict().items():
        param_list.append(param.data.view(-1, ))

    return torch.cat(param_list)

def convert_param_tensor_to_np_params(param_value_list, model):
    numpy_params = {}
    tensor_params = model.state_dict()
    start_idx = 0

    for k in tensor_params.keys():
        numpy_params[k] = param_value_list[start_idx:start_idx + torch.numel(tensor_params[k])].view(tensor_params[k].shape).detach().cpu().numpy()
        start_idx += torch.numel(tensor_params[k])
    assert start_idx == len(param_value_list)
    return numpy_params
