
import os
from pathlib import Path
from typing import List, Dict
import random
import warnings

import torch

import bypass.core.models

from bypass.core.prune_depgraph import BypassActivationPruner
from bypass.core.activation import ActivationForDx2, ActivationForBypass
from bypass.utils import CROSS_CHANNEL_LAYERS
#from bypass.configs import BypassConfig

import torch_pruning as tp
import torch_pruning.ops as ops
from torch_pruning.dependency import Node, Group

from timm.models.vision_transformer import Attention
from bypass.core.models.attn_dev.pruned_attn import PrunedAttention_ind

# from bypass.core.v2.dependency import Node,Group,DependencyGraph

DG_NON_DETECTABLE={Attention:'qkv',PrunedAttention_ind:'qkv'}



def is_leaf_module(module):
    return len(list(module.children())) == 0

def detect_attn_cell(layer_name:str,layer:Attention):
    ret={
        'D':layer.qkv,
        'W':layer.qkv.activation,
        'dep':{layer_name:{'inputs':[],'outputs':[],'prune_io':2}}, # prune_io 0:out, 1:in, 2:mid, 3:head(notimplemented)
        'reg_target':[[0,layer_name+'.qkv.weight',layer.qkv.weight]]
         }
    if hasattr(layer.qkv,'bias'):
        ret['reg_target'].append([0,layer_name+'.qkv.bias',layer.qkv.bias])
    return ret
INDIVIDUAL_DETECT_FUNCTIONS={Attention: detect_attn_cell}

class TPDetectorForBypass:
    detected_groups:Dict[int, Node] = {}
    detected_bypass_cells:Dict[int,dict] = {}
    module2name = {}
    
    def __init__(self,reg_target_type = 'multi') -> None:
        self.DG = tp.DependencyGraph()
        self.reg_target_type =  reg_target_type

        
    def detect(self,model:torch.nn.Module,input_shape=(3,32,32), device='cpu'):
        #self.register_module_to_DG(model)
        self.DG.build_dependency(
            model, ignored_params=[],
            customized_pruners={ActivationForDx2:BypassActivationPruner()},
            example_inputs=torch.zeros(1,*input_shape).to(device=device))

        ignored_layers = sum([
                [(name,mod,head_attr_key) for name, mod in model.named_modules() if isinstance(mod,non_detect_cls)] for non_detect_cls,head_attr_key in DG_NON_DETECTABLE.items()
                ],[])
        detected_groups_all = list(self.DG.get_all_groups(
            ignored_layers=[getattr(x,head_attr_key).activation if isinstance(getattr(x,head_attr_key),ActivationForBypass) else x for name, x, head_attr_key in ignored_layers]
        ))

        idx=-1
        for group in detected_groups_all:
            idx+=1
            D = self.find_D(group)
            if not D:
                continue
            self.detected_groups[idx] = group
            self.detected_bypass_cells[idx]= self.detect_bypass_cell(group,D)
        for name, mod, head_attr_key in ignored_layers:
            idx+=1
            self.detected_groups[idx] = mod
            self.detected_bypass_cells[idx] = INDIVIDUAL_DETECT_FUNCTIONS[mod.__class__](name,mod)
            
    def detect_bypass_cell(self,group:Group,D:Node):
        ret = {'D':D}
        W = self.find_W(D)
        ret['W'] = W
        # ret['post_modules'] = self.find_post_modules(group,D)
        ret['dep'] = self.summarize_group(group)
        ret['reg_target'] = getattr(self,f'find_reg_target_{self.reg_target_type}')(group,D)
        return ret
    def find_post_modules(self,group:Group,D:Node):
        search_head=D
        # if hasattr(D.module.activation,'bias'):
        #     return {D:[D]}
        def path_search(search_head:Node,cache=[]):
            if isinstance(search_head.module,CROSS_CHANNEL_LAYERS):
                return [[*cache,search_head]]
            ret=[]
            for next_node in search_head.outputs:
                # next_node = edge.target
                ret.extend(path_search(next_node,[*cache,search_head]))
            return ret
        # ret= {x[-1]:x[:-1] for x in }
        ret = path_search(search_head)
        return ret
    def summarize_group(self,group:Group):
        ret = {}
        # source_all = [x[0].source for x in group._group]
        # target_all = [x[0].target for x in group._group]
        # nodes_all = set([*source_all,*target_all])
        # for node in nodes_all:
        for edge in group._group:
            node= edge.dep.target
            prune_axis = 1 if edge.dep.handler.__name__ == 'prune_in_channels' else 0
            node_name = node.name if node._name is None else node._name
            if node_name in ret:
                continue
            tmp = {'inputs':[x.name if x._name is None else x._name for x in node.inputs],'outputs':[x.name if x._name is None else x._name for x in node.outputs],'prune_io':prune_axis}
            ret[node_name] = tmp
        
        return ret

    def find_reg_target_multi(self,group:Group,D:Node):
        # group에 연결된 모든 node 중에서 reg target을 찾는다.
        ret=[]
        for edge in group._group:
            module:torch.nn.Module = edge.dep.target.module
            prune_axis = 1 if edge.dep.handler.__name__ == 'prune_in_channels' else 0
            module_name = edge.dep.target.name if edge.dep.target._name is None else edge.dep.target._name
            # if isinstance(module,ActivationForBypass):
            #    continue 
            if isinstance(module,CROSS_CHANNEL_LAYERS) and prune_axis == 1:
                ret.append([prune_axis,'.'.join([module_name,'weight']),module.weight])
                continue
            
            for w_name, weight in module.named_parameters():
                assert weight.shape[prune_axis] == D.module.num_parameters
                if w_name  == 'bias':
                    continue
                if isinstance(module,ActivationForBypass) and w_name.startswith('delta'):
                    continue
                ret.append([prune_axis,'.'.join([module_name,w_name]),weight])
        return ret
    def find_reg_target_single(self,group:Group,D:Node):
        ret = []
        W = self.find_W(D)[0]
        for w_name, weight in W.module.named_parameters():
            if w_name  == 'bias':
                continue
            assert weight.shape[0] == D.module.num_parameters
            ret.append([0,'.'.join([W._name,w_name]),weight])
        return ret
    def find_D(self,group:Group):
        source_all = [x[0].source for x in group._group]
        target_all = [x[0].target for x in group._group]
        nodes_all = set([*source_all,*target_all])
        D_cands = [node for node in nodes_all if isinstance(node.module,ActivationForBypass)]

        if len(D_cands) == 0:
            warnings.warn(f'No D module found in group {group}')
            return False
        elif len(D_cands) > 1:
            raise ValueError(f'Multiple D modules found in group {group}')

        return D_cands[0]
    def find_W(self,D_node:Node) -> List[Node]:
        #D_node에 연결된 W_node 를 리턴
        assert isinstance(D_node.module,ActivationForBypass), f'D_node must be ActivationForBypass. Current input: {D_node}'
        D_inputs = D_node.inputs
        assert len(D_inputs) == 2, 'D_node must have 2 inputs'  
        Mul_node_cands = [node for node in D_inputs if node.grad_fn.__class__.__name__.startswith('MulBackward')]
        Permute_node_cands = sum([[node for node in mul_node.inputs if node.grad_fn.__class__.__name__.startswith('PermuteBackward')] for mul_node in Mul_node_cands],[])
        W_node_cands = sum([[node for node in permute_node.inputs if hasattr(node.module,'weight')] for permute_node in Permute_node_cands],[])
        # DECISION 250103: W_layer가 D_layer 바로 앞인 경우만 고려. add등으로 묶인 채 들어올 수 있는데 이건 나중에 구현
        if len(W_node_cands) > 1:
            raise ValueError(f'Detected multiple W_nodes: {W_node_cands}')
        if len(W_node_cands) == 0:
            raise ValueError(f'No W_node found for D_node {D_node}')
        # W_node_cands[0].outputs=[D_node]
        # D_node.inputs=[W_node_cands[0]]
        return W_node_cands
    def register_module_to_DG(self,module):
        # register all leaf modules, assuming that the leaf modules are all elementwise operations
        visited = []
        for _ ,child in module.named_modules():
            if isinstance(child,ActivationForDx2) and child.__class__ not in visited:
                self.DG.register_customized_layer(child.__class__,BypassActivationPruner())
                visited.append(child.__class__)

        unregistered_modules = self.detect_unregistered(module)
        for mod in unregistered_modules:
            self.DG.register_customized_layer(mod,None)
            warnings.warn(f'{mod} is not registered in DependencyGraph. Registering as a leaf module.')
    def detect_unregistered(self,model:torch.nn.Module):
        leaf_modules = [module for name, module in model.named_modules() if is_leaf_module(module)]
        unregistered_modules = [module.__class__ for module in leaf_modules if ops.module2type(module)==ops.OPTYPE.ELEMENTWISE]

        return set(unregistered_modules)
    def make_cell(self,D_module:ActivationForBypass):
        cell = Node(D_module)
        return cell
    def __getitem__(self, idx):
        return self.detected_groups[idx]


