#!/usr/bin/python
# -*- coding: utf-8 -*-
import re
import re
import os
import shutil
import copy
import datetime
import numpy as np
import torch
import scipy
import datetime
import sys
import time
import math
import random
import statistics

from typing import List
from collections import OrderedDict
from scipy.stats import beta as beta_distribution
from .logger import *

sys.setrecursionlimit(10000)


def get_specific_time():
    now = time.localtime()
    year, month, day = str(now.tm_year), str(now.tm_mon), str(now.tm_mday)
    hour, minute, second = str(now.tm_hour), str(now.tm_min), str(now.tm_sec)
    return str(year + "_" + month + "_" + day + "_" + hour + "h" + minute + "m" + second + "s")


REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}",
         "-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'}


def clean(x):
    x = x.lower()
    return re.sub(
        r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''",
        lambda m: REMAP.get(m.group()), x)


def check_and_make_the_path(path):
    if not os.path.exists(path):
        os.makedirs(path)


# compute the cos similarity between a and b. a, b are numpy arrays
def cos_sim(a, b):
    return 1 - scipy.spatial.distance.cosine(a, b)


def eval_label(match_true, pred, true, total, match):
    match_true, pred, true, match = match_true.float(), pred.float(), true.float(), match.float()
    try:
        print("match_true:", match_true.data, " ;pred:", pred.data, " ;true:", true.data, " ;match:", match.data,
              " ;total:", total)
        accu = match / total
        precision = match_true / pred
        recall = match_true / true
        F = 2 * precision * recall / (precision + recall)
    except ZeroDivisionError:
        accu, precision, recall, F = 0.0, 0.0, 0.0, 0.0
        logger.error("[Error] float division by zero")
    return accu, precision, recall, F


def normalization(x):
    """"
    归一化到区间{0,1]
    返回副本
    """
    _range = np.max(x) - np.min(x)
    return (x - np.min(x)) / _range


def get_parameters(model) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in model.state_dict().items()]


def set_parameters(model, parameters: List[np.ndarray]):
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)
    return model


def get_gradient(loss, model, create_graph=True):
    # 注意，create_graph参数用于决定梯度计算图是否保留
    # 如果create_graph为False，Pytorch在计算梯度后梯度计算图会释放，后续如果用loss.backward()会失效
    return torch.autograd.grad(loss, model.parameters(), create_graph)


def get_gradient_weighted_sum(weighted_list, gradient_list):
    grad_sum = []
    for i in range(len(weighted_list)):
        grad_sum[i] = float(weighted_list[i]) * gradient_list[i]
    return tuple(grad_sum)


def get_the_number_of_model_parameters(model):
    num_parameters = sum(torch.numel(parameter) for parameter in model.parameters())
    return num_parameters


def communication_cost_simulated_by_beta_distribution(client_number, alpha=0.3, beta=1):
    x = np.arange(0, 1, 1 / client_number)
    y = beta_distribution.pdf(x, alpha, beta)
    for index in range(len(y)):
        if math.isinf(y[index]):
            y[index] = 16
        elif math.isnan(y[index]):
            y[index] = 0.001
        else:
            y[index] = round(y[index], 4) + 1
    descending_order_list = [i for i in range(client_number)]
    random.shuffle(descending_order_list)
    return y, descending_order_list



def save_model(param_dict, updated_global_model, client_model_list, iter_t, optim):
    logger.info("Communication Round %d Global Models Saving" % (iter_t + 1))
    model_state_dict = updated_global_model.state_dict()
    checkpoint = {
        'model': model_state_dict,
        # 'generator': generator_state_dict,
        'opt': param_dict,
        'optims': optim,
    }
    check_and_make_the_path(param_dict['model_path'])
    torch.save(checkpoint, os.path.join(param_dict['model_path'], "step_%d_" % iter_t + "global_model.pt"))
    # torch.save(updated_global_model, os.path.join(param_dict['model_path'], "step_%d_" % iter_t + "global_model.pkl"))
    logger.info("Communication Round %d Client Models Saving" % (iter_t + 1))
    for client_id, client_model in enumerate(client_model_list):
        _ = os.path.join(param_dict['model_path'],
                         "client_" + str(client_id + 1), "step_%d_" % iter_t + "model.pkl")
        check_and_make_the_path(os.path.join(param_dict['model_path'], "client_" + str(client_id + 1)))
        torch.save(client_model, _)


def get_HM_by_two_value(acc, FR):
    HM = statistics.harmonic_mean([float(acc), float(FR)])
    return HM
