# from domainbed
import sys
sys.path.append('..')
import csv
import statistics
import json
import logging
from utils.eval import accuracies_losses, mean_std
import torch
from utils.aggregate import norm2_model, zero_model, sub_models
import math
import numpy as np
import matplotlib.pyplot as plt
import os, pickle

def show(image):    # with color
    plt.imshow(np.transpose(image, (1, 2, 0)))

def save(datasets, dir_, name):
    os.makedirs(dir_, exist_ok=True)
    output_in = os.path.join(dir_, f'{name}.pickle')
    with open(output_in, 'wb') as output:
        pickle.dump(datasets, output)
    
def log_global_epoch(alg, args, train_loaders, test_loaders, loss_func, device, train_writer, \
                         test_writer, f, t,):
    if t % args.log_interval == 0: # or t >= args.num_rounds - 100:   # computing every round in 100 is too expensive
        logging.info(f'| global round: {t}')

        #etest
        train_metrics = compute_metrics(alg.models, train_loaders, loss_func, device)
        test_metrics = compute_metrics(alg.models, test_loaders, loss_func, device)

        log_metrics('train', *train_metrics, train_writer, t)
        log_metrics('test', *test_metrics, test_writer, t)
        #log_metrics('test', *global_metrics, test_writer, t)

        w_norms = alg.models[0].w.weight.norm(dim=1).detach().cpu().numpy()
        model_norm = math.sqrt(norm2_model(alg.models[0]))

        for i, w_norm in enumerate(w_norms):
            train_writer.add_scalar(f'w_norm/class_{i}', w_norm, t)
        train_writer.add_scalar('w_norm/entire_matrix', np.sqrt((w_norms**2).sum()), t)
        train_writer.add_scalar('model_norm', model_norm, t)

        #log to csv
        train_losses, train_accs, train_acc_mean, _ = train_metrics
        test_losses, test_accs, test_acc_mean, _ = test_metrics
        f.write(f'{t},\"{train_losses}\",\"{train_accs}\",{train_acc_mean},' +
                    f'\"{test_losses}\",\"{test_accs}\",{test_acc_mean},\"{w_norms.tolist()}\"\n')

 
        f.flush()
        return


def log_norm_local(alg, args, old_w_update_vec, old_update_vec, old_model, train_writer, t):
    # log update norm
    with torch.no_grad():
        for i, m in enumerate(alg.models):
            # compute and log update norms
            update = sub_models(m, old_model) # delta_model = m_i - m_old
            update_norm2 = norm2_model(update)
            w_update = update.w.weight
            w_update_norm2 = w_update.norm()**2

            train_writer.add_scalar(f'update_norm/client_{i}', math.sqrt(update_norm2), t)
            train_writer.add_scalar(f'w_update_norm/client_{i}', math.sqrt(w_update_norm2), t)
            train_writer.add_scalar(f'g_update_norm/client_{i}', math.sqrt(update_norm2 - w_update_norm2), t)
            #train_writer.add_scalar(f'local_loss/client_{i}', losses[i], t)

            # compute and log w update angle
            if t % 50 == 0:    
                w_update_vec = w_update.flatten()
                w_update_angle = torch.cosine_similarity(w_update_vec[None], old_w_update_vec[i][None])[0].cpu().item()
                old_w_update_vec[i] = w_update_vec
                train_writer.add_scalar(f'prev_w_update_angle/client_{i}', w_update_angle, t)



def compute_metrics(models, loaders, loss_func, device):
    accs, loss = accuracies_losses(models, loaders, loss_func, device)
    losses_ = [round(x, 3) for x in loss]
    accs = [round(x,3) for x in accs]
    acc_mean, acc_std = (round(x,3) for x in mean_std(accs))
    loss_mean, loss_std = (round(x,3) for x in mean_std(losses_))

    return losses_, accs, acc_mean, acc_std

