#!/usr/bin/env python
# coding: utf-8

import os
import os.path as osp
import torch


from .dyn_model import dyn2DMat


def dyn2dnn(target_model, dyn_root, device='cuda:0'):

    for file_name in os.listdir(dyn_root):

        var_name = file_name.split('#')[0]
        tar_mat = target_model.state_dict()[var_name]

        load_configs = {'H_num': int(file_name.split('#')[2]),
                'q_dim': int(file_name.split('#')[3]),
                'norm_p': int(file_name.split('#')[4]),
                '_scale': int(file_name.split('#')[5])}

        if len(tar_mat.shape) == 2:
            num_rowQ, num_colQ = tar_mat.shape
            load_DyN = dyn2DMat(num_rowQ, num_colQ, load_configs['H_num'], load_configs['q_dim'], device=device, p=load_configs['norm_p'], _scale=load_configs['_scale'])

        dyn_path = osp.join(dyn_root, file_name)
        load_DyN.load_state_dict(torch.load(dyn_path))
        load_DyN.eval()
  
        rec_param = load_DyN.forward()
        target_model.state_dict()[var_name].copy_(rec_param)

    return target_model



