import os
from typing import List, Union, Dict
from pathlib import Path
from dataclasses import dataclass, field
import warnings
import logging


import torch
from torch_pruning.dependency import Node
import sys
sys.path.append('/workspace/IPPRO_pruning')
from bypass.core.activation import ActivationForBypass, ActivationForActivionChange, _is_channel_wise, ActivationForDx2
from bypass.core.detect import TPDetectorForBypass
from bypass.core.prune_depgraph import prune_tp, BypassActivationPruner
from bypass.utils import GammaSchedule, NORM_FUNCTIONS, load_json, write_json, get_norm_vector
from bypass.configs import BypassConfig

BLUEPRINT_PRESET_ROOT = Path(__file__).parent/'blueprints'

@dataclass
class BlueprintCell:
    '''
    module에서 W,D,A가 될 것들의 name을 가진 class. module.get_submodule로 불러올 수 있어야 함

    한개의 cell은 한개의 pruning group에 대응. 

    reg_target: 
    '''
    W:Union[str,List[str]]
    D:str
    A:Union[str,List[str]]
    dep:Dict[str,Dict[str,List[str]]]
    reg_target:Union[None,List[str]]

    ## DEPRECATED:250109
    # pre_modules:List[str] # D보다 앞의 module들. W를 포함. pre_modules[-1] == W
    # post_modules: Dict[str,List[str]]  # modules between D and each A. key:A, value: list of intermediate modules
    
    def validate(self,model:torch.nn.Module)-> bool:
        '''
        각 module들이 model의 submodule로 들어있는지 판단
        '''
        module_names_all = [*[getattr(self,item_name) for item_name in 'ADW'],*self.pre_modules,*sum(self.post_modules.values(),[])]
        if self.reg_target is not None:
            module_names_all.extend(self.reg_target)
        for module_name in module_names_all:
            try:
                model.get_submodule(module_name)
            except:
                return False
        return True
class Blueprint:
    '''
    BlueprintCell 들의 집합. 혹시 reg_target이 지정되지 않았을 경우 ㅇ
    '''
    def __init__(self, blueprint_dict):
        self.bypass_layers={}
        for k, item in blueprint_dict.items():
            if isinstance(item,BlueprintCell):
                cell = item
            else:
                cell = BlueprintCell(**item)
            if cell.reg_target is None: # reg_target이 없을 경우 W 하나만 사용
                if isinstance(cell.W,(tuple,list)):
                    cell.reg_target = [(0,x.weight) for x in cell.W]
                else:   
                    cell.reg_target = [(0,cell.W.weight)] 
            self.bypass_layers[k] = cell
    def validate(self,model:torch.nn.Module):
        '''
        모든 cell의 module들이 model의 submodule로 들어있는지 판단
        '''
        for k, cell in self.bypass_layers.items():
            valid= cell.validate(model)
            if not valid:
                raise ValueError(f'Model and Blueprint inconsistent: group {k}')
        return True

        
    @classmethod
    def load_from(cls,blueprint_path:os.PathLike):
        loaded_dict = {(int(k) if k.isnumeric() else k):v for k, v in load_json(blueprint_path).items()}
        return cls(loaded_dict)
    
    def save(self,savepath:os.PathLike):
        to_save = {k:cell.__dict__ for k, cell in self.bypass_layers.items()}
        write_json(to_save,savepath)
    
