import torch
import random
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset
import torchvision
import numpy as np
from copy import deepcopy
import argparse
from torchvision import transforms
from optimizer.fnnow import FNNOW
from optimizer.fagh import FAGH
from optimizer.fedsophia import FedSophia
from optimizer.scaffold import SCAFFOLD
from optimizer.fedavg import FedAVG
from model import MLP, CNN, ResNet18


def getDirichletData(data, targets, psizes, alpha):
    """devide data by dirichlet with alpha"""
    n_nets = psizes
    K = len(torch.unique(targets))
    labelList = targets
    min_size = 0
    N = len(labelList)

    net_dataidx_map = {}
    while min_size < K:
        idx_batch = [[] for _ in range(n_nets)]
        # for each class in the dataset
        for k in range(K):
            idx_k = np.where(labelList == k)[0] #labelList[k][0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(alpha, n_nets)) #dirichlet distribution
            ## Balance
            proportions = np.array([p*(len(idx_j)<N/n_nets) for p,idx_j in zip(proportions,idx_batch)])
            proportions = proportions/proportions.sum()
            proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])

    for j in range(n_nets):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]
        
    net_cls_counts = {}

    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(labelList[dataidx], return_counts=True)
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
        net_cls_counts[net_i] = tmp
    print('Data statistics: %s' % str(net_cls_counts))

    local_sizes = []
    for i in range(n_nets):
        local_sizes.append(len(net_dataidx_map[i]))
    local_sizes = np.array(local_sizes)
    weights = local_sizes/np.sum(local_sizes)
    print(weights)
    partitions = [Subset(data, idx_batch[i]) for i in range(n_nets)]
    return partitions, weights

def data_tf(x):
    x = np.array(x, dtype='float32') / 255
    x = (x - 0.5) / 0.5
    # x = x.reshape((-1,))
    x = x.reshape((1, 28, 28))
    x = torch.tensor(x)
    return x

def dataload(dataset, size, isNonIID = False, alpha = 1):
    if dataset == 'mnist':
        trainset = torchvision.datasets.MNIST(root='./../../data', train=True, download=False, transform=data_tf)
        testset = torchvision.datasets.MNIST(root='./../../data', train=False, download=False, transform=data_tf)
        targets, classnum = trainset.targets, 10
    elif dataset == 'fmnist':
        trainset = torchvision.datasets.FashionMNIST(root='./../data', train=True, download=False, transform=data_tf)
        testset = torchvision.datasets.FashionMNIST(root='./../data', train=False, download=False, transform=data_tf)
        targets, classnum = trainset.targets, 10
    elif dataset == 'cifar10': 
        transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        trainset = torchvision.datasets.CIFAR10(root='./../data', train=True, download=False, transform=transform)
        testset = torchvision.datasets.CIFAR10(root='./../data', train=False, download=False, transform=transform)
        targets, classnum = trainset.targets, 10
    if isNonIID:
        partitions, ratio = getDirichletData(trainset, torch.tensor(targets), size, alpha)
    else:
        partitions = []
        datalen =  len(trainset)
        indexes = [x for x in range(0,datalen)]
        random.shuffle(indexes)
        partlen = int(datalen / size)
        for index in range(size):
            partitions.append(Subset(trainset, indexes[0:partlen]))
            indexes = indexes[partlen:]
        ratio = [1 / size] * size
    return partitions, ratio, testset, classnum


def fnnow_aggregate(Hg_sum_list, ratio, model_g, old_Hg_dict, lr, b):
    dict_sum = {}
    for i, Hg_dict in enumerate(Hg_sum_list):
        for name in Hg_dict:
            if name not in dict_sum:
                dict_sum[name] = torch.zeros_like(Hg_dict[name])
            dict_sum[name] += ratio[i] * Hg_dict[name]

    for name, param in model_g.named_parameters():
        if not param.requires_grad:
            continue
        if name not in old_Hg_dict:
            old_Hg_dict[name] = dict_sum[name].clone()
        Hg = b * old_Hg_dict[name] + (1 - b) * dict_sum[name]
        old_Hg_dict[name] = Hg.clone()
        param.data -= lr * Hg