def log_metrics(prefix, losses_, accs, acc_mean, acc_std, writer, t):

    logging.info(f'{prefix} losses: {losses_}')
    logging.info(f'{prefix} accs: {accs}')
    logging.info(f'{prefix} acc mean (std): {acc_mean} ({acc_std})')

    num_clients = len(losses_)
    writer.add_scalar('acc', acc_mean, t)
    writer.add_scalar('loss', statistics.mean(losses_), t)
    for i in range(num_clients):
        writer.add_scalar(f'acc/client_dist_{i}', accs[i], t)
        writer.add_scalar(f'loss/client_dist_{i}', losses_[i], t)

def log_aggregate_model_diagnostics(alg, t, train_writer, old_model):
    ''' logs diagnostic information about the global aggregate model AFTER aggregation
    '''
    with torch.no_grad():

        # compute and log update norms
        update = sub_models(alg.models[0], old_model) # delta_model = m - m_old
        update_norm2 = norm2_model(update)
        w_update = update.w.weight
        w_update_norm2 = w_update.norm().item()**2
        g_update_norm2 = update_norm2 - w_update_norm2

        train_writer.add_scalar('update_norm/aggregate', math.sqrt(update_norm2), t)
        train_writer.add_scalar('w_update_norm/aggregate', math.sqrt(w_update_norm2), t)
        train_writer.add_scalar('g_update_norm/aggregate', math.sqrt(g_update_norm2), t)

        # compute and log norms of the aggregate model
        model_norm2 = norm2_model(alg.models[0])
        w_norm2 = alg.models[0].w.weight.norm().item()**2
        g_norm2 = model_norm2 - w_norm2

        w_norms_per_class = alg.models[0].w.weight.norm(dim=1).detach().cpu().numpy()

        train_writer.add_scalar('aggregate_model_norm', math.sqrt(model_norm2), t) # entire model norm
        train_writer.add_scalar('aggregate_w_norm/entire_matrix', math.sqrt(w_norm2), t) # entire w matrix norm
        train_writer.add_scalar('aggregate_g_norm', math.sqrt(g_norm2), t) # g norm
        # per class w norms
        for i, class_w_norm in enumerate(w_norms_per_class):
            train_writer.add_scalar(f'aggregate_w_norm/class_{i}', class_w_norm, t)



def to_csv(csv_file, row, mode='w'):
    with open(csv_file, mode) as f:
        writer = csv.writer(f)
        writer.writerow(row)
        
 # present things in a nice format 

def print_acc(list_):
    for elem in list_:
        print(f'{elem * 100:.2f}%', end='\t')
    print('\n')
    
def round_list(list_, dec=4):
    return [round(elem, dec) for elem in list_]


def read_json(file):
    results = []
    with open(file) as f:
        for line in f:
            j_content = json.loads(line)
            results.append(j_content)
    return results


def read_csv(csvfilename):
    data = []
    with open(csvfilename, "r", encoding="utf-8", errors="ignore") as scraped:
        reader = csv.reader(scraped, delimiter=',')
        row_index = 0
        for row in reader:
            if row:  # avoid blank lines
                row_index += 1
                columns = [str(row_index), row[0], row[1], row[2], row[3], row[4], row[5], row[6]]
                data.append(columns)
    return data

def print_mean_std(mean, std):
    return f'{mean}' + '$_{' + f'\pm {std}' + '}$'

def plus_minus(A_means, G_means, stds, worsts, bests):
    mean, std = round_list(list(mean_std(A_means)), dec=2)
    mean_0, std_0 = round_list(list(mean_std(G_means)), dec=2)
    mean_1, std_1 = round_list(list(mean_std(stds)), dec=2)
    mean_2, std_2 = round_list(list(mean_std(worsts)), dec=2)
    mean_3, std_3 = round_list(list(mean_std(bests)), dec=2)
    print(print_mean_std(mean, std) + ' & ' + print_mean_std(mean_0, std_0) + ' & ' +\
          print_mean_std(mean_1, std_1) + ' & ' + print_mean_std(mean_2, std_2) + ' & ' +\
          print_mean_std(mean_3, std_3))
    return



def save_acc_loss(json_file, t, acc, loss):
    result = {}
    result['epoch'] = t
    result['accs'] = list(acc)
    result['losses'] = list(loss)
    with open(json_file, 'a') as f:
        f.write(json.dumps(result, sort_keys=True) + '\n')