class BaseBypassUnit(torch.nn.Module):
    bypassing_status=(0,1,2) # default: meaningless
    def __init__(self,D,W,dep,reg_target=None,A=None,status=0):
        super().__init__()
        self.D = D.module if isinstance(D,Node) else D
        self.A = A[0].module if isinstance(A,list) and isinstance(A[0],Node) else A
        self.W = W[0].module if isinstance(W,list) and isinstance(W[0],Node) else W
        ## DEPRECATED:250109
        # self.pre_modules = pre_modules # modules between W and D
        # self.post_modules = post_modules # modules between D and A
        self.dep=dep
        self.reg_target = reg_target
        self.status = status # 0: train1,train3, 1: opt1, 2: opt2
        self.active=True
        return None
    @staticmethod
    def gather_submodules(model:torch.nn.Module, items):
        if isinstance(items,str):
            return model.get_submodule(items)
        if isinstance(items,(tuple,list)):
            return [model.get_submodule(x) for x in items]
        if isinstance(items,dict):
            return {k:model.get_submodule(v) for k, v in items.items()}
    @classmethod
    def from_blueprint(cls,model:torch.nn.Module,blueprint_cell:BlueprintCell):
        '''
        multi-layer일 때는 이 메소드 사용하여 지정
        '''
        
        D = cls.gather_submodules(model,blueprint_cell.D)
        W = cls.gather_submodules(model,blueprint_cell.W)
        if blueprint_cell.A is not None:
            A = cls.gather_submodules(model,blueprint_cell.A)
        else:
            A=None
        ## DEPRECATED:250109
        # pre_modules = torch.nn.ModuleList([model.get_submodule(name) for name in blueprint_cell.pre_modules])
        # post_modules ={model.get_submodule(k):cls.gather_submodules(model,v) for (k, v) in blueprint_cell.post_modules.items()}

        reg_target = [[io, name, model.get_parameter(name)] for (io, name) in blueprint_cell.reg_target] if blueprint_cell.reg_target is not None else None
        # dep = {cls.gather_submodules(model,k):cls.gather_submodules(model,v) for k,v in blueprint_cell.dep.items()}
        dep = blueprint_cell.dep
        ret = cls(D=D,W=W,A=A,dep=dep,reg_target=reg_target,status=0)
        ret.blueprint = blueprint_cell
        return ret
    def re_init(self,model:torch.nn.Module):
        new_reg_target = []
        for prune_io, param_name, param in self.reg_target:
            new_param = model.get_parameter(param_name)
            new_reg_target.append((prune_io,param_name,new_param))
        # if hasattr(self,'blueprint'):
        #     bp = self.blueprint
        # else:
        #     bp = BlueprintCell(**self.to_blueprint(model.module2name))
        # self = self.__class__.from_blueprint(model,bp)
        self.reg_target = new_reg_target

    # @property # torch.nn.Module.__getattr__ has issue with property. same with https://github.com/pytorch/pytorch/issues/49726
    def bypass(self): # ADW loss를 필요로 하는 상태
        return self.status in self.bypassing_status and self.active
    # @property # torch.nn.Module.__getattr__ has issue with property. same with https://github.com/pytorch/pytorch/issues/49726
    def num_channels(self): # number of channels : out_channel of W == in_channel if A; pruned일 경우 남아 있는 channels
        return self.D.delta.shape[0]
        # num_input_channel_A =  self.A.weight.shape[1] if self.A.weight.dim()>1 else self.A.weight.shape[0]
        # num_output_channel_W = self.W.weight.shape[0]
        # if num_input_channel_A!=num_output_channel_W:
        #     warnings.warn(f'num_channels mismatch: {num_input_channel_A} ({self.blueprint.A})')
        # return num_output_channel_W
    #@property #torch.nn.Module.__getattr__ has issue with property. same with https://github.com/pytorch/pytorch/issues/49726
    def numel_target(self):
        # N of paper
        total_cnt = sum([x[1].numel() for x in self.reg_target])
        num_channels = self.num_channels()
        assert total_cnt % num_channels == 0, f'number of parameters should be divided by num_channels: ({total_cnt}/{num_channels})'

        return total_cnt // num_channels


    def set_hparams(self,configs:BypassConfig,global_steps=0):
        if self.status !=2: # opt2일 경우 opt2_gamma 사용
            self.gamma = GammaSchedule(configs.gamma+f'/{len(self.reg_target)}')
            self.gamma.global_step =  global_steps
            self.epsilon = self._process_str_hparam(configs.epsilon)
        else:
            self.gamma = GammaSchedule(configs.opt2_gamma+f'/{len(self.reg_target)}')
            self.gamma.global_step = global_steps
            self.epsilon = self._process_str_hparam(configs.opt2_epsilon)
        self.pruning_epsilon = self._process_str_hparam(configs.prune_epsilon)

        # norm function for |ADW|
        norm_type = configs.adw_norm_type
        self.norm_function = lambda x: NORM_FUNCTIONS[norm_type](x,list(range(x.dim())))
    @staticmethod
    def _process_str_hparam(item:str):
        if not isinstance(item,str):
            return float(item)
        # assert isinstance(item,str), '_process_str_hparam only process str'
        try:
            return eval(item)
        except NameError:
            return item

    def ADW_loss(self):
        return NotImplemented #subclass에서 define
    
    def to_blueprint(self,module2name):
        if hasattr(self,'blueprint'):
            return self.blueprint    
        tmp ={}
        for k in 'ADW':
            v = getattr(self,k)
            if v is None:
                tmp[k] = None
            elif isinstance(v,(tuple,list,torch.nn.ModuleList)):
                tmp[k]=[module2name[x] for x in v]
            elif isinstance(v,dict):
                tmp[k] = {kk:module2name[vv] for kk, vv in v.items()}
            else:
                tmp[k] = module2name[getattr(self,k)]
        ## DEPRECATED:250109
        # tmp['post_modules'] = {module2name[kk]:[module2name[x] for x in vv] for kk,vv in self.post_modules.items()}
    
        tmp['reg_target'] = [(prune_io,weight_name) for prune_io,weight_name,weight in self.reg_target]
        # tmp['dep'] = {k:{kk:module2name[vv] for kk,vv in v.items()} for k,v in self.dep.items()}
        tmp['dep'] = self.dep
        return tmp
    
    def ref_tensor(self,norm_p=1):
        return NotImplemented

