import argparse
import os
import numpy as np

import torch.nn as nn
import torch.nn.functional as F
import torch
import math
import yaml

def add_to_config(mydict, cfl):
    with open(cfl, 'a') as configfile:
        data = yaml.dump(mydict, configfile, indent=4)
        print("Write successful")

def load_config(file_path):
    with open(file_path, "r") as f:
        return yaml.safe_load(f)

def set_state_dict(std, weights):
    # std = model.state_dict()
    st = 0
    for params in std:
        if not params.endswith('num_batches_tracked'):
            shape = std[params].shape
            ed = st + np.prod(shape)
            std[params] = weights[st:ed].reshape(shape)
            # model.load_state_dict(std)
            st = ed
    return std

def gets_weights(std):
    # std = model.state_dict()
    weights = []
    normalize_factor = {}
    #     std = model.state_dict()
    for params in std:
        if not params.endswith('num_batches_tracked'):
            if 'mean' in params or 'var' in params:
                continue
            # print(params)
            w = std[params].reshape(-1)
            weights.append(w)
    return torch.cat(weights, -1)


def get_model_norm(std):
    weights = []
    normalize_factor = {}
    # std = model.state_dict()
    idx=0
    layers_index={}
    for param in std:
        if not param.endswith('num_batches_tracked'):
            w = std[param].reshape(-1).detach().cpu()
            ed = idx + w.shape[-1]
            layers_index[param] = {'idx_start': idx, 'idx_end': ed}
            x_min = w.min()
            x_max = w.max()
            # factor = max(abs(x_max), abs(x_min))
            # # if factor>1.0:
            if x_max-x_min==0.0:
                w = torch.exp(w - x_min) / torch.exp(x_max - x_min)
            else:
                w = (w - x_min) / (x_max - x_min)
            weights.append(w)
            normalize_factor[param] = [x_max, x_min]
            idx = ed
    # weights = torch.cat()
    return weights, normalize_factor, layers_index


def set_norm(model, weights, normalize_factor):
    std = model.state_dict()
    st = 0
    ed = 0
    # w =[]
    for param in std:
        if not param.endswith('num_batches_tracked'):
            shape = std[param].shape
            ed = st + np.prod(shape)
            x_max = normalize_factor[param][0]
            x_min = normalize_factor[param][1]
            # if x_max-x_min==0.0:
            #     std[param] = torch.log(weights[st:ed].reshape(shape) * torch.exp(torch.tensor([x_max - x_min]))+1e-45) + x_min
            # else:
            std[param] = weights[st:ed].reshape(shape) * (x_max - x_min) + x_min
            model.load_state_dict(std)
            st = ed



def get_weights(std):
    # std = model.state_dict()
    weights = []
    normalize_factor = {}
    #     std = model.state_dict()
    for params in std:
        if not params.endswith('num_batches_tracked'):
            # print(params)
            w = std[params].reshape(-1)
            weights.append(w)
    return torch.cat(weights, -1)

def set_mymodel_weights(model, weights):
    std = model.state_dict()
    st = 0
    for params in std:
        if not params.endswith('num_batches_tracked'):
            if 'mean' in params or 'var' in params:
                continue
            shape = std[params].shape
            ed = st + np.prod(shape)
            std[params] = weights[st:ed].reshape(shape)
            model.load_state_dict(std)
            st = ed

def set_model_weights(model, weights):
    std = model.state_dict()
    st = 0
    for params in std:
        if not params.endswith('num_batches_tracked'):
            if params.endswith('running_var') or params.endswith('running_mean'):
                continue
            # elif 'linear' in params:
            #     continue
            shape = std[params].shape
            ed = st + np.prod(shape)
            std[params] = weights[st:ed].reshape(shape)
            model.load_state_dict(std)
            st = ed
    return model


def set_bnmodel_weights(model, weights, cls=20):
    std = model.state_dict()
    st = 0
    for params in std:
        if not params.endswith('num_batches_tracked'):
            if 'mean' in params:
                std[params] = std[params] * 0.0
            elif 'var' in params:
                std[params] = std[params] * 0.0 + 1
            else:
                shape = std[params].shape
                ed = st + np.prod(shape)
                std[params] = weights[st:ed].reshape(shape)
                model.load_state_dict(std)
                st = ed


