#!/usr/bin/env python
# coding: utf-8

import os
import os.path as osp
import json
from datetime import datetime 
import torch
import torch.nn as nn
import time
import random

import copy

from .dyn_model import dyn2DMat, dyn4DMat


def dyn2dnn(target_model, dyn_root, device='cuda:0'):

    for file_name in os.listdir(dyn_root):

        var_name = file_name.split('#')[0]
        tar_mat = target_model.state_dict()[var_name]
        load_configs = {'H_num': int(file_name.split('#')[2]),
                'q_dim': int(file_name.split('#')[3]),
                'norm_p': int(file_name.split('#')[4]),
                '_scale': int(file_name.split('#')[5])}

        if len(tar_mat.shape) == 2:
            num_input, num_output = tar_mat.shape
            load_DyN = dyn2DMat(num_input, num_output, load_configs['H_num'], load_configs['q_dim'], device=device, p=load_configs['norm_p'], _scale=load_configs['_scale'])
        elif len(tar_mat.shape) == 4:
            num_output, num_input, kernel_size, _ = tar_mat.shape
            load_DyN = dyn4DMat(kernel_size, num_input, num_output, load_configs['H_num'], load_configs['q_dim'], device=device, p=load_configs['norm_p'], _scale=load_configs['_scale'])

        dyn_path = osp.join(dyn_root, file_name)
        load_DyN.load_state_dict(torch.load(dyn_path))
        load_DyN.eval()
  
        rec_param = load_DyN.forward()
        if len(tar_mat.shape) == 2:
            print('load weights {}'.format(var_name))
            target_model.state_dict()[var_name].copy_(rec_param)
        elif len(tar_mat.shape) == 4:
            print('load weights {}'.format(var_name))
            for row_id in range(kernel_size):
                for col_id in range(kernel_size):
                    conv_param = rec_param[row_id*kernel_size+col_id]
                    target_model.state_dict()[var_name][:,:,row_id,col_id].copy_(conv_param)

    return target_model