class BaseAddon(torch.nn.Module): # BaseMixin 역할
    default_device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    unit_cls: BaseBypassUnit = BaseBypassUnit
    extract_fn_name:str = 'extract_single_sequential'
    bypass_activation_cls = ActivationForBypass
    bypass_units = {}
    reg_target_type='single'
    def build_from_blueprint(self, blueprint: Blueprint):
        for k, cell in blueprint.bypass_layers.items():
            self.bypass_units[k] = self.unit_cls.from_blueprint(self,cell)
        return None
    def build_from_detector(self,input_shape,detector=None):
        if detector is None:
            detector = TPDetectorForBypass(reg_target_type =self.reg_target_type)
        detector.detect(self,input_shape)
        for idx, detected_cell in detector.detected_bypass_cells.items():
            self.bypass_units[idx] = self.unit_cls(**detected_cell)
    def set_gamma_steps(self,gamma_steps:dict):
        for idx, gamma_step in gamma_steps.items():
            self.bypass_units[idx].gamma.global_step = gamma_step


    def to_blueprint(self):
        ret = {}
        weight2name = {p:name for name,p in self.named_parameters()}
        for ind, unit in self.bypass_units.items():
            if hasattr(unit,'blueprint'):
                ret[ind] = unit.blueprint
                continue
            ret[ind] =  unit.to_blueprint(self.module2name)
        return Blueprint(ret)

    def set_norm_function(self,norm_type):
        # norm_type = self.configs.adw_norm_type
        norm_func = NORM_FUNCTIONS[norm_type]
        self.norm_function = lambda x : norm_func(x,list(range(1,x.dim())))
        for ind, unit in self.bypass_units.items():
            unit.norm_function = self.norm_function

class BypassUnit(BaseBypassUnit):
    bypassing_status=(2,) # only opt2 is bypassing
    def ADW_loss(self):
        W_mat =  self.W.weight
        D_vec = self.D.delta
        DW_mat = D_vec*W_mat.T
        if hasattr(self.A,'weight'):
            A_mat = self.A.weight
            if DW_mat.dim() == 1:
                broadcast = [None for _ in A_mat.shape]
                broadcast[1] = 1             
                DW_mat = A_mat * DW_mat[broadcast]
            else:
                DW_mat = A_mat * DW_mat
        return self.norm_function(DW_mat)
    def ref_tensor(self, norm_p=1):
        return NotImplemented
class BypassAddon(BaseAddon):
    # extract_fn_name = 'extract_single_sequential'
    unit_cls = BypassUnit
    bypass_activation_cls = ActivationForBypass

class CatalystPruningUnit_W(BaseBypassUnit):
    bypassing_status = (1,2) # opt1, opt2 is bypassing
    def ADW_loss(self):
        return sum([self.ADW_loss_individual_out(w[1]) if w[0]==0 else self.ADW_loss_individual_in(w[1]) for w in self.reg_target])
    def ADW_loss_individual_out(self,w):
        delta = self.D.delta
        DW_mat = torch.einsum('c...,c->c...',w,delta)
        return self.norm_function(DW_mat)
    def ADW_loss_individual_in(self,a):
        delta = self.D.delta
        AD_mat = torch.einsum('bc...,c->cb...',a,delta)
        return self.norm_function(AD_mat)
    def ref_tensor(self,norm_p=2): # d_init을 위한 reference tensor
        tmp = []
        for prune_io, param_name, weight in self.reg_target:
            if weight.dim() == 1:
                assert prune_io == 0
                tmp.append(weight**2)
            else:
                tmp.append(torch.sum(weight**2,dim=[x for x in range(weight.dim()) if x!=prune_io]))
        return torch.sqrt(sum(tmp))
