#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

from ast import arg
import os
import copy
import torch
import random
from torchvision import datasets, transforms
from sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid
from numpy.testing import assert_array_almost_equal
import numpy as np
from collections import Counter
import pickle as pkl
import math


class ClothFolder_train(datasets.ImageFolder):

    def __init__(self, root, transform):
        super(ClothFolder_train, self).__init__(root, transform)

    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        # path = self.path_all[index]
        # target = self.targets_all[index]
        # sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        # if self.target_transform is not None:
        #     target = self.target_transform(target)
        return sample, target

    def __len__(self):
        return len(self.samples)


class ClothFolder_test(datasets.ImageFolder):

    def __init__(self, root, transform):
        super(ClothFolder_test, self).__init__(root, transform)

    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target


def build_uniform_P(size, noise):
    """ The noise matrix flips any class to any other with probability
    noise / (#class - 1).
    """

    assert (noise >= 0.) and (noise <= 1.)

    P = noise / (size - 1) * np.ones((size, size))
    np.fill_diagonal(P, (1 - noise) * np.ones(size))

    assert_array_almost_equal(P.sum(axis=1), 1, 1)
    return P


def generate_noise_matrix_from_diagonal(diag):
    K = diag.shape[0]
    noise_matrix = np.zeros((K, K))
    for i in range(diag.shape[0]):
        noise_matrix[i, i] = diag[i]
        # noise_matrix[np.arange(K)!=i, i] = np.random.dirichlet(np.ones(K-1)) * (1 - diag[i])  # this is different from the T generally defined in the literature
        tmp = np.random.dirichlet(np.ones(K-1)) * (1 - diag[i])
        while np.sum(tmp > 0.9*noise_matrix[i, i]) > 0:
            tmp = np.random.dirichlet(np.ones(K-1)) * (1 - diag[i])
        noise_matrix[i, np.arange(K) != i] = tmp  # use this one
    return noise_matrix


def multiclass_noisify(y, P, random_state=0):
    """ Flip classes according to transition probability matrix T.
    It expects a number between 0 and the number of classes - 1.
    """
    y = y.numpy()
    assert P.shape[0] == P.shape[1]
    assert np.max(y) < P.shape[0]

    # row stochastic matrix
    assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
    assert (P >= 0.0).all()

    m = y.shape[0]
    new_y = y.copy()
    flipper = np.random.RandomState(random_state)

    for idx in np.arange(m):
        i = y[idx]
        # draw a vector with only an 1
        flipped = flipper.multinomial(1, P[i, :], 1)[0]
        new_y[idx] = np.where(flipped == 1)[0]

    return torch.from_numpy(new_y)


def noisify_with_P(y_train, nb_classes, noise, random_state=None):
    """
    nb_classes: number of classes
    noise: noisy ratio
    """

    if noise > 0.0:
        P = build_uniform_P(nb_classes, noise)
        # seed the random numbers with #run
        y_train_noisy = multiclass_noisify(y_train,
                                           P=P,
                                           random_state=random_state)
        actual_noise = (y_train_noisy != y_train).mean()
        assert actual_noise > 0.0
        print('Actual noise %.2f' % actual_noise)
        y_train = y_train_noisy
    else:
        P = np.eye(nb_classes)

    return y_train, P


