import json
from typing import Union

import torch

from .torch_utils import AVAILABLE_OPTIMIZERS, AVAILABLE_LOSSES, NORM_FUNCTIONS, CHANNEL_WISE_LAYERS, CROSS_CHANNEL_LAYERS, NORMALIZE_LAYERS, name_id_dict,get_norm_vector, AVAILABLE_SCHEDULERS,OPTIM_STATE_KEYS, CrossEntropyLabelSmooth
from .numpy_utils import search_threshold, search_threshold_ratio, search_threshold_GMM
def load_json(path):
    with open(path, 'r', encoding='utf-8-sig') as f:
        return json.load(f)
def write_json(res, path):
    with open(path, 'w', encoding='utf-8-sig') as f:
        json.dump(res, f, indent=4, ensure_ascii=False)
        f.close()        
def serializable(obj):
    try:
        json.dumps(obj)
        return True
    except TypeError:
        print("Unable to serialize the object")
        return False
def search_epsilon(bypass_layer_group:dict,epsilon,epsion_kwargs=None):
    eps_kwargs={} if epsion_kwargs is None else epsion_kwargs
    epsilon,ref_vector,ref_type =gather_refs(bypass_layer_group,epsilon)

    return compute_eps(epsilon,ref_vector,ref_type,eps_kwargs)
def gather_refs(bypass_layer_group:dict,epsilon):
    if isinstance(epsilon,str) and epsilon.startswith('W_'):
        W_layer=bypass_layer_group['W'][0]
        ref_vector = get_norm_vector(W_layer.weight)
        ref_type = 'W'
    elif isinstance(epsilon,str) and epsilon.startswith('A_'):
        A_layer=bypass_layer_group['A'][0]
        ref_vector = get_norm_vector(A_layer.weight,axis=1)
        ref_type = 'A'
    elif isinstance(epsilon,str) and epsilon.startswith('C_'):
        W_layer=bypass_layer_group['W'][0]
        W_norm = get_norm_vector(W_layer.weight)
        D_activ=bypass_layer_group['D']
        D_norm = get_norm_vector(D_activ.delta)
        ref_vector = D_norm/W_norm
        ref_type = 'C'
    else:
        D_activ=bypass_layer_group['D']
        ref_vector = get_norm_vector(D_activ.delta)
        ref_type = 'D'
        if isinstance(epsilon,str):
            epsilon = 'D_'+epsilon
    return epsilon, ref_vector,ref_type

def compute_eps(epsilon,ref_vector,ref_type,eps_kwargs):
    if type(epsilon) in [int,float]:
        epsilon = epsilon
    elif isinstance(epsilon,str) and epsilon[2:].replace('.','',1).isdigit():
        epsilon = float(epsilon[2:])    
    elif epsilon.endswith('auto'):
        epsilon = search_threshold(ref_vector.cpu().detach().numpy(),**eps_kwargs)
    elif 'ratio' in epsilon:
        prune_ratio = float(epsilon.split('ratio')[-1])
        assert 0 <= prune_ratio <= 1, f'pruning ratio must lie in [0,1]. Current: {prune_ratio}'
        epsilon = search_threshold_ratio(ref_vector.cpu().detach().numpy(),1-prune_ratio,**eps_kwargs)
    elif 'GMM' in epsilon:
        epsilon =  search_threshold_GMM(ref_vector.cpu().detach().numpy(),**eps_kwargs)

    if ref_type in 'CD':
        prune_indices = torch.where(ref_vector >= epsilon)[0]
        preserve_indices = torch.where(ref_vector < epsilon)[0]

        if len(preserve_indices) == 0:
            return compute_eps(epsilon*10,ref_vector,ref_type,eps_kwargs)
            # return search_epsilon(bypass_layer_group,f'{ref_type}_ratio0.1')
            # return search_epsilon(bypass_layer_group,f'{ref_type}_GMM',epsion_kwargs=epsion_kwargs)
            # preserve_indices, prune_indices = prune_indices, preserve_indices
            epsilon = epsilon * 100
            prune_indices = torch.where(ref_vector >= epsilon)[0]
            preserve_indices = torch.where(ref_vector < epsilon)[0]

    else:
        prune_indices = torch.where(ref_vector < epsilon)[0]
        preserve_indices = torch.where(ref_vector >= epsilon)[0]
        if len(preserve_indices) == 0:
            # preserve_indices, prune_indices = prune_indices, preserve_indices
            # return search_epsilon(bypass_layer_group,f'{ref_type}_GMM',epsion_kwargs=epsion_kwargs)
            # return search_epsilon(bypass_layer_group,f'{ref_type}_ratio0.9')
            epsilon = epsilon * 100
            prune_indices = torch.where(ref_vector >= epsilon)[0]
            preserve_indices = torch.where(ref_vector < epsilon)[0]
    return ref_type, prune_indices,preserve_indices, epsilon

class GammaSchedule():
    def __init__(self,gamma_formula:str):
        assert isinstance(gamma_formula,str)
        assert 't' in gamma_formula or gamma_formula.isnumeric(), '[ERROR]gamma formula should be given with equation of variable t'
        self.gamma_formula=gamma_formula

        if 'np.' in gamma_formula:
            import numpy as np
        if 'math.' in gamma_formula:
            import math
        self.global_step=0
        return None
    # def __call__(self,gamma_dict:dict,step:int):
    #     gamma_value=eval(self.gamma_formula.replace('t',str(self.global_step)))
    #     self.global_step+=1
    #     return {k:gamma_value for k in gamma_dict.keys()}
    def __call__(self,step:Union[None,int]=None):
        if step is None:
            ret= eval(self.gamma_formula.replace('t',str(self.global_step)))
            self.global_step+=1
        else:
            ret =eval(self.gamma_formula.replace('t',str(step)))
        return ret

    def __str__(self):
        return self.gamma_formula
    def __repr__(self):
        return 'GammaSchedule(%s)'%self.gamma_formula