class CatalystPruningUnit_W_Attn(BaseBypassUnit):
    bypassing_status = (1,) # opt1 is bypassing
    def ref_tensor(self, norm_p=1):
        return NotImplemented

class CatalystPruningAddon(BaseAddon):
    # extract_fn_name = 'extract_single_sequential'
    unit_cls = CatalystPruningUnit_W
    bypass_activation_cls = ActivationForDx2
    reg_target_type='single'
class CatalystMultiPruningAddon(BaseAddon):
    # extract_fn_name = 'extract_multi_sequential'
    unit_cls = CatalystPruningUnit_W
    bypass_activation_cls = ActivationForDx2
    reg_target_type = 'multi'

AVAILABLE_ADDONS = {
    'base':BaseAddon,
    'bypass':BypassAddon,
    'mild_pruning_W':CatalystPruningAddon,
    'mild_pruning_WD':CatalystMultiPruningAddon
}
def model_factoryV2(base,bypass_type,input_shape,verbose=False,norm_type='l2mean',blueprint:Union[None,Blueprint]=None,*model_args,**model_kwargs):
    mixin=AVAILABLE_ADDONS[bypass_type]
    name=f'{base.__name__}_{bypass_type}'
    # model=base.__class__(name, (mixin, base), {})(*model_args,**model_kwargs)
    model:BypassAddon=type(name, (mixin, base), {})(*model_args,**model_kwargs)
    model.eval()
    model.shape_info = ShapeWatcher()
    for name, module in model.named_modules():
        # if isinstance(module, ActivationForBypass) and not isinstance(module, model.bypass_activation_cls):
        #     parent_module_name = parent_module_name = '.'.join(name.split('.')[:-1])
        #     node_name = name.split('.')[-1]
        #     parent_module:torch.nn.Module = model.get_submodule(parent_module_name)
        #     parent_module.register_module(node_name,model.bypass_activation_cls(module.num_parameters,module.activation))
        #     logging.info(f'ActivationForBypass{name} detected. replacing....result: {type(model.get_submodule(name))}')
        if bypass_type == 'mild_pruning_WD' and 'downsample' in name and isinstance(module,torch.nn.BatchNorm2d):
            parent_module_name = parent_module_name = '.'.join(name.split('.')[:-1])
            node_name = name.split('.')[-1]
            parent_module:torch.nn.Module = model.get_submodule(parent_module_name)
            if isinstance(parent_module,ActivationForBypass):
                continue
            new_activ=ActivationForDx2(parent_module[0].weight.shape[0],module)
            # new_activ.embed()
            # new_activ.proj()
            parent_module.register_module(node_name,new_activ)
        if isinstance(module,torch.nn.Flatten):
            module.register_forward_hook(model.shape_info)
    if blueprint is not None:
        model.build_from_blueprint(blueprint)
        with torch.no_grad():
            dummy =  torch.rand([1,*input_shape])
            model.forward(dummy)
    else:
        model.build_from_detector(input_shape) # detector는 인스턴스를 주기때문에 blueprint를 거치면 name으로 변환하고 다시 찾아와야 하므로 비효율적
    model.set_norm_function(norm_type)
    model.module2name = {mod:name for name, mod in model.named_modules()}
    return model

class ShapeWatcher:
    def __init__(self):
        self.data = {}
    def __call__(self, module,inp,output,name=None):
        self.data[module] = [inp[0].shape,output.shape]
        return None

if __name__ == '__main__':
    # cell = BlueprintCell(W='conv1',mid_modules=['relu1','bn1'],D='relu1',A='conv2')

    # cell_rep = repr(cell)
    # print(cell_rep)



    from bypass.core.models import cifar10_BypassBNresnet56,cifar10_NaiveBypassBNresnet56, cifar10_Bypassresnet56,cifar100_BypassVGG19,imagenet_BypassResnet50,imagenet_DeiTBase
    model=model_factoryV2(imagenet_DeiTBase,'mild_pruning_WD',[3,224,224],blueprint=None,verbose=False,norm_type='l21').cpu()
    # model=model_factoryV2(imagenet_BypassResnet50,'mild_pruning_WD',[3,224,224],blueprint=None,verbose=False,norm_type='l1mean').cpu()

    # model.bypass_units[0].D.activation
    # model.bypass_units[0].D.bias
    # model.load_state_dict(model.state_dict())
    
    # model.bypass_units[0].ADW_loss()
    # model.to_blueprint()

    savepath = BLUEPRINT_PRESET_ROOT/ str(model.__class__.__name__)
    model.to_blueprint().save(savepath.with_suffix('.json'))
    print(1)