#!/usr/bin/env python
# coding: utf-8

import os
import os.path as osp
import json
from datetime import datetime 
import torch
import torch.nn as nn
import time
import random
import torchvision.models as models

import copy

from .dyn_model import Conv_DyN, Linear_DyN, Linear_DyN_NoMat
from .dyn_fitting import update_DyN
from .utils import build_opt_sch

def dnn2dyn(configs, target_model, vars_list=[], logger=None):
    """_summary_

    Args:
        configs (dict): configs of dyn and training
        target_model (nn.Module): target model to be converted
        vars_list (list[str]): target variables to be converted

    Returns:
        dict: dyn_model of vars_list
    """
    s_dict = target_model.state_dict()
    model_dyn = dict()
    for var_name, tar_mat in s_dict.items():
        if var_name in vars_list:
            logger.info("tar_mat shape {}, layer {}".format(tar_mat.shape, var_name))
            layer_dyn = mat2dyn(tar_mat, var_name, logger, **configs)
            model_dyn[var_name] = layer_dyn
    return model_dyn




def mat2dyn(
    tar_mat,
    var_name,
    logger=None,
    q_dim=10,
    H_num=10,
    norm_p=2,
    SCALE_FACTOR_fc=0.01,
    SCALE_FACTOR_conv=0.01,
    max_dynEpoch=200000,
    meanL_thres=1e-10,
    lr=1e-4,
    save_dyn=False,
    SAVE_ROOT='dyn_coords',
    device='cuda:0',
    opt='adam',
    sch='exp',
    weight_decay=0.0,
    opt_decay_step=1000,
    opt_decay_rate=0.99,
    opt_restart=1
    ):
    """
    convert a 2D matrix to dyn

    Args:
        tar_mat (torch.Tensor): Target matrix to be converted to dyn
        var_name (str): Param name of target matrix
        q_dim (int, optional): Embedding dim of dyn. Defaults to 10.
        H_num (int, optional): Partial linear param. Defaults to 10.
        norm_p (int, optional): Norm used in the optimization. Defaults to 1.
        _scale (int, optional): [Coming soon]. Defaults to 1.
        max_dynEpoch (int, optional): Epoches of dyn training. Defaults to 200000.
        meanL_thres (float, optional): threshold for early stopping. Defaults to 1e-10.
        lr (float, optional): learning rate for dyn training. Defaults to 1e-4.
        save_dyn (bool, optional): Whether to save dyn embeddings. Defaults to False.
        SAVE_ROOT (str, optional): the root directory of dyn embeddings. Defaults to 'dyn_coords'.
        device (str, optional): GPU devices. Defaults to cuda:0.

    Returns:
        torch.nn.Module: dyn model of the target matrix
    """
    rand_id = round(random.random()*100000)

    logger.info('--- Rand_id:{}'.format(rand_id))
    logger.info('num_rawParams:{}'.format(tar_mat.numel()))
    


    if len(tar_mat.shape) == 2:
        num_output, num_input = tar_mat.shape
        logger.info('num_dynParams:{}'.format(q_dim*H_num*(num_input+num_output)))
        DyMat_model = Linear_DyN_NoMat(num_input, num_output, H_num, q_dim, norm_p=norm_p, SCALE_FACTOR_fc=SCALE_FACTOR_fc)

    elif len(tar_mat.shape) == 4:
        num_output, num_input, kernel_size, _ = tar_mat.shape
        logger.info('num_dynParams:{}'.format(q_dim*H_num*(num_input+num_output)*kernel_size))
        DyMat_model = Conv_DyN(kernel_size, num_input, num_output, stride=0, padding=0, num_CHs=H_num, q_dim=q_dim, norm_p=norm_p, SCALE_FACTOR_conv=SCALE_FACTOR_conv)
    
    DyMat_model = DyMat_model.to(device)

    if not osp.exists(SAVE_ROOT):
        os.makedirs(SAVE_ROOT)
    FILE_NAME = '{}#Config#{}#{}#{}#{}#.dyn'.format(
        var_name, H_num, q_dim, norm_p, SCALE_FACTOR_conv
        )
    SAVE_PATH = osp.join(SAVE_ROOT, FILE_NAME)
    if osp.exists(SAVE_PATH):
        print(f'{SAVE_PATH} exists')
        return None
    DyMat_optim, DyMat_sche = build_opt_sch(
        opt=opt, 
        sch=sch, 
        params=DyMat_model.parameters(), 
        weight_decay=weight_decay, 
        lr=lr, 
        opt_decay_step=opt_decay_step, 
        opt_decay_rate=opt_decay_rate,
        opt_restart=opt_restart
        )
    print(DyMat_optim)
    print(datetime.now().time())
    DyMat_model, _max, _mean = update_DyN(DyMat_model, DyMat_optim, DyMat_sche, tar_mat.to(device), max_dynEpoch, meanL_thres=meanL_thres, logger=logger)

    if save_dyn:
        torch.save(DyMat_model.state_dict(), SAVE_PATH)
    return DyMat_model

