from types import SimpleNamespace
from typing import Union, List, Dict
from pathlib import Path
from collections import OrderedDict
import os
import datetime
import random
import subprocess
#from bypass.core.Bypass import AVAILABLE_MIXIN
import bypass.core.models as bp_models
from bypass.utils import load_json, write_json, serializable

PRESET_ROOT=Path(__file__).parent/'presets'
class BypassConfig(SimpleNamespace):

    # Instance specifier
    saveroot=Path(__file__).parents[2]/'save'
    dateinfo=datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

    # random seeds
    torch_seed=random.randint(0,1000)
    np_seed=random.randint(0,1000)
    random_seed=random.randint(0,1000)
    
    # model configs
    bypass_type:str ='bypass'
    model_build_verbose:bool=True

    model_name='cifar10_3c3d'
    model_args:dict={}

    delta_init:str = 'ConstantMultipleNorm(1,norm_p=2)'
    large_group_delta_c0:float = 0.5
    # Dataset informations
    dataset_name:str = 'cifar10'
    input_shape=[3,32,32]
    output_dims=10
    workers:int =1
    # training configs
    
    ## Common

    # batch size
    train_batch_size:int=128
    eval_batch_size:int=128
    accumulation_steps:int = 1

    # loss and metric
    metrics:Union[str,List[str]] =["accuracy", "categorical_crossentropy"]
    loss:str = "categorical_crossentropy"
    loss_args:dict={}
    mixup_args:Union[None,dict] = None # {'mixup_preset':'snp'}

    # ADW_loss control
    adw_norm_type:str = 'l2mean' # 'l1mean','l2','l2mean','mse'

    # optimizer configs
    optimizer_type:str = "adam"
    optimizer_args:dict={'weight_decay': 0.002}
    optimizer_norm_args:dict={'weight_decay': 0}
    optimizer_bias_args:dict={'weight_decay': 0}
    init_momentum:int =  False # If True, initialize optimizer momentum when phase changes
    optimizer_delta_args:dict={'weight_decay':0}

    ## Train1
    train1_epoch:int = 1500
    train1_lr:float=0.002
    train1_scheduler_args:Union[dict,None] = None
    ## Train2
    opt1_epoch:int = 1850
    opt1_lr:float = 0.002
    opt1_scheduler_args:Union[dict,None] = None
    opt1_preserve_ratio:Union[float,None] = None
    opt1_force_prune_end:bool = True
    opt1_early_prune:Union[bool,List[int]] = False

    opt2_lr:float = 0.002
    gamma:Union[str,List[str],Dict[str,str]] = '0.03*t'
    opt2_gamma:Union[str,None,List[str],Dict[str,str]] = None
    opt2_max:int  = 3500
    opt2_prune:Union[bool,None]=False
    opt2_scheduler_args:Union[dict,None] = None
    opt2_early_prune:Union[bool,List[int]] = False
    ## Projection
    epsilon:Union[float,List[float],Dict[str,float]] = 2e-6 # bypassing epsilon
    opt2_epsilon:Union[float,None,List[float],Dict[str,float]]=None
    prune_epsilon:Union[None,str,float,List[Union[str,float]],Dict[str,float]] = None # pruning epsilon
    prune_epsilon_args:Union[None,dict] = None

    ## Train3
    train3_epoch:int = 4000
    train3_lr:float = 0.002
    train3_scheduler_args:Union[dict,None] = {'milestones':[],'gamma':0.1}

    _loaded_saveroot:Union[None,os.PathLike]=None

    # checkpointing
    save_period_epoch:Union[None,int] = None
    save_train3_best:Union[str,None] = 'acc'
    wandb_project:Union[str,None] = None
    workstation:str = 'aimlk'

    @classmethod
    def from_preset(cls,name_or_path='cifar10_3c3d',reset_dateinfo=True,**config_modif):
        try:
            preset_path=PRESET_ROOT/f'{name_or_path}.json'
            loaded_dict=load_json(preset_path)
        except FileNotFoundError:
            preset_path=Path(name_or_path)
            loaded_dict=load_json(preset_path)
            loaded_dict['_loaded_saveroot']=preset_path.parents[1]
        if reset_dateinfo:
            loaded_dict['dateinfo']=datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        loaded_dict.update(**config_modif)
        return cls(**loaded_dict)
    def gather_items(self):
        save_items={x:getattr(self,x) for x in dir(self) if not x.startswith('__')}
        ret={}
        for k,v in save_items.items():
            if isinstance(v,os.PathLike):
                ret[k]=str(v)
                continue
            if not serializable(v):
                continue
            ret[k]=v
        return ret
        # save_items={k:v for k, v in save_items.items() if serializable(v)}
        # return save_items
    def save(self,savepath=None):
        save_items=self.gather_items()
        # save_items['version'] = self.git_hash # 이미 들어감

        if savepath is not None:
            os.makedirs(Path(savepath).parent,exist_ok=True)
            write_json(save_items,savepath)
        else:
            savepath=PRESET_ROOT/f'{self.model_name}.json'
            write_json(save_items,savepath)
            
    #def validate(self):
    #    if not self.bypass_type in AVAILABLE_MIXIN.keys():
    #        raise KeyError(f"{self.__class__.__name__}.bypass_)type must be one of {list(AVAILABLE_MIXIN.keys())}")
        
    #    if not hasattr(bp_models,self.model_name):
    #        raise KeyError(f'there is no model named {self.model_name} in bypass.core.model') 

    #    return None
    @property
    def git_hash(self):
        
        with subprocess.Popen(['git', 'rev-parse' ,'--verify', 'HEAD'], stdout=subprocess.PIPE,cwd=str(Path(__file__).parent)) as proc:
            git_hash=proc.stdout.read().decode("utf-8")
        return git_hash.strip('\n')
    @property
    def summary_path(self):
        if self._loaded_saveroot is None:
            return self.saveroot/self.bypass_type/self.model_name/self.dateinfo/'tb_logs'
        else:
            return Path(self._loaded_saveroot)/self.dateinfo/'tb_logs'
    @property
    def ckpt_path(self):
        if self._loaded_saveroot is None:
            return self.saveroot/self.bypass_type/self.model_name/self.dateinfo/'save'
        else:
            return Path(self._loaded_saveroot)/self.dateinfo/'save'

if __name__ == '__main__':
    configs=BypassConfig(model_name='simplecnn')
    # configs=BypassConfig.from_preset('cifar10_2c2d')
    # print(configs.git_hash)
    # print(configs.bypass_type)
    # configs.save()
    # configs.save()
    configs.bypass_type = 'mild_pruning_W'
    configs.model_name = 'cifar10_Bypassresnet56'
    configs.dataset_name = 'benchmark_cifar10'
    configs.loss_args={'reduction':'mean'}

    configs.adw_norm_type = 'l1mean'
    configs.optimizer_type = 'sgd'
    configs.optimizer_args = {'weight_decay':5e-4,'momentum':0.9}
    configs.train1_epoch = 0
    configs.train1_lr = 1e-4
    configs.opt1_lr = 1e-4
    configs.opt2_lr = 1e-4
    configs.opt1_epoch = 1000
    configs.opt2_max = 1300
    configs.epsilon = 7e-4
    configs.prune_epsilon = 1e-2

    configs.train3_epoch = 1500
    configs.train3_lr = 1e-4
    configs.save()
    

    print(1)