import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import sys
import pdb
from copy import deepcopy

import aggregation
import attack
import nets
import utils
import client
import group 

import torch.multiprocessing as mp

import psutil
import os
import gc
import signal
import random

#Learning rate scheduler used to train CIFAR-10
def get_lr(epoch, num_epochs, max_lr):
    epoch += 1
    mu = 2*num_epochs/4
    sigma = num_epochs/4
    if (epoch < 2*num_epochs/4):
        return max_lr*(1-np.exp(-25*(epoch/num_epochs)))
    else:
        return max_lr*np.exp(-0.5*(((epoch-mu)/sigma)**2))
           
def main(args):

    start_time = time.time() 
      
    if args.gpu == -1: device = torch.device('cpu')
    else: device = torch.device('cuda')

    n_clients = args.nclients
    n_groups = args.ngroups
    n_rounds = args.nrounds
    prob_select = args.prob_select
    n_iters = args.niters

    k = args.k

    if (args.aggregation == 'hcl'):
        group_ids = utils.cluster_clients(n_clients, n_groups)
        print("%d clients distributed into %d groups." %(n_clients, n_groups))
        if (len(np.unique(group_ids)) < n_groups):
            print("Some group has 0 clients, restart and redistribute")
            return
        for gid in np.unique(group_ids):
            g_size = len(np.where(group_ids==gid)[0])
            print("Group %d has %d clients" %(gid, g_size))
        if(args.is_mal == None): is_mal = utils.distribute_malicious(n_clients, args.nbyz, distr_type = 'group', n_groups=n_groups, group_ids=group_ids)
        else: is_mal = np.load(args.is_mal)
        print(np.where(is_mal))
        w_size = n_groups
        k_size = int((k-1)/2)
        if (k%2 == 0 and k < n_groups):
            print("Restart and enter an odd k")
            return

    if (args.aggregation.find('p2p') != -1):
        w_size = n_clients
        if (k%2 == 1): k_size = [int((k-1)/2), int((k-1)/2)]
        else: k_size = [int(k/2), int(k/2)-1]
        if (args.is_mal == None): is_mal = utils.distribute_malicious(n_clients, args.nbyz, distr_type = 'uniform')
        else: is_mal = np.load(args.is_mal)
        print(np.where(is_mal))
        rpt = torch.zeros((n_rounds, n_clients, n_clients)).to(device)

    if (args.aggregation.find('p2p') != -1 or args.aggregation == 'hcl'):
        print("Simulating a mixing matrix")

        if (args.W == None):
            W = torch.zeros((w_size, w_size)).to(device) 
            nbrs = {}
            if (args.graph_type == 'k-regular'):

                for i in range(w_size):
                    #### double stochastic mixing matrix : k neighbors on either side of the indexing
                    for j in range (-k_size[0], k_size[1]+1):
                        idx = i+j
                        if (idx<0): idx += w_size
                        if (idx>=w_size): idx -= w_size

                        if (args.self_wt == None): W[i][idx] = 1.0/k
                        else:
                            if (i == idx): W[i][idx] = args.self_wt
                            else: W[i][idx] = (1-args.self_wt)/(k-1)
            elif (args.graph_type == 'power-law'):
                deg = np.zeros(n_clients)
                exponent = math.log(args.nclients)/math.log(2)
                if ((int(exponent)) != exponent):
                    print("Restart and enter the number of clients in an exponent of 2.")
                    return
                else:
                    y = [2**i for i in range(int(exponent)-1, 0, -1)] + [1,1]
                    if args.min_degree == None: xmin = k
                    else: xmin = args.min_degree
                    if args.max_degree == None: xmax = n_clients
                    else: xmax = max_degree
                    ratio = (xmax/xmin)**(1.0/(len(y)-1))
                    x = [xmin]
                    for i in range(len(y)-2):
                        x.append(int(xmin * ratio**(i+1)))
                    x.append(xmax)
                    idx = 0
                    y_next_sum = y[0]
                    for i in range(n_clients):
                        if (i < y_next_sum):
                            deg[i] = x[idx]
                        else:
                            idx += 1
                            y_next_sum += y[idx]
                            deg[i] = x[idx]
                       
                        sampled = [i] + random.sample([j for j in range(n_clients) if j != i], int(deg[i]-1))
                        W[i][sampled] = 1.0 / len(sampled)
                        nbrs[i] = sampled[1:]
               
        else:
            W = torch.load(args.W).to(device)
    sample_type = args.sample_type
    batch_size = args.batch_size
    lr = args.lr
    filename = args.exp
    train_data, test_data, n_inputs, n_outputs = utils.load_data(args.dataset, args.batch_size)
    torch.manual_seed(0)
    net = utils.load_net(args.net, n_inputs, n_outputs, device)
    if (args.load_net != None): net.load_state_dict(torch.load(args.load_net))
    net_vec = utils.model_to_vec(net)
    if (args.aggregation.find('p2p') != -1 or args.aggregation == 'hcl'):
        print("Initializing nets")
        nets = {}
        for i in range(w_size): 
            if args.init == 'diff': 
                torch.manual_seed(i)
                net = utils.load_net(args.net, n_inputs, n_outputs, device)
                nets[i] = utils.vec_to_model(utils.model_to_vec(net), args.net, n_inputs, n_outputs, device)
            elif args.init == 'same': 
                net = utils.vec_to_model(net_vec, args.net, n_inputs, n_outputs, device)  
                nets[i] = utils.vec_to_model(utils.model_to_vec(net), args.net, n_inputs, n_outputs, device)
            else: 
                print("Restart and enter same or diff initialization for the clients/groups in p2p/hcl")
                return
    byz = utils.load_byz(args.attack)
    each_worker_data, each_worker_label, wts = utils.distribute_data_fang(device, args.batch_size, args.bias, train_data, n_clients, n_inputs, n_outputs, args.net)
    print("Data loaded and distributed among clients.")
    if (args.dataset != 'shakespeare'): ddist = utils.data_distribution(each_worker_label)
    criterion = nn.CrossEntropyLoss()
    P = utils.num_params(net)

    rr_idx = torch.zeros(n_clients).to(device)#.share_memory_() #data sample index for round robin sampling

    test_acc = np.zeros(int(n_rounds/args.eval_time))
    p = np.zeros(n_rounds)
    local_test_acc = np.zeros((int(n_rounds/args.eval_time), n_clients))
    cdist = np.zeros(n_rounds)
    pre_dist = np.zeros(n_rounds)
    direction = torch.zeros(P).to(device)
    aggregated_grads = torch.zeros((n_clients,P)).to(device)#.share_memory_()
    FS_record = {}
    medians = {}
    for cl in range(n_clients): 
        aggregated_grads[cl] = utils.model_to_vec(nets[cl])
        if (is_mal[cl] ==0): medians[cl] = 0

    for rnd in range(n_rounds):
        if (prob_select < 1): client_participated = np.random.choice(2, n_clients, [1-prob_select, prob_select])
        else: client_participated = np.ones(n_clients)

        client_grads = torch.zeros((n_clients,P)).to(device)#.share_memory_()
        if (args.aggregation.find('p2p') != -1):
            if (args.capabilities == 'mal'): past_weights = torch.mean(aggregated_grads[is_mal == 1], 0)
            elif (args.capabilities == 'ben'): past_weights = torch.mean(aggregated_grads[is_mal == 0], 0)
            elif (args.capabilities == 'all'): past_weights = torch.mean(aggregated_grads, 0)

        if (args.dataset == 'cifar10'): lr = get_lr(rnd, n_rounds, args.lr)
        for cl in range(n_clients):
            if (client_participated[cl] and is_mal[cl]==0):
                if (args.aggregation.find('p2p') != -1): 
                    if (args.dataset == 'shakespeare'): msg = 'shakespeare'
                    else: msg = 'p2p'
                    client.local_comp(msg, cl, nets[cl], each_worker_data[cl], each_worker_label[cl], sample_type, lr, device, criterion, batch_size, n_iters, client_grads, rr_idx)
                elif (args.aggregation == 'hcl'):
                    client.local_comp('hcl', cl, nets[group_ids[cl]], each_worker_data[cl], each_worker_label[cl], sample_type, lr, device, criterion, batch_size, n_iters, client_grads, rr_idx)
                elif (args.aggregation.find('fl') != -1):
                    client.local_comp('fl', cl, net, each_worker_data[cl], each_worker_label[cl], sample_type, lr, device, criterion, batch_size, n_iters, client_grads, rr_idx)
                else:
                    print ("Aggregation technique not recognized, restart")
                    return

        if (args.attack == 'shejwalkar'):
            model_re, lamda, deviation = attack.shejwalkar(device, past_weights, client_grads, client_participated, is_mal, dev_type='unit_vec', capabilities=args.capabilities, agr=args.aggregation) 
        if (args.aggregation == 'fl_fedsgd'):
            aggregated_grads = torch.matmul(torch.transpose(client_grads, 0, 1), wts.reshape(-1,1))
        if (args.aggregation == 'fl_mean'):
            aggregated_grads = aggregation.fl_mean(client_grads)
        if (args.aggregation == 'fl_mean_prism'):
            aggregated_grads = aggregation.fl_mean_prism(client_grads, client_participated, direction, args.nbyz, device)
            direction = torch.sign(aggregated_grads)

        if (args.aggregation == 'hcl'):
            group_participated = torch.zeros(n_groups).to(torch.int)#.share_memory_()
            group_grads = torch.zeros((n_groups, P)).to(device)#.share_memory_()
            p[rnd], aggregated_wts = aggregation.hcl(client_grads, group_ids, client_participated, group_participated, nets, W, device)
            
            for gr in range(n_groups):
                acc = utils.update_model('weights_and_evaluate', nets[gr], aggregated_wts[gr], test_data, device)
                test_acc[rnd] += acc
                print("Group %d has test accuracy %.2f" %(gr, acc))
            test_acc[rnd] /= n_groups

        if (args.aggregation == 'p2p'):  
            if (args.dataset == 'shakespeare'): 
                message = 'shakespeare'
                test_data = each_worker_data[n_clients], each_worker_label[n_clients]
            else: message = 'evaluate'
            p[rnd], aggregated_grads = aggregation.p2p(W, client_grads, device) #grads is only the name, actually there are weight values
            pre_gossip_acc = np.zeros(n_clients)
            for cl in range(n_clients):
                nets[cl] = utils.vec_to_model(aggregated_grads[cl], args.net, n_inputs, n_outputs, device)
                if (rnd % args.eval_time == args.eval_time-1):
                    local_test_acc[int(rnd/args.eval_time), cl] = utils.update_model(message, nets[cl], None, test_data, device)
            avg_grads = torch.mean(aggregated_grads[is_mal==0], dim=0)
            avg_net = utils.vec_to_model(avg_grads, args.net, n_inputs, n_outputs, device)
            cdist[rnd] = torch.mean(torch.norm(aggregated_grads[is_mal==0]-avg_grads, dim=1)).item()
            if (rnd % args.eval_time == args.eval_time-1):
                print(rnd, local_test_acc[int(rnd/args.eval_time)])
                test_acc[int(rnd/args.eval_time)] = utils.update_model(message, avg_net, None, test_data, device)
                del avg_grads

        if (args.aggregation == 'p2prism'):
            if (args.dataset == 'shakespeare'):
                message = 'shakespeare'
                test_data = each_worker_data[n_clients], each_worker_label[n_clients]
            else: message = 'evaluate'
            #rpt[rnd], FS, aggregated_wts, medians = aggregation.p2prism(medians, W, nets, client_grads, k_size, nbrs, args.graph_type, is_mal, device, True, model_re, lamda, deviation)  
            rpt[rnd], FS, aggregated_wts, medians = aggregation.p2prism(medians, W, nets, client_grads, k_size, nbrs, args.graph_type, is_mal, device, False, None, None, None)  
            FS_record[rnd] = FS
            for cl in range(n_clients):
                nets[cl] = utils.vec_to_model(aggregated_wts[cl], args.net, n_inputs, n_outputs, device)
                if (rnd % args.eval_time == args.eval_time-1):
                    local_test_acc[int(rnd/args.eval_time), cl] = utils.update_model(message, nets[cl], None, test_data, device)
            avg_wts = torch.mean(aggregated_wts[is_mal==0], dim=0)
            avg_net = utils.vec_to_model(avg_wts, args.net, n_inputs, n_outputs, device)
            cdist[rnd] = torch.mean(torch.norm(aggregated_wts[is_mal==0]-avg_wts, dim=1)).item()
            if (rnd % args.eval_time == args.eval_time-1):
                print(rnd, local_test_acc[int(rnd/args.eval_time)])
                test_acc[int(rnd/args.eval_time)] = utils.update_model(message, avg_net, None, test_data, device)
                del avg_wts

        if (args.aggregation == 'p2p_trmean'):
            if (args.dataset == 'shakespeare'):
                message = 'shakespeare'
                test_data = each_worker_data[n_clients], each_worker_label[n_clients]
            else: message = 'evaluate'
            aggregated_wts = aggregation.p2p_trmean(W, client_grads, nets, args.nbyz, is_mal, device)
            for cl in range(n_clients):
                nets[cl] = utils.vec_to_model(aggregated_wts[cl], args.net, n_inputs, n_outputs, device)
                if (rnd % args.eval_time == args.eval_time-1):
                    local_test_acc[int(rnd/args.eval_time), cl] = utils.update_model(message, nets[cl], None, test_data, device)
            avg_wts = torch.mean(aggregated_wts[is_mal==0], dim=0)
            avg_net = utils.vec_to_model(avg_wts, args.net, n_inputs, n_outputs, device)
            cdist[rnd] = torch.mean(torch.norm(aggregated_wts[is_mal==0]-avg_wts, dim=1)).item()
            if (rnd % args.eval_time == args.eval_time-1):
                print(rnd, local_test_acc[int(rnd/args.eval_time)])
                test_acc[int(rnd/args.eval_time)] = utils.update_model(message, avg_net, None, test_data, device)
                del avg_wts

        if (args.aggregation.find('fl_') != -1):
            acc = utils.update_model('update_with_gradients_and_evaluate', net, aggregated_grads, test_data, device)
            test_acc[rnd] = acc
        if (rnd % args.eval_time == 0):
            print ("Rnd %d | Acc -> [%.2f, %.2f] : %.2f | Cons dist %f" %(rnd, min(local_test_acc[int(rnd/args.eval_time)]), max(local_test_acc[int(rnd/args.eval_time)]), test_acc[int(rnd/args.eval_time)], cdist[int(rnd/args.eval_time)]))
         
        if (rnd % args.save_time == args.save_time-1):
            print("Saving")
            if (args.aggregation == 'hcl'):
                print (group_ids)
                np.save(args.exp+"_group_ids.npy", group_ids)
                torch.save(W, args.exp+"_W.pt")
                np.save(args.exp+"_p.npy", p)
            np.save(args.exp+"_test_acc.npy", test_acc)
            np.save(args.exp+"_local_acc.npy", local_test_acc)
            np.save(args.exp+"_cdist.npy", cdist)
            if (args.dataset != 'shakespeare'): np.save(args.exp+"_ddist.npy", ddist)
            if (args.aggregation == 'p2prism'):
                torch.save(rpt, args.exp+"_rpt.pt")
                np.save(args.exp+"_FS_dict.npy", FS_record)
            if (args.aggregation.find('p2p') != -1):
                torch.save(W, args.exp+"_W.pt")
            if(args.attack != 'benign'): np.save(args.exp+"_is_mal.npy", is_mal)
    print("Total time taken = ", time.time() - start_time)

if __name__ == "__main__":
    args = utils.parse_args()
    main(args)