def set_lnmodel_weights(model, weights, cls=20):
    std = model.state_dict()
    st = 0
    for params in std:
        if not params.endswith('num_batches_tracked'):
            if 'mean' in params:
                std[params] = std[params] * 0.0
            elif 'var' in params:
                std[params] = std[params] * 0.0 + 1
            elif 'linear' in params:
                continue
            else:
                shape = std[params].shape
                ed = st + np.prod(shape)
                std[params] = weights[st:ed].reshape(shape)
                model.load_state_dict(std)
                st = ed


def set_fbnmodel_weights(model, weights):
    std = model.state_dict()
    st = 0
    for params in std:
        if not params.endswith('num_batches_tracked'):
            if 'mean' in params:
                std[params] = std[params] * 0.0
            elif 'var' in params:
                std[params] = std[params] * 0.0 + 1.0
            else:
                shape = std[params].shape
                ed = st + np.prod(shape)
                n = np.prod(shape)
                m = len(weights[st:])
                if m>=n:
                    std[params] = weights[st:ed].reshape(shape)
                    model.load_state_dict(std)
                    st = ed


def set_weights(model, weights):
    std = model.state_dict()
    st = 0
    for params in std:
        if not params.endswith('num_batches_tracked'):
            shape = std[params].shape
            ed = st + np.prod(shape)
            std[params] = weights[st:ed].reshape(shape)
            model.load_state_dict(std)
            st = ed


def set_theweights(model, weights):
    std = model.state_dict()
    st = 0
    for params in std:
            shape = std[params].shape
            ed = st + np.prod(shape)
            try:
                std[params] = weights[st:ed].reshape(shape)
            except:
                # continue

                ws = weights[st:]
                w = std[params].reshape(-1)
                w[:ws.shape[-1]]= ws
                std[params] = w.reshape(shape)
            model.load_state_dict(std)
            st = ed

def set_fweights(model, weights):
    std = model.state_dict()
    st = 0
    for params in std:
        if not params.startswith('fc3'):
            shape = std[params].shape
            ed = st + np.prod(shape)
            std[params] = weights[st:ed].reshape(shape)
            model.load_state_dict(std)
            st = ed


def vecpadder(x, max_in=3728761 * 3):
    shape = x.shape
    delta1 = max_in - shape[0]
    x = F.pad(x, (0, delta1))
    return x


def get_model_normweights(model):
    weights = []
    normalize_factor = {}
    std = model.state_dict()
    for param in std:
        if not param.endswith('num_batches_tracked'):
            w = std[param].reshape(-1)
            x_min = w.min().item()
            x_max = w.max().item()
            factor = max(abs(x_max), abs(x_min))
            # if factor>1.0:
            w = (w - x_min) / (x_max - x_min)
            weights.append(w)
            normalize_factor[param] = [x_max, x_min]
    return weights, normalize_factor


def set_normweights(model, weights, normalize_factor):
    std = model.state_dict()
    st = 0
    ed = 0
    # w =[]
    for param in std:
        if not param.endswith('num_batches_tracked'):
            shape = std[param].shape
            ed = st + np.prod(shape)
            x_max = normalize_factor[param][0]
            x_min = normalize_factor[param][1]
            # factor = max(abs(torch.tensor(normalize_factor[name + '-' + params]))).item()
            # if factor>1:
            std[param] = weights[st:ed].reshape(shape) * (x_max - x_min) + x_min
            model.load_state_dict(std)
            st = ed

# def pad_to_chunk_multiple(x, chunk_size):
#     shape = x.shape
#     max_in = chunk_size*math.ceil(shape[0]/chunk_size)
#     delta1 = max_in - shape[0]
#     x = F.pad(x, (0, delta1))
#     return x
def pad_to_chunk_multiple(x, chunk_size):
    shape = x.shape
    if len(shape)<2:
        x =x.unsqueeze(0)
        shape = x.shape
    max_in = chunk_size*math.ceil(shape[1]/chunk_size)
    delta1 = max_in - shape[1]
    # x = F.pad(x, (0, delta1))
    x =F.pad(x, (0, delta1, 0, 0), "constant", 0)
    return x

def matpadder(x, max_in=512):
    shape =x.shape
    # delta1 = max_in - shape[0]
    delta2 = max_in - shape[1]

    out = F.pad(x, (0, delta2, 0, 0), "constant", 0)
    return out