import math
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import argparse

import nets
import attack
import pdb

import unidecode
import string
import torch.nn as nn

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", help="dataset", default='mnist', type=str)
    parser.add_argument("--bias", help="degree of non-IID to assign data to workers", type=float, default=0.5)
    parser.add_argument("--net", help="net", default='dnn', type=str, choices=['mlr', 'dnn', 'resnet18', 'lstm'])
    parser.add_argument("--load_net", help="load a specific model", default=None, type=str)
    parser.add_argument("--batch_size", help="batch size", default=32, type=int)
    parser.add_argument("--lr", help="learning rate", default=0.01, type=float)
    parser.add_argument("--nclients", help="# clients", default=20, type=int)
    parser.add_argument("--ngroups", help="# groups", default = 3, type=int)
    parser.add_argument("--nrounds", help="# training rounds", default=10, type=int)
    parser.add_argument("--niters", help="# local iterations", default=1, type=int)
    parser.add_argument("--prob_select", help="probability of a client participating in a round", default=1.0, type=float)
    parser.add_argument("--sample_type", help="local data sampling", default='round_robin', type=str, choices=['round_robin', 'random', 'same'])
    parser.add_argument("--gpu", help="index of gpu", default=0, type=int)
    parser.add_argument("--nbyz", help="# byzantines", default=0, type=int)
    parser.add_argument("--attack", help="type of attack", default='benign', type=str,
                        choices=['benign', 'full_trim', 'full_krum', 'adaptive_trim', 'adaptive_krum', 'shejwalkar', 'shej_agnostic'])
    parser.add_argument("--aggregation", help="aggregation rule", default='hcl', type=str)
    parser.add_argument("--cmax", help="PRISM's notion of c_max", default=0, type=int)
    parser.add_argument("--decay", help="Decay rate", default=0.0, type=float)
    parser.add_argument("--exp", help="Experiment name", default='', type=str)
    parser.add_argument("--k", help="Number of neighbors in P2PL", default=3, type=int)
    parser.add_argument("--W", help="Whether to generate new geographical simulation", default=None, type=str)
    parser.add_argument("--init", help="initialize models", default=None, type=str)
    parser.add_argument("--save_time", help="array saving frequency", default=1, type=int)
    parser.add_argument("--eval_time", help="evaluation frequency", default=1, type=int)
    parser.add_argument("--self_wt", help="Weight assigned to self in gossip averaging", default=None, type=float)
    parser.add_argument("--capabilities", help="attacker's access to local models", default='ben', type=str, choices=['all', 'ben', 'mal'])
    parser.add_argument("--is_mal", help="malicious client indices", default=None, type=str) 
    parser.add_argument("--graph_type", help="k-regular or power-law", default='k-regular', type=str)
    parser.add_argument("--min_degree", help="min in-degree for power law", default=None, type=int)
    parser.add_argument("--max_degree", help="max in-degree for power law", default=None, type=int)
    return parser.parse_args()

def read_file(filename):
    file = unidecode.unidecode(open(filename).read())
    return file, len(file)

def char_tensor(string_):
    all_characters = string.printable
    tensor = torch.zeros(len(string_)).long()
    for c in range(len(string_)):
        try:
            tensor[c] = all_characters.index(string_[c])
        except:
            continue
    return tensor

def compute_cdist(client_wts):

    avg_wts = torch.mean(client_wts, dim=0)
    return torch.mean(torch.norm(client_wts-avg_wts, dim=1)).item()

def distribute_malicious(n_clients, n_mals, distr_type='first_few', n_groups=0, group_ids=0):

    is_mal = np.zeros(n_clients)
    #n_mals = math.floor(fbyz*n_clients)   
    if (distr_type == 'uniform'):
        mal_idx = np.random.choice(n_clients, n_mals, replace=False)
        is_mal[mal_idx] = 1
    elif (distr_type == 'first_few'): is_mal[:n_mals] = 1
    elif (distr_type == 'group'): is_mal[np.where(group_ids < math.floor(fbyz*n_groups))[0]] = 1
    return is_mal

def model_to_vec(net):

    return (torch.cat([x.detach().reshape((-1)) for x in net.parameters() if x.requires_grad != 'null'], dim=0)).squeeze(0)

