import logging
import os

import numpy as np
import torch
import wandb


def transform_list_to_tensor(model_params_list):
    for k in model_params_list.keys():
        model_params_list[k] = torch.from_numpy(np.asarray(model_params_list[k])).float()
    return model_params_list


def transform_tensor_to_list(model_params):
    for k in model_params.keys():
        model_params[k] = model_params[k].detach().numpy().tolist()
    return model_params


def get_global_model_path(args):
    folder = "./checkpoint/" + str(args.pssl_optimizer) + str(args.run_id)
    if not os.path.exists(folder):
        os.mkdir(folder)
    global_model_path = folder + "/" + args.model + "_best_acc.pth"
    return global_model_path


def save_global_model(args, model):
    path = get_global_model_path(args)
    torch.save(model.cpu().state_dict(), path)
    wandb.save(path)


def load_global_model(args, model):
    path = get_global_model_path(args)
    if os.path.exists(path):  # checking if there is a file with this name
        # state_dict = torch.load(path, map_location="cpu")
        state_dict = torch.load(path, map_location="cpu")['state_dict']
        model.load_state_dict(state_dict)
        logging.info("Loaded Global Model")
    else:
        logging.info("Failed")


def get_personalized_model_path(args, client_index):
    personalized_model_path = args.personalized_model_path + "/" + str(args.pssl_optimizer) + \
                                   "_" + str(args.run_id)
    if not os.path.exists(personalized_model_path):
        os.mkdir(personalized_model_path)

    folder = personalized_model_path + "/personalized_model"
    if not os.path.exists(folder):
        os.mkdir(folder)
    personalized_path = folder + '/client_' + str(client_index) + '.pth'
    return personalized_path


def save_personal_model(args, personalized_model, client_index):
    # create folder
    path = get_personalized_model_path(args, client_index)
    torch.save(personalized_model.cpu().state_dict(), path)  # save the model
    logging.info(" Personal Model of Client number %d saved " % client_index)


def load_personal_model(args, personalized_model, client_index):
    path = get_personalized_model_path(args, client_index)
    if os.path.exists(path):  # checking if there is a file with this name
        personalized_model.load_state_dict(torch.load(path))  # if yes load it
        logging.info(" Personal Model of Client number %d Loaded " % client_index)
    else:
        logging.info(" Personal Model does not exist")
