import logging
import math

import numpy as np
import torch
import torch.nn as nn

from args import *
from data.ExeclWrite import *
from models.resnet import resnet20

use_bn = args.usebn
args.wd = 5e-4 if use_bn else 1e-4
bn_name = 'wBN' if use_bn else 'woBN'

logger = logging.getLogger()



@torch.no_grad()
def eval(snn, eval_loader, device):
    snn.set_spike_state(use_spike=True)
    model = nn.DataParallel(snn)
    model.to(device)

    correct = 0.
    total = 0.
    model.eval()
    for batch_idx, (inputs, targets) in enumerate(eval_loader):
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, predicted = outputs.cpu().max(1)
        total += float(targets.size(0))
        correct += float(predicted.eq(targets).sum().item())

    acc = 100. * float(correct) / float(total)
    # print('Server  Test Accuracy of the model: %.3f' % acc)

    return acc


# ann has bn
def Train(model, device, train_loader, id, model_save_name, optimizer, scheduler, test_loader, global_round,
          ):
    ann = nn.DataParallel(model)
    criterion = nn.CrossEntropyLoss().to(device)

    for epoch in range(args.local_epochs):

        running_loss = 0
        total = 0
        correct = 0
        for i, (images, labels) in enumerate(train_loader):

            ann.train()
            optimizer.zero_grad()
            labels = labels.to(device)
            images = images.to(device)
            outputs = ann(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            loss.backward()
            optimizer.step()
            _, predicted = outputs.max(1)
            total += float(labels.size(0))
            correct += float(predicted.eq(labels).sum().item())
        acc = 100. * float(correct) / float(total)
        print(' Epoch: {}/{} Loss: {:.6f},  Acc: {}/{} ({:.2f}%)'.format(
            epoch + 1, args.local_epochs, loss.item(),
            correct, total, acc))
        scheduler.step()

    ann.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = ann(inputs)
            _, predicted = outputs.cpu().max(1)
            total += float(targets.size(0))
            correct += float(predicted.eq(targets.cpu()).sum().item())
    acc1 = 100. * float(correct) / float(total)

    print('\n Test Accuracy of the model on the 10000 test images: %.3f' % (100 * correct / total))
    write_excel_xls_append(args.result, str(id), global_round, acc1, acc)

    torch.save(ann.module.state_dict(), model_save_name)

def transfer_state_dict(Global_dict, model_dict, server_diff1, server_diff2):
    '''
    根据model_dict，选择pretrained_dict的某些参数继承
    :param pretrained_dict: Global_model don't have bn
    :param model_dict: Local_model has bn
    '''
    state_dict = {}
    for k1, v1 in model_dict.items():
        Found = False
        # print(k1)
        if 'fc.bias' in k1:
            state_dict['fc.bias'] = Global_dict['model.fc.bias']
            continue
        for k2, v2 in server_diff2.items():#给weight
            # print(k2)
            if k1 in k2:
                state_dict[k1] = v2
                # print(k1)
                Found = True
                break
        if not Found:
            for k3, v3 in server_diff1.items():#给bn
                # print('!!!!!', k3)
                if k1 in k3:
                    state_dict[k1] = v3
                    # print('转化', k1)
                    Found = True
                    break
        if not Found:#给bias
            state_dict[k1] = v1
            # print('\n ', k1)

    return state_dict


def transfer_snn_to_ann_model(Global_dict, Local_dict, model, server_diff1, server_diff2):
    # Global_dict = torch.load(Global_file)
    model_dict = model.state_dict()

    up_dict = transfer_state_dict(Global_dict, Local_dict, server_diff1, server_diff2)
    # print(up_dict.keys())
    model_dict.update(up_dict)
    model.load_state_dict(model_dict)

    return model


# 计算两个模型之间的范数
def model_norm(model_1, model_2):
    squared_sum = 0
    for name, layer in model_1.named_parameters():
        squared_sum += torch.sum(torch.pow(layer.data - model_2[name], 2))
    return math.sqrt(squared_sum)


def process_grad(grads):
    '''
    Args:
        grads: grad
    Return:
        a flattened grad in numpy (1-D array)
    '''

    client_grads = grads[0]

    for i in range(1, len(grads)):
        client_grads = np.append(client_grads, grads[i])  # output a flattened array
    return client_grads


def get_stdev(parameters):
    # input: the model parameters
    # output: the standard deviation of the flattened vector

    flattened_param = process_grad(parameters)
    return np.std(flattened_param)
