from io import BytesIO
import torch as th
import os
import io
import glob
import time


def create_models(state_dim,  action_dim, use_multi_head=False):
    if use_multi_head is True:
        from learner.models.model_float_multi_head import PolicyNet
        net = PolicyNet(state_dim, action_dim)
    else:
        from learner.models.model_float import PolicyNet
        net = PolicyNet(state_dim, action_dim)
    return net


def convert_to_pt(args, net, graph_buffer):
    fpnet = BytesIO()
    P = th.jit.script(net)
    th.jit.save(P, fpnet)
    fpnet.seek(0)
    graph_buffer[0] = fpnet.read()


def load_model(path):
    with open(path, 'rb') as m:
        models = io.BytesIO(m.read())
    net = th.jit.load(models, map_location=th.device('cpu'))
    return net


def load_model_from_p2p(path):
    prefix = path + '_*'
    model_files = glob.glob(prefix)
    # filter out the first model file
    model_files = sorted(model_files)
    file = model_files[-1]
    net = load_model(file)
    return net, file


def deserialize_model(model, path, device=th.device("cpu")):
    model.load_state_dict(th.load(path, map_location=device))


def serialize_model(saved_path_prefix, url_path_prefix, net, p2p_cache_size, logger):
    # timestamp = str(time.time())
    timestamp = time.strftime("%y%m%d%H%M%S", time.localtime(time.time()))
    file_name = saved_path_prefix + '_' + timestamp
    url_path = url_path_prefix + '_' + timestamp
    script_net = th.jit.script(net)
    th.jit.save(script_net, file_name)

    remove_file(saved_path_prefix, p2p_cache_size, logger)
    return url_path, file_name


def remove_file(path_prefix, p2p_cache_size, logger):
    prefix = path_prefix + '_*'
    model_files = glob.glob(prefix)
    if len(model_files) > p2p_cache_size:
        # filter out the latest model files
        model_files = sorted(model_files)
        file = model_files[0]
        os.remove(file)
        logger.info("remove p2p file %s" % file)


def save_model(pnet, vnet, filename='model.pth'):
    th.save(
        {
            'pnet': pnet.state_dict(),
            'vnet': vnet.state_dict()
        },
        filename
    )