def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if 'cifar10' in args.dataset and 'cifar100' not in args.dataset:
        data_dir = './data'
        apply_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                         transform=apply_transform)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                        transform=transform_test)
        num_classes = 10

    elif 'cifar100' in args.dataset:
        data_dir = './data'
        apply_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_dataset = datasets.CIFAR100(data_dir, train=True, download=True,
                                          transform=apply_transform)

        test_dataset = datasets.CIFAR100(data_dir, train=False, download=True,
                                         transform=transform_test)
        num_classes = 100

    elif args.dataset == 'mnist':

        data_dir = './data'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                      transform=apply_transform)
        num_classes = 10

    if args.dataset == 'clothing1m':
        # DMI original
        data_root = './data'
        train_folder = os.path.join(data_root, "noisy_train")
        test_folder = os.path.join(data_root, "clean_test")
        train_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.6959, 0.6537, 0.6371], std=[
                                 0.3113, 0.3192, 0.3214]),
        ])
        test_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.6959, 0.6537, 0.6371], std=[
                                 0.3113, 0.3192, 0.3214]),
        ])
        train_dataset = ClothFolder_train(
            root=train_folder, transform=train_transform)
        test_dataset = ClothFolder_test(
            root=test_folder, transform=test_transform)
        train_dataset.noise_or_not = [True] * len(train_dataset)
        num_classes = 14
    elif 'cifar10-N' in args.dataset:
        noise_label = torch.load('./CIFAR-N/CIFAR-10_human.pt')
        train_dataset.ori_targets = train_dataset.targets
        if 'worst' in args.dataset:
            train_dataset.targets = noise_label['worse_label']
        elif 'aggre' in args.dataset:
            train_dataset.targets = noise_label['aggre_label']
        elif 'random' in args.dataset:
            train_dataset.targets = noise_label['random_label1']
        train_dataset.noise_or_not = np.array(
            train_dataset.ori_targets) == np.array(train_dataset.targets)
        train_dataset.targets = list(train_dataset.targets)
    elif args.dataset == 'cifar100-N':
        noise_label = torch.load('./CIFAR-N/CIFAR-100_human.pt')
        train_dataset.ori_targets = train_dataset.targets
        train_dataset.targets = noise_label['noisy_label']
        train_dataset.noise_or_not = np.array(
            train_dataset.ori_targets) == np.array(train_dataset.targets)
        train_dataset.targets = list(train_dataset.targets)
    else:
        train_dataset.ori_targets = train_dataset.targets
        if args.noise_ratio > 0.0:
            if args.random == 1:
                acc = 1 - args.noise_ratio
                std_acc = 0.05
                P_diag = acc + std_acc*2*(np.random.rand(num_classes) - 0.5)
                T = generate_noise_matrix_from_diagonal(diag=P_diag)
                T = np.array(T)
            else:
                T = build_uniform_P(num_classes, args.noise_ratio)
            labels = np.array(train_dataset.targets)
            noisy_label = multiclass_noisify(
                torch.from_numpy(labels), P=T, random_state=args.seed
            ).numpy()
            train_dataset.targets = list(noisy_label)
            print('Data corrupted')
        train_dataset.noise_or_not = np.array(
            train_dataset.ori_targets) == np.array(train_dataset.targets)

    # sample training data amongst users
    if args.iid:
        # Sample IID user data from Mnist
        user_groups = cifar_iid(train_dataset, args.num_users)
    else:
        # user_groups = cifar_noniid(train_dataset, args.num_users)
        user_groups, client_openset = openset_sampling(
            train_dataset, args, num_classes)

    if True:
        pkl_file = dict()
        pkl_file['noisy_label'] = np.array(train_dataset.targets)
        pkl_file['user_groups'] = user_groups
        with open(f'{args.dataset}-{args.noise_ratio}-{args.seed}-{args.random}.pkl', 'wb') as f:
            pkl.dump(pkl_file, f)

    return train_dataset, test_dataset, user_groups, num_classes, client_openset


