import argparse
import copy
import os
import random

from torch_head import *
from common_head import *

import time

def FedAvg(params, weights=None):
    if not weights:
        weights = [1] * len(params)
    # assert len(params) == len(weights), (len(params), len(weights))
    avg_params = []
    for _, param in enumerate(zip(*params)):
        avg_params.append(weighted_sum(param, weights))
    return avg_params

def weighted_sum(inputs, weights):
    weights_sum = sum(weights)
    # print(weights[0] / weights_sum)
    res = 0
    for i, w in zip(inputs, weights):
        res += i * (w / weights_sum)
    return res

def plot_scores_client(score_list, save_dir, client):
    x = np.array(score_list)
    ep = x[:, 0]
    IS = x[:, 1]
    IS_std = x[:, 2]
    FID = x[:, 3]
    fig, ax = plt.subplots(1, 2, sharex='all', figsize=(12, 4.8))
    ax[0].set_xlabel('Epoch')
    ax[1].set_xlabel('Epoch')
    ax[0].set_ylabel('Inception Score')
    ax[1].set_ylabel('FID')
    ax[0].plot(ep, IS, 'r-', linewidth=3)
    ax[0].plot(ep, IS - IS_std, 'r--', linewidth=1)
    ax[0].plot(ep, IS + IS_std, 'r--', linewidth=1)
    ax[1].plot(ep, FID, 'r-', linewidth=3)

    plt.savefig(os.path.join(save_dir, f'score_client-{client}.png'))
    plt.close()
    np.save(os.path.join(save_dir, f'score_client-{client}.npy'), x)


def print_fl_log_2(client, epoch, epoches, iteration, iters, 
                learning_rate, display, batch_time, data_time, 
                D_losses, G_losses, G_n_losses, G_ep_losses,
                logger=None):
    if logger is None:
        info = print
    else:
        info = logger.info
    info(f'client-{client} epoch: [{epoch}/{epoches}] '
         f'iteration: [{iteration}/{iters}]\tLearning rate: {learning_rate}')
    info('Time {batch_time.sum:.3f}s / {0} iters, ({batch_time.avg:.3f})\t'
          'Data load {data_time.sum:.3f}s / {0} iters, ({data_time.avg:3f})\n'
          'Loss_D = {loss_D.val:.8f} (ave = {loss_D.avg:.8f})\n'
          'Loss_G = {loss_G.val:.8f} (ave = {loss_G.avg:.8f})\n'
          'Loss_GN = {loss_GN.val:.8f} (ave = {loss_GN.avg:.8f})\n'
          'Loss_GEP = {loss_GEP.val:.8f} (ave = {loss_GEP.avg:.8f})\n'.format(
        display, batch_time=batch_time,
        data_time=data_time, loss_D=D_losses, loss_G=G_losses, 
        loss_GN=G_n_losses, loss_GEP=G_ep_losses))
    info(time.strftime('%Y-%m-%d %H:%M:%S ' + '-'*30 +'\n', time.localtime()))



def plot_fl_loss_my(client, d_loss, g_loss, g_loss_hist, g_loss_ep, num_epoch, epoches, save_dir):
    fig, ax = plt.subplots()
    ax.set_xlim(0, epoches + 1)
    ax.set_ylim(min(np.min(g_loss_hist), min(np.min(g_loss), np.min(d_loss))) - 0.1,
                max(np.max(g_loss), np.max(d_loss)) * 1.1)
    plt.xlabel('Epoch {}'.format(num_epoch))
    plt.ylabel('Loss')

    plt.plot([i for i in range(1, num_epoch + 1)], d_loss, label='Discriminator', color='red', linewidth=3)
    plt.plot([i for i in range(1, num_epoch + 1)], g_loss, label='Generator', color='mediumblue', linewidth=3,
             alpha=0.5)
    plt.plot([i for i in range(1, num_epoch + 1)], g_loss_hist, label='Generator - (hist)', color='green', linewidth=3,
             alpha=0.5)
    plt.plot([i for i in range(1, num_epoch + 1)], g_loss_ep, label='Generator - (ep)', color='gold', linewidth=3,
             alpha=0.5)
    plt.legend()
    plt.savefig(os.path.join(save_dir, f'DCGAN_loss_client_{client}_epoch_{num_epoch}.png'))
    plt.close()


def plot_fl_result(client, G, fixed_noise, image_size, num_epoch, save_dir, fig_size=(10, 10), is_gray=False, n_side=10):
    G.eval()
    generate_images = G(fixed_noise)
    G.train()

    n_rows = n_cols = n_side
    fig_size = (n_side, n_side)

    fig, axes = plt.subplots(n_rows, n_cols, figsize=fig_size)

    for ax, img in zip(axes.flatten(), generate_images):
        ax.axis('off')
        # ax.set_adjustable('box-forced')
        if is_gray:
            img = img.cpu().data.view(image_size, image_size).numpy()
            ax.imshow(img, cmap='gray', aspect='equal')
        else:
            img = (((img - img.min()) * 255) / (img.max() - img.min())).cpu().data.numpy().transpose(1, 2, 0).astype(
                np.uint8)
            ax.imshow(img, cmap=None, aspect='equal')
    plt.subplots_adjust(wspace=0, hspace=0)
    title = f'Epoch {num_epoch}'
    fig.text(0.5, 0.04, title, ha='center')

    plt.savefig(os.path.join(save_dir, f'DCGAN_client_{client}_epoch_{num_epoch}.png'))
    plt.close()

def get_results(path, n):
    print(f"==== {path} ====")
    path = os.path.join(path, 'score_client-{}.npy')
    c_nps = []
    for i in range(n):
        c_path = path.format(i)
        c_nps.append(np.load((c_path)))
    client_scores = {'is':[], 'is_std':[], 'fid':[]}
    for n in c_nps:
        if not np.isnan(n[-1][1]):
            client_scores['is'].append(n[-1][1])
        if not np.isnan(n[-1][2]):
            client_scores['is_std'].append(n[-1][2])
        if not np.isnan(n[-1][3]):
            client_scores['fid'].append(n[-1][3])
    for k in client_scores:
        print(k, f"{np.mean(client_scores[k]):.4f}")

if __name__ == "__main__":
    pass
    # get_results("dump/train_stl10_W_10_25_10_i/", 10)
    # get_results("dump/train_stl10_W_2_25_10_i/", 2)
    # get_results("dump/train_stl10_LS_10_25_10_i/", 10)
    # get_results("dump/train_stl10_LS_2_25_10_i/", 2)
    # get_results("dump/train_stl10_W_10_25_10_f/", 10)
    get_results("dump/train_stl10_LS_10_25_10_f/", 10)