def vec_to_model(vec, net_name, num_inp, num_out, device):

    net = load_net(net_name, num_inp, num_out, device)
    with torch.no_grad():
        idx = 0
        for j, (param) in enumerate(net.named_parameters()):
            if param[1].requires_grad:
                param[1].data = vec[idx:(idx+param[1].nelement())].reshape(param[1].shape).detach() ##assigned not updated
                idx += param[1].nelement()
    return net

def update_model(message, net, aggregated_grads, test_data, device):

    if (message.find('weights') != -1):
        with torch.no_grad():
            idx = 0
            for j, (param) in enumerate(net.named_parameters()):
                if param[1].requires_grad:
                    param[1].data = aggregated_grads[idx:(idx+param[1].nelement())].reshape(param[1].shape).detach() ##assigned not updated
                    idx += param[1].nelement()
    
    if (message.find('gradients') != -1):
        with torch.no_grad():
            idx = 0
            for j, (param) in enumerate(net.named_parameters()):
                if param[1].requires_grad:
                    param[1].data += aggregated_grads[idx:(idx+param[1].nelement())].reshape(param[1].shape).detach()
                    idx += param[1].nelement()

    if (message == 'shakespeare'):
        criterion = nn.CrossEntropyLoss()
        net.eval()
        ln = int(int(len(test_data[0])/200)*200)
        data = test_data[0][:ln].reshape((-1, 200))
        labels = test_data[1][:ln].reshape((-1, 200))
        with torch.no_grad():
            hdn = net.init_hidden(data.shape[0])
            hidden = (hdn[0].to(device), hdn[1].to(device))
            test_loss = 0
            for c in range(200):
                output, hidden = net(data[:, c], hidden)
                test_loss += criterion(output.view(data.shape[0], -1), labels[:,c])
            test_loss /= 200
        return test_loss

    test_acc = -1
    if (message.find('evaluate') != -1):
        correct = 0
        total = 0
        net.eval()
        with torch.no_grad():
            for data in test_data:
                images, labels = data
                outputs = net(images.to(device))
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels.to(device)).sum().item()
            test_acc = correct/total
    return test_acc

def create_batch(data_size, batch_size, rr_idx, cl, sample_type='round_robin'):
   
    if (sample_type == 'random'): return np.random.choice(data_size, batch_size, replace=False)
    elif (sample_type == 'same'): return np.arange(batch_size)
    elif (sample_type == 'round_robin'): 
        if (rr_idx[cl] + batch_size < data_size): 
            batch_idx = np.asarray(list(range(int(rr_idx[cl]), int(rr_idx[cl])+batch_size)))
            rr_idx[cl] = rr_idx[cl] + batch_size
        else: 
            batch_idx = np.asarray(list(range(int(rr_idx[cl]), data_size)) + list(range(0, batch_size - (data_size-int(rr_idx[cl])))))
            rr_idx[cl] = batch_size - (data_size-int(rr_idx[cl]))
        return batch_idx

def cluster_clients(n_clients, n_groups, cluster_type="uniform"):

    if (cluster_type == "uniform"): return np.random.randint(0, n_groups, n_clients)

def num_params(net):

    P = 0
    for param in net.parameters():
        if param.requires_grad:
            P = P + param.nelement()
    return P

def load_net(net_name, num_inputs, num_outputs, device):

    if (net_name == 'mlr'):
        net = nets.MLR(num_inputs, num_outputs)
    elif (net_name == 'resnet18'):
        net = nets.ResNet18()
    elif(net_name == 'dnn'):
        net = nets.DNN()
    elif(net_name == 'lstm'):
        n_characters = len(string.printable)
        net = nets.CharRNN(n_characters, 128, n_characters, 'lstm', 2)
    net.to(device) 
    return net

def load_byz(byz_name):

    if byz_name == 'benign':
        byz = attack.benign
    elif byz_name == 'full_trim':
        byz = attack.full_trim
    elif byz_name == 'full_krum':
        byz = attack.full_krum
    elif byz_name == 'adaptive_trim':
        byz = attack.adaptive_trim
    elif byz_name == 'adaptive_krum':
        byz = attack.adaptive_krum
    elif byz_name == 'shejwalkar':
        byz = attack.shejwalkar
    elif byz_name == 'shej_agnostic':
        byz = attack.shej_agnostic

    return byz

