import numpy as np
import pandas as pd
from torch.nn import functional as F
import torch
import torch.optim as optim
import torch.nn as nn
from scipy.spatial.distance import cdist

def to_onehot(label, num_class):

    identity = torch.eye(num_class).cuda()
    onehot = torch.index_select(identity, 0, label)
    return onehot

def evaluation(loader, model):
    start_test = True
    with torch.no_grad():
        iter_test = iter(loader)
        for i in range(len(loader)):
            data = iter_test.next()
            inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            outputs, feats, _, _ = model(inputs)
            if start_test:
                all_input = inputs.float().cpu()
                all_feat = feats.float().cpu()
                all_output = outputs.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_input = torch.cat((all_input, inputs.float().cpu()), 0)
                all_feat = torch.cat((all_feat, feats.float().cpu()), 0)
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)

        value, predict = torch.max(all_output, 1)
        accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])


        return accuracy

def lr_schedule(optimizer, iter_num, max_iter, gamma=10, power=0.75):

    decay = (1 + gamma * iter_num / max_iter) ** (-power)
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr0'] * decay
        param_group['weight_decay'] = 1e-3
        param_group['momentum'] = 0.9
        param_group['nesterov'] = True
    return optimizer

def client_optimizer(modelI, modelS, modelD, lr):
    param_group = []
    learning_rate = lr
    for k, v in modelI.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]

    for k, v in modelS.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]

    for k, v in modelD.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)
    return optimizer

def server_optimizer(model, lr):
    param_group = []
    learning_rate = lr
    for k, v in model.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    return optimizer

def op_copy(optimizer):

    for param_group in optimizer.param_groups:
        param_group['lr0'] = param_group['lr']
    return optimizer