import os
import uuid
from pathlib import Path

import colorful
import torch
from .classes import ModelInterface
from .training.trainer import Trainer


def save_model(path: str, model: ModelInterface):
    save_state_dict(path, {
        "version": model.version,
        "parameters": model.state_dict()
    })


def load_model(path: str, model: ModelInterface):
    try:
        state_dict = load_state_dict(path)
        model.load_state_dict(state_dict["parameters"])
        model.version = state_dict["version"]
    except FileNotFoundError:
        pass


def save_trainer(path: str, trainer: Trainer):
    trainer_state = trainer.state_dict()
    save_state_dict(path, trainer_state)


def load_trainer(path: str, trainer: Trainer):
    try:
        trainer.load_state_dict(load_state_dict(path))
    except FileNotFoundError:
        pass


def load_state_dict(path: str):
    path = Path(path)
    if path.is_file():
        print(colorful.bold_green("Loading {} ...".format(path)))
        state_dict = torch.load(path, map_location="cpu")
        return state_dict
    else:
        print(colorful.bold_orange("{} not found".format(path)))
        raise FileNotFoundError()


def save_state_dict(path: str, state_dict):
    path = Path(path)
    tmp_filename = str(uuid.uuid4())
    root_path = path.parent
    tmp_path = root_path / tmp_filename
    torch.save(state_dict, tmp_path)
    os.rename(tmp_path, path)