if __name__ == '__main__':
    configs = BypassConfig.from_preset('cifar100_BypassVGG19',bypass_type = 'mild_pruning_WD',gamma = '1e-4*t',optimizer_args={'weight_decay': 0, 'momentum': 0},delta_init= 'ConstantMultipleNorm(1.2,norm_p=1)',adw_norm_type = 'l21mean')
    model_name = configs.model_name

    model = getattr(bypass.core.models,model_name)()
    dummy_input = torch.zeros(1,*configs.input_shape)
    for name, module in model.named_modules():
        if isinstance(module,ActivationForBypass) and not isinstance(module,ActivationForDx2):
            parent_module_name = parent_module_name = '.'.join(name.split('.')[:-1])
            node_name = name.split('.')[-1]
            parent_module:torch.nn.Module = model.get_submodule(parent_module_name)
            new_activ=ActivationForDx2(module.num_parameters,module.activation)
            # new_activ.embed()
            parent_module.register_module(node_name,new_activ)
        if 'downsample' in name and isinstance(module,torch.nn.BatchNorm2d):
            parent_module_name = parent_module_name = '.'.join(name.split('.')[:-1])
            node_name = name.split('.')[-1]
            parent_module:torch.nn.Module = model.get_submodule(parent_module_name)
            new_activ=ActivationForDx2(parent_module[0].weight.shape[0],module)
            # new_activ.embed()
            # new_activ.proj()
            parent_module.register_module(node_name,new_activ)

    detector = TPDetectorForBypass()
    detector.detect(model,configs.input_shape)
    flatten_node = detector.detected_groups[1]._group[-2][0].source
    detector.DG._update_reshape_index_mapping(flatten_node)
    print(1)