import logging
import os
import shutil

import numpy as np
import torch
import wandb


def transform_list_to_tensor(model_params_list):
    for k in model_params_list.keys():
        model_params_list[k] = torch.from_numpy(np.asarray(model_params_list[k])).float()
    return model_params_list


def transform_tensor_to_list(model_params):
    for k in model_params.keys():
        model_params[k] = model_params[k].detach().numpy().tolist()
    return model_params


def post_complete_message_to_sweep_process(args):
    os.system("mkdir ./tmp/; touch ./tmp/fedml")
    pipe_path = "./tmp/fedml"
    if not os.path.exists(pipe_path):
        os.mkfifo(pipe_path)
    pipe_fd = os.open(pipe_path, os.O_WRONLY)

    with os.fdopen(pipe_fd, 'w') as pipe:
        pipe.write("training is finished! \n%s\n" % (str(args)))


def save_checkpoint(model_name, round, model, acc):
    state = {
        'model_name': model_name,
        'round': round,
        'state_dict': model.state_dict(),
        'acc': acc
    }
    filename = "./checkpoint/" + model_name + "_fedssl_best_acc.pth"
    torch.save(state, filename)
    wandb.save(filename)


def load_checkpoint(model, model_name):
    filename = "./checkpoint/" + model_name + "_fedssl_best_acc.pth"
    checkpoint = torch.load(filename, map_location='cuda:0')
    start_round = checkpoint['round']
    acc = checkpoint['acc']
    model.load_state_dict(checkpoint['state_dict'])

    return start_round, model, acc


def clear_cache_for_personalized_model(args):
    folder = args.personalized_model_path + "/personalized_model"
    try:
        if os.path.exists(folder):
            shutil.rmtree(folder)
    except Exception:
        print("failed")


def save_personal_model(args, personalized_model, client_index):
    # create folder
    folder = args.personalized_model_path + "/personalized_model"
    if not os.path.exists(folder):
        os.mkdir(folder)
    path = folder + '/client_' + str(client_index) + '.pth'
    torch.save(personalized_model.cpu().state_dict(), path)  # save the model
    logging.info(" Personal Model of Client number %d saved " % client_index)


def load_personal_model(args, personalized_model, client_index):
    path = args.personalized_model_path + '/personalized_model/client_' + str(client_index) + '.pth'
    if os.path.exists(path):  # checking if there is a file with this name
        personalized_model.load_state_dict(torch.load(path))  # if yes load it
        logging.info(" Personal Model of Client number %d Loaded " % client_index)