def fagh_aggregate(v_sum_list, g_sum_list, ratio, model_g, optimizer, lr, b1, b2):
    v_dict_sum = {}
    g_dict_sum = {}
    old_v_dict, old_g_dict = optimizer.old_v_dict, optimizer.old_g_dict
    for i, (v_dict, g_dict) in enumerate(zip(v_sum_list, g_sum_list)):
        for name in v_dict:
            if name not in v_dict_sum:
                v_dict_sum[name] = torch.zeros_like(v_dict[name])
                g_dict_sum[name] = torch.zeros_like(g_dict[name])
            v_dict_sum[name] += ratio[i] * v_dict[name]
            g_dict_sum[name] += ratio[i] * g_dict[name]

    for name, param in model_g.named_parameters():
        if not param.requires_grad:
            continue
        if name not in old_v_dict:
            old_v_dict[name] = v_dict_sum[name].clone()
            old_g_dict[name] = g_dict_sum[name].clone()
        v = b1 * old_v_dict[name] + (1 - b1) * v_dict_sum[name]
        g = b2 * old_g_dict[name] + (1 - b2) * g_dict_sum[name]
        old_v_dict[name] = v.clone()
        old_g_dict[name] = g.clone()
        Hg = optimizer.sherman(v, g)
        param.data -= lr * Hg

def fedsophia_aggregate(Hg_list, ratio, model_g, lr):
    Hg_sum = {}
    for name, param in model_g.named_parameters():
        if not param.requires_grad:
            continue
        if name not in Hg_sum:
            Hg_sum[name] = torch.zeros_like(param)
        for i in range(len(Hg_list)):
            if name in Hg_list[i]:
                Hg_sum[name] += ratio[i] * Hg_list[i][name]
        param.data -= lr * Hg_sum[name]
        Hg_sum[name].zero_()

def scaffold_aggregate(xdel_list, cdel_list, ratios, model_g, c_global, lr=1):
    for name, param in model_g.named_parameters():
        if not param.requires_grad:
            continue
        x_sum = torch.zeros_like(param)
        c_sum = torch.zeros_like(param)
        for i in range(len(xdel_list)):
            x_sum += ratios[i] * xdel_list[i][name]
            c_sum += ratios[i] * cdel_list[i][name]
        param.data += lr * x_sum
        c_global[name] += c_sum

def fedavg_aggregate(g_sum_list, ratio, model_g, lr, beta):
    dict_sum = {}
    for i, g_dict in enumerate(g_sum_list):
        for name in g_dict:
            if name not in dict_sum:
                dict_sum[name] = torch.zeros_like(g_dict[name])
            dict_sum[name] += ratio[i] * g_dict[name]

    for name, param in model_g.named_parameters():
        if not param.requires_grad:
            continue
        param.data -= lr * dict_sum[name]
    