def load_data(dataset_name, batch_size):

    if (dataset_name == 'mnist'):
        print("Loading MNIST")
        transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]) 
        trainset = torchvision.datasets.MNIST(root='./data', train=True, download='True', transform=transform)
        train_data = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False)
        testset = torchvision.datasets.MNIST(root='./data', train=False, download='True', transform=transform)
        test_data = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
        del trainset, testset        
        num_inputs = 28 * 28
        num_outputs = 10
         
    elif (dataset_name == 'fmnist'):
        print("Loading FMNIST")
        transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]) 
        trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download='True', transform=transform)
        train_data = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False)
        testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download='True', transform=transform)
        test_data = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
        del trainset, testset        
        num_inputs = 28 * 28
        num_outputs = 10

    elif dataset_name == 'cifar10':
        print("Loading CIFAR-10")
        num_inputs = 32*32*3
        num_outputs = 10
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])        
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download='True', transform=transform_train)
        train_data = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False)
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download='True', transform=transform_test)
        test_data = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
        del trainset, testset
       
    elif dataset_name == 'shakespeare':
        print("Loading Shakespeare dataset")
        file, file_len = read_file('shakespeare.txt')
        return file, None, None, None

    else:
        sys.exit('Not Implemented Dataset!')
    return train_data, test_data, num_inputs, num_outputs

def data_distribution(labelset):

    m = len(labelset)
    nclass = len(torch.unique(labelset[0])) #assuming client 0 has at least one sample from all the classes
    ddist = np.zeros((m, nclass))
    for cl in range(m):
        labels = labelset[cl]
        for label in labels:
            ddist[cl, label.item()] += 1

    return ddist

def distribute_data_fang(device, batch_size, bias_weight, train_data, num_workers, num_inputs, num_outputs, net_name):

    if net_name == 'lstm':
        each_worker_data = []
        each_worker_label = []
        for i in range(num_workers+1):
            if (i<num_workers):
                chunk = train_data[i*200*32: (i+1)*200*32+1]
            else: 
                chunk = train_data[i*200*32+1:]
            each_worker_data.append(char_tensor(chunk[:-1]).to(device))
            each_worker_label.append(char_tensor(chunk[1:]).to(device))
        return each_worker_data, each_worker_label, None

    other_group_size = (1-bias_weight) / (num_outputs-1)
    worker_per_group = num_workers / (num_outputs)
    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)] 
    batch_ctr = 0
    for _, (data, label) in enumerate(train_data):
        if net_name == 'mlr':
            data = data.reshape((-1, num_inputs))
        sample_ctr = 0
        for (x, y) in zip(data, label):
            upper_bound = (y.item()) * (1-bias_weight) / (num_outputs-1) + bias_weight
            lower_bound = (y.item()) * (1-bias_weight) / (num_outputs-1)
            np.random.seed(batch_size*batch_ctr + sample_ctr)
            rd = np.random.random_sample()
            if rd > upper_bound:
                worker_group = int(np.floor((rd - upper_bound) / other_group_size)+y.item()+1)
            elif rd < lower_bound:
                worker_group = int(np.floor(rd / other_group_size))
            else:
                worker_group = y.item()
            
            # assign a data point to a worker
            sample_ctr += 1
            np.random.seed(batch_size*batch_ctr + sample_ctr)
            rd = np.random.random_sample()
            selected_worker = int(worker_group*worker_per_group + int(np.floor(rd*worker_per_group)))
            if (bias_weight == 0): selected_worker = np.random.randint(num_workers)
            each_worker_data[selected_worker].append(x.to(device))
            each_worker_label[selected_worker].append(y.to(device))
        batch_ctr += 1
    # concatenate the data for each worker
    each_worker_data = [(torch.stack(each_worker, dim=0)).squeeze(0) for each_worker in each_worker_data] 
    each_worker_label = [(torch.stack(each_worker, dim=0)).squeeze(0) for each_worker in each_worker_label]
    
    # random shuffle the workers
    random_order = np.random.RandomState(seed=42).permutation(num_workers)
    each_worker_data = [each_worker_data[i] for i in random_order]
    each_worker_label = [each_worker_label[i] for i in random_order]
    
    wts = torch.zeros(len(each_worker_data)).to(device)
    for i in range(len(each_worker_data)):
        wts[i] = len(each_worker_data[i])
    wts = wts/torch.sum(wts)
    return each_worker_data, each_worker_label, wts
