import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from sys import argv
import numpy as np

MODEL_ROOT_DIR = '../model/cpp'
#path = argv[1]
path = '../model/vgg_bwn_cyc_hill_of0.05_k2_ft_ckpt.pth'
model_name = 'vgg_bwn_cyc'


def mkd(path):
    if not os.path.exists(path):
        os.mkdir(path)

def main():
    model_dir = os.path.join(MODEL_ROOT_DIR, model_name)
    mkd(model_dir)

    model = torch.load(path, map_location=torch.device('cpu'))
    pl = model['net']
    param_dict = {}
    for k,v in pl.items(): 
        if 'var' in k:
            k = k.replace('var', 'std')
        param_dict[k] = v.data.numpy()

    for k,v in param_dict.items():
        prefix = k.split('.')[0]
        layer_name = k.split('.')[-2]
        param_name = k.split('.')[-1]
        
        mat = v

        if param_name == 'weight' and 'conv' in layer_name:
            ch_num = mat.shape[0]
            vectorisze_size = mat.shape[1] * mat.shape[2] * mat.shape[3]
            mat = mat.reshape((ch_num, -1))
            
            layer_count = int(layer_name.replace('conv',''))
            if layer_count > 1:
                print(k)
                alpha = np.mean(np.abs(mat), axis=1)
                target_bn = 'bn' + str(layer_count)
                target_bn_key = '.'.join([prefix, target_bn, 'weight'])
                
                param_dict[target_bn_key] = np.multiply(param_dict[target_bn_key], alpha)
                mat = np.sign(mat).astype(np.int32)

            param_dict[k] = mat
            print(param_dict[k][0,0])
        
        if param_name == 'running_std' and 'bn' in layer_name:
            mat = np.sqrt(mat + 1e-10)
            param_dict[k] = mat

    for k,v in param_dict.items():
        layer_name = k.split('.')[-2]
        param_name = k.split('.')[-1]
        model_layer_dir = os.path.join(model_dir, layer_name)
        mkd(model_layer_dir)
        
        if len(v.shape) == 1: v = v.reshape((1, v.shape[0]))

        if v.dtype == np.int32:
            type_postifx = 'int'
        if v.dtype == np.float32:
            type_postifx = 'float'

        param_save_target =  os.path.join(model_layer_dir, param_name + '.' + type_postifx)
        
        #print(model_layer_dir)
        #print(param_save_target) 
        if len(v.shape) > 0:
            print(v.dtype)
            if v.dtype == np.int32: np.savetxt(param_save_target, v, delimiter=',', fmt='%d')
            else: np.savetxt(param_save_target, v, delimiter=',')

if __name__ == '__main__':
    main()