def test(testloader, model, rank, rnd):
    correct = 0 
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            # calculate outputs by running images through the network
            outputs = model(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        # print(f'round:{rnd} Accuracy of the client:{rank} on the 10000 test images: {100 * correct // total} %')
        # with open("NW_rsn.txt", "a") as file:
        #     file.write(f'round:{rnd} Accuracy of the client:{rank} on the 10000 test images: {100 * correct // total} %\n')
        print(f'round:{rnd} Accuracy of the client:{rank} on the 10000 test images: {100 * correct / total:.2f} %')

def set_seed(seed):
    torch.manual_seed(seed)  # set seed
    torch.cuda.manual_seed_all(seed)  
    torch.backends.cudnn.deterministic = True  
    torch.backends.cudnn.benchmark = False 


if __name__ == "__main__":
    import argparse
    import torch
    import torch.nn as nn
    from copy import deepcopy

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.cuda.empty_cache()

    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--bs', type=int, default=64)
    parser.add_argument('--eps', type=float, default=0.01)
    parser.add_argument('--sampler', type=float, default=0.0001)
    parser.add_argument('--epoch', type=int, default=1)
    parser.add_argument('--nonIID', action='store_true')
    parser.add_argument('--alpha', type=float, default=1)
    parser.add_argument('--size', type=int, default=10)
    parser.add_argument('--round', type=int, default=10)
    parser.add_argument('--dataset', type=str, default='mnist')
    parser.add_argument('--opt', type=str, default='fedavg')
    parser.add_argument('--model', type=str, default='MLP')
    args = parser.parse_args()

    # paramrter
    set_seed(args.seed)
    size = args.size
    lr = args.lr
    eps = args.eps
    sampler = args.sampler
    epochs = args.epoch
    batch_size = args.bs
    dataset = args.dataset
    opt = args.opt
    model = args.model
    clp = 1e-4
    b1 = 0.9
    b2 = 0.95
    b = 0  # momentum

    # init
    model_dict = {
    'MLP': MLP,
    'CNN': CNN,
    'ResNet18': ResNet18
    }

    model_g = model_dict[args.model]().to(device)
    models = [deepcopy(model_g) for _ in range(size)]

    # data split
    partations, ratio, testset, classnum = dataload(dataset, size, isNonIID=args.nonIID, alpha=args.alpha)
    testdata = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=10)
    criterion = nn.CrossEntropyLoss()

    old_Hg_dict = {}

    for rnd in range(args.round):
        Hg_sum_list = []
        v_sum_list = []
        g_sum_list = []
        
        xdel_sum = []
        cdel_sum = []

        for rank in range(size):
            traindata = torch.utils.data.DataLoader(partations[rank], batch_size=batch_size, shuffle=True, drop_last=True, num_workers=10)
            if opt == 'fnnow':
                if rank == 0:  
                    optimizer = FNNOW(lr=lr, eps=eps, sampler=sampler, clp=clp, beta=b, device=device)
                Hg_dict = optimizer.train(models[rank], traindata, criterion, epochs)
                Hg_sum_list.append(Hg_dict)
            elif opt == 'fagh':
                if rank == 0:  
                    optimizer = FAGH(lr=lr, eps=eps, beta1=b1, beta2=b2, device=device)
                v_dict, g_dict = optimizer.train(models[rank], traindata, criterion, epochs)
                v_sum_list.append(v_dict)
                g_sum_list.append(g_dict)
            elif opt == 'fedsophia':
                if rank == 0: 
                    optimizer = FedSophia(lr=lr, eps=eps, beta1=b1, beta2=b2, device=device)
                Hg_dict = optimizer.train(models[rank], traindata, criterion, epochs, rnd, batch_size)
                Hg_sum_list.append(Hg_dict)
            elif opt == 'scaffold':
                if rank == 0:
                    optimizer = SCAFFOLD(eps=eps, lr=lr, device=device)
                    optimizer.init_global(model_g)
                x_del, c_del = optimizer.train(rank, models[rank], traindata, criterion, epochs)
                xdel_sum.append(x_del)
                cdel_sum.append(c_del)
            elif opt == 'fedavg':
                if rank == 0:
                    optimizer = FedAVG(eps, lr=lr, beta=b, device=device)
                g_dict = optimizer.train(models[rank], traindata, criterion, epochs)
                g_sum_list.append(g_dict)

        # aggregation
        if opt == 'fnnow':
            fnnow_aggregate(Hg_sum_list, ratio, model_g, old_Hg_dict, lr, b)
        elif opt == 'fagh':
            fagh_aggregate(v_sum_list, g_sum_list, ratio, model_g, optimizer, lr, b1, b2)
        elif opt == 'fedsophia':
            fedsophia_aggregate(Hg_sum_list, ratio, model_g, lr)
        elif opt == 'scaffold':
            scaffold_aggregate(xdel_sum, cdel_sum, ratio, model_g, optimizer.c_global, lr=1)
        elif opt == 'fedavg':
            fedavg_aggregate(g_sum_list, ratio, model_g, lr, b)
        # update
        for rank in range(size):
            models[rank].load_state_dict(model_g.state_dict())

        # test
        test(testdata, model_g, size, rnd)