def openset_sampling(train_dataset, args, num_classes):
    # sampling every clients openset
    np.random.seed(args.seed)
    client_openset = dict()
    for i in range(args.num_users):
        cur_user_comb = np.random.binomial(1, 0.5, num_classes)
        while (cur_user_comb == 1).all() and (cur_user_comb == 0).all():
            cur_user_comb = np.random.binomial(1, 0.5, num_classes)
        client_openset[i] = np.where(cur_user_comb == 1)[0]
    cand_pool = []
    for i in range(args.num_users):
        cand_pool.extend(client_openset[i])
    cand_stat = Counter(cand_pool)

    class_dict = dict()
    for i in range(num_classes):
        class_dict[i] = np.random.dirichlet(np.ones(cand_stat[i]))

    y_noisy = train_dataset.targets
    y_noisy_dist = Counter(y_noisy)
    y_noisy = np.array(y_noisy)
    unit_cls_num = dict()
    cls_idx = dict()
    if args.random == 1:
        for i in range(num_classes):
            unit_cls_num[i] = []
            tmp = class_dict[i] * y_noisy_dist[i]
            for j in class_dict[i] * y_noisy_dist[i]:
                if j < 1:
                    unit_cls_num[i].append(1)
                else:
                    unit_cls_num[i].append(int(j))
            unit_cls_num[i] = np.array(unit_cls_num[i])
            if np.sum(unit_cls_num[i]) > y_noisy_dist[i]:
                max_idx = np.argmax(unit_cls_num[i])
                unit_cls_num[max_idx] = unit_cls_num[max_idx] - \
                    (np.sum(unit_cls_num[i]) - y_noisy_dist[i])
                assert np.sum(
                    unit_cls_num[i]) < y_noisy_dist[i], "index out of range"
    else:
        for i in range(num_classes):
            unit_cls_num[i] = int(y_noisy_dist[i] / cand_stat[i])

    for i in range(num_classes):
        cls_idx[i] = np.where(y_noisy == i)[0]

    client_idx = dict()
    client_label = dict()
    cls_pointer = dict()
    pos_pointer = dict()

    for i in range(num_classes):
        cls_pointer[i] = 0
        pos_pointer[i] = 0

    for i in range(args.num_users):
        client_idx[i] = []
        client_label[i] = []

    for i in range(args.num_users):
        client_label_space = client_openset[i]
        for label in client_label_space:
            cls_pos = pos_pointer[label]
            if args.random == 1:
                cls_length = unit_cls_num[label][cls_pos]
            else:
                cls_length = unit_cls_num[label]
            start_pointer = cls_pointer[label]
            end_pointer = cls_pointer[label] + cls_length
            for idx in cls_idx[label][start_pointer:end_pointer]:
                assert train_dataset.targets[idx] == label
            client_idx[i].extend(cls_idx[label][start_pointer:end_pointer])
            client_label[i].extend([label] * cls_length)
            cls_pointer[label] = end_pointer
            pos_pointer[label] += 1
    print('Openset Finished')
    return client_idx, client_openset


def average_weights(w):
    """
    Returns the average of the weights.
    """
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        if 'num_batches_tracked' in key:
            w_avg[key] = w_avg[key].true_divide(len(w))
        else:
            w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg


def transition_matrix_generate(noise_rate=0.5, num_classes=10):
    P = np.ones((num_classes, num_classes))
    n = noise_rate
    P = (n / (num_classes - 1)) * P

    if n > 0.0:
        # 0 -> 1
        P[0, 0] = 1. - n
        for i in range(1, num_classes-1):
            P[i, i] = 1. - n
        P[num_classes-1, num_classes-1] = 1. - n
    return P


def dp_process(args, label, num_cls):
    np.random.seed(args.seed)
    noise_ratio = (num_cls - 1) / ((math.e) ** args.power + num_cls - 1)
    T = transition_matrix_generate(noise_ratio, num_cls)
    for k, v in enumerate(label):
        label[k] = np.random.choice(num_cls, p=T[v, :])
    label_dp_counter = Counter(label)
    label_dp_dist = []
    for i in range(num_cls):
        label_dp_dist.append(label_dp_counter[i] / len(label))
    label_dp_dist = np.array([label_dp_dist])
    # label_dp_dist /= len(label)
    return np.dot(label_dp_dist, np.linalg.inv(T))


def exp_details(args):
    print('\nExperimental details:')
    print(f'    Model     : {args.model}')
    print(f'    Optimizer : {args.optimizer}')
    print(f'    Learning  : {args.lr}')
    print(f'    Global Rounds   : {args.epochs}\n')

    print('    Federated parameters:')
    if args.iid:
        print('    IID')
    else:
        print('    Non-IID')
    print(f'    Fraction of users  : {args.frac}')
    print(f'    Local Batch size   : {args.local_bs}')
    print(f'    Local Epochs       : {args.local_ep}')
    print(f'    Method             : {args.method}')
    return
