import copy
import logging
import os
import re

import torch

from config import RESULTS_DIR
from src.utils.common import create_experiment


def natural_key(string_):
    """See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
    return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_)]


def load_ckp(experiment_dir, model, optimizer, device):
    checkpoint_path = os.path.join(RESULTS_DIR, experiment_dir, 'models', 'checkpoints')
    last_saved_model = sorted(os.listdir(checkpoint_path), key=natural_key)[-1]
    last_saved_model_path = os.path.join(checkpoint_path, last_saved_model)

    checkpoint = torch.load(last_saved_model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    logging.info(f"Loading model from checkpoint {last_saved_model}, "
                 f"starting from epoch {checkpoint['epoch']}, loss {checkpoint['val_loss']}. \n")

    return model, optimizer, checkpoint['epoch'], checkpoint['val_loss']


def load_best(experiment_dir, device, args):
    best_model_path = os.path.join(RESULTS_DIR, experiment_dir, 'models', 'final')
    best_model = os.listdir(best_model_path)[-1]
    best_model_path = os.path.join(best_model_path, best_model)

    experiment = create_experiment(args)
    model = experiment.model
    model.load_state_dict(torch.load(best_model_path, map_location=device))

    return model


def save_ckp(model, epoch, val_loss, optimizer, path):
    torch.save(
        {
            'model': model,
            'epoch': epoch + 1,
            'val_loss': val_loss,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        },
        os.path.join(path, 'models', 'checkpoints', f"epoch_{epoch}.pt"),
    )


def save_best(model, epoch, val_loss, best_loss, path):
    logging.info("====================================================")
    logging.info(f'Validation loss decreased on epoch {epoch}: ({best_loss} --> {val_loss}). Saving model...')
    logging.info("====================================================")
    torch.save(model.state_dict(), os.path.join(path, 'models', 'final', 'best_model.pt'))

    return copy.deepcopy(model)
