#!/usr/bin/env python
# coding: utf-8

import os
import os.path as osp
from datetime import datetime 
import torch
import random


from .dyn_model import dyn2DMat
from .dyn_fitting import update_DyN
from .utils import build_opt_sch

def dnn2dyn(configs, target_model, vars_list=[]):
    """_summary_

    Args:
        configs (dict): configs of dyn and training
        target_model (nn.Module): target model to be converted
        vars_list (list[str]): target variables to be converted

    Returns:
        dict: dyn_model of vars_list
    """
    s_dict = target_model.state_dict()
    model_dyn = dict()
    for var_name, tar_mat in s_dict.items():
        if var_name in vars_list:
            print("tar_mat shape {}, layer {}".format(tar_mat.shape, var_name))
            layer_dyn = mat2dyn(tar_mat, var_name, **configs)
            model_dyn[var_name] = layer_dyn
    return model_dyn




def mat2dyn(
    tar_mat,
    var_name,
    q_dim=10,
    H_num=10,
    norm_p=1,
    _scale=1,
    max_dynEpoch=200000,
    meanL_thres=1e-10,
    lr=1e-4,
    save_dyn=False,
    SAVE_ROOT='dyn_coords',
    device='cuda:0',
    opt='adam',
    sch='exp',
    weight_decay=0.0,
    opt_decay_step=1000,
    opt_decay_rate=0.99,
    opt_restart=1
    ):
    """
    convert a 2D matrix to dyn

    Args:
        tar_mat (torch.Tensor): Target matrix to be converted to dyn
        var_name (str): Param name of target matrix
        q_dim (int, optional): Embedding dim of dyn. Defaults to 10.
        H_num (int, optional): Partial linear param. Defaults to 10.
        norm_p (int, optional): Norm used in the optimization. Defaults to 1.
        _scale (int, optional): [Coming soon]. Defaults to 1.
        max_dynEpoch (int, optional): Epoches of dyn training. Defaults to 200000.
        meanL_thres (float, optional): threshold for early stopping. Defaults to 1e-10.
        lr (float, optional): learning rate for dyn training. Defaults to 1e-4.
        save_dyn (bool, optional): Whether to save dyn embeddings. Defaults to False.
        SAVE_ROOT (str, optional): the root directory of dyn embeddings. Defaults to 'dyn_coords'.
        device (str, optional): GPU devices. Defaults to cuda:0.

    Returns:
        torch.nn.Module: dyn model of the target matrix
    """
    rand_id = round(random.random()*100000)

    print('--- Rand_id:', rand_id)
    print('num_rawParams:', tar_mat.numel())
    


    if len(tar_mat.shape) == 2:
        num_rowQ, num_colQ = tar_mat.shape
        print('num_dynParams:', q_dim*H_num*(num_rowQ+num_colQ))
        DyMat_model = dyn2DMat(num_rowQ, num_colQ, H_num, q_dim, device=device, p=norm_p, _scale=_scale)

    if not osp.exists(SAVE_ROOT):
        os.makedirs(SAVE_ROOT)
    FILE_NAME = '{}#Config#{}#{}#{}#{}#.dyn'.format(
        var_name, H_num, q_dim, norm_p, _scale
        )
    SAVE_PATH = osp.join(SAVE_ROOT, FILE_NAME)

    if osp.exists(SAVE_PATH):
        print('load dyn weights from {}'.format(SAVE_PATH))
        checkpoint = torch.load(SAVE_PATH)
        DyMat_model.load_state_dict(checkpoint)
        DyMat_optim, DyMat_sche = build_opt_sch(
            opt=opt, 
            sch='None',
            params=DyMat_model.parameters(), 
            weight_decay=weight_decay, 
            lr=1e-4, 
            opt_decay_step=opt_decay_step, 
            opt_decay_rate=opt_decay_rate,
            opt_restart=opt_restart
            )
    else:
        DyMat_optim, DyMat_sche = build_opt_sch(
            opt=opt, 
            sch=sch, 
            params=DyMat_model.parameters(), 
            weight_decay=weight_decay, 
            lr=lr, 
            opt_decay_step=opt_decay_step, 
            opt_decay_rate=opt_decay_rate,
            opt_restart=opt_restart
            )
    #torch.optim.Adam(DyMat_model.parameters(), lr=lr)
    print(DyMat_optim)
    print(datetime.now().time())
    DyMat_model, _max, _mean = update_DyN(DyMat_model, DyMat_optim, DyMat_sche, tar_mat, max_dynEpoch, meanL_thres=meanL_thres)

    if save_dyn:
        torch.save(DyMat_model.state_dict(), SAVE_PATH)
    return DyMat_model

