import os
from pathlib import Path
from typing import List, Dict
import random
import warnings

import torch
import torch.nn as nn
import typing

import sys
sys.path.append('/workspace/jaeheun_MildPruning')

import bypass.core.models

from bypass.core.prune_depgraph import BypassActivationPruner
from bypass.core.activation import TrivialActivationForBypass,TrivialActivationForDx2,ActivationForDx2,ActivationForBypass
from bypass.utils import CROSS_CHANNEL_LAYERS
from bypass.configs import BypassConfig

import torch_pruning as tp
import torch_pruning.ops as ops
from torch_pruning.pruner import function
from torch_pruning.dependency import Node, Group


class GetImportance:

    def __init__(self) -> None:
        self.detected_W = {}
        self.detected_D = {}
        self.detected_WI = {}
        
        self.detected_WG = {}
        self.detected_DG = {}
        self.detected_WIG = {}
        
        self.idx_in = {}
        
        self.importances = {}
        
        self.target_types = list=[nn.modules.conv._ConvNd, ActivationForDx2, nn.Linear]#, nn.modules.batchnorm._BatchNorm, nn.LayerNorm]
        
        self.group_reduction = 'mean'
        self.normalizer = 'mean'
        
        self.p = 2
        self.bias = True
        self.cnt = 0
    @torch.no_grad()
    def accumul_grad(self, weight, up_cand, idx):
        for j in range(len(weight[str(idx)])):
            
            weight[str(idx)][j] = weight[str(idx)][j] + up_cand[j]
        
        return weight
        
        
    @torch.no_grad()
    def detect(self, detected_groups_all, lamda):
        cnt = 0

        for idx, group in enumerate(detected_groups_all):
            local_weight, local_weight_grad, local_weight_in, local_weight_in_grad, local_delta, local_delta_grad, weight_idxs_in = self.calculate(group, lamda)
            
            if self.cnt == 0:
                self.detected_W[str(idx)] = local_weight
                self.detected_WG[str(idx)] = local_weight_grad
                
                self.detected_WI[str(idx)] = local_weight_in
                self.detected_WIG[str(idx)] = local_weight_in_grad
                
                self.detected_D[str(idx)] = local_delta
                self.detected_DG[str(idx)] = local_delta_grad

                self.idx_in[str(idx)] = weight_idxs_in
                
            else:
                self.detected_WG = self.accumul_grad(self.detected_WG, local_weight_grad, idx)
                self.detected_WIG = self.accumul_grad(self.detected_WIG, local_weight_in_grad, idx)
                self.detected_DG = self.accumul_grad(self.detected_DG, local_delta_grad, idx)
            
        self.cnt += 1
        self.idx = idx
        return None

    def get_impscore(self, factor, lamda):
        
        for i in range(self.idx+1):
            
            local_weight_grad = []
            local_delta_grad = []
            weight_imps = []
            delta_imps = []
            weight_in_imps = []
            
            for j in range(len(self.detected_WG[str(i)])):
                weight_grad = self.detected_WG[str(i)][j]
                delta_grad = self.detected_DG[str(i)][j]
                
                local_weight_grad.append(weight_grad)
                local_delta_grad.append(delta_grad)
                
            local_weight = self.detected_W[str(i)]
            local_delta = self.detected_D[str(i)]
            
            for j in range(len(self.detected_WG[str(i)])):
                weight_imp = self.get_norm_vector((local_weight[j] - lamda * factor * local_weight_grad[j]).abs().sum(1), axis=0, p=2)
                delta_imp = self.get_norm_vector((local_delta[j] - lamda * factor * local_delta_grad[j]).abs(), axis=0, p=2)
                weight_imps.append(weight_imp)
                delta_imps.append(delta_imp)
            
            if len(weight_imps) > 2:
                local_weight_grad_reindex = [weight_imps[0]]
                weight_imps.pop(0)
                weight_imps.reverse()
                for k in range(len(weight_imps)):
                    local_weight_grad_reindex.append(weight_imps[k])
                
            else:
                local_weight_grad_reindex = weight_imps
                    
            imps = []
            for j in range(len(local_weight_grad_reindex)):
                imps.append(torch.div(local_weight_grad_reindex[j], delta_imps[j]))
            
            imps = torch.stack(imps, 0)
            imps = torch.mean(imps, 0)
            
            
            
            for j in range(len(self.detected_WIG[str(i)])):
                in_imp = self.get_norm_vector((self.detected_WI[str(i)][j] - lamda * factor * self.detected_WIG[str(i)][j]).abs().sum(1), axis=0, p=2)
                weight_in_imps.append(in_imp)
            
            weight_in_imps = self._reduce(weight_in_imps, self.idx_in[str(i)])
            weight_in_imps = self._normalize(weight_in_imps, self.normalizer)
            
            local_imps = [imps, weight_in_imps]
            local_imps = self._reduce(local_imps, self.idx_in[str(i)][:2])
            
            self.importances[str(i)] = local_imps
            
        return self.importances
    
    def return_score(self):
        return self.importances
    
    @torch.no_grad()
    def calculate(self, group: Group, lamda):
        local_weight = []
        local_weight_grad = []
        local_delta = []
        local_delta_grad = []

        local_weight_in = []
        local_weight_in_grad = []
        
        weight_idxs = []
        weight_idxs_in = []
        delta_idxs = []
        
        # Iterate over all groups and estimate group importance
        for i, (dep, idxs) in enumerate(group):
            
            layer = dep.layer
            prune_fn = dep.pruning_fn
            root_idxs = group[i].root_idxs
            if not isinstance(layer, tuple(self.target_types)):
                continue
            ####################
            # Conv/Linear Output
            ####################

            if isinstance(layer, ActivationForBypass):
                w = layer.delta[idxs]
                dw = layer.delta.grad.data[idxs]
                local_delta.append(w)
                #delta_imp = self.get_norm_vector((w - lamda*dw), axis=0, p=2)
                local_delta_grad.append(dw)
                delta_idxs.append(root_idxs)
                


            if prune_fn in [
                function.prune_conv_out_channels,
                function.prune_linear_out_channels,
            ]:
                if hasattr(layer, "transposed") and layer.transposed:
                    w = layer.weight.data.transpose(1, 0)[idxs].flatten(1)
                    dw = layer.weight.grad.data.transpose(1, 0)[
                        idxs].flatten(1)
                else:
                    w = layer.weight.data[idxs].flatten(1)
                    dw = layer.weight.grad.data[idxs].flatten(1)

                weight_imp = self.get_norm_vector((w - lamda*dw).abs().sum(1), axis=0, p=2)
                
                local_weight.append(w)
                local_weight_grad.append(dw)
                weight_idxs.append(root_idxs)

                
            ####################
            # Conv/Linear Input
            ####################
            
            elif prune_fn in [
                function.prune_conv_in_channels,
                function.prune_linear_in_channels,
            ]:
                if hasattr(layer, "transposed") and layer.transposed:
                    w = (layer.weight.data).flatten(1)
                    dw = (layer.weight.grad).flatten(1)
                else:
                    w = (layer.weight.data).transpose(0, 1).flatten(1)
                    dw = (layer.weight.grad).transpose(0, 1).flatten(1)

                # repeat importance for group convolutions
                if prune_fn == function.prune_conv_in_channels and layer.groups != layer.in_channels and layer.groups != 1:
                    w = w.repeat(layer.groups)
                    dw = dw.repeat(layer.groups)

                weight_imp = self.get_norm_vector((w - lamda*dw).abs().sum(1), axis=0, p=2)
                
                local_weight_in.append(w)
                local_weight_in_grad.append(dw)
                weight_idxs_in.append(root_idxs)
        
        return local_weight, local_weight_grad, local_weight_in, local_weight_in_grad, local_delta, local_delta_grad, weight_idxs_in
    
    
    def get_norm_vector(self, tensor:torch.Tensor,axis=0,p=1,mean:bool=False):
        if tensor.dim() == 1:
            return abs(tensor)
        else:
            num_channels =  tensor.shape[axis]
            ret=tensor.transpose(0,axis).reshape(num_channels,-1)
            num_elements = ret.shape[1]
            return ret.norm(dim=-1,p=p)#/ (num_elements)**(1/p)
        
    def _normalize(self, group_importance, normalizer):
        return group_importance / group_importance.mean()
        
    def _reduce(self, group_imp: typing.List[torch.Tensor], group_idxs: typing.List[typing.List[int]]):
        if len(group_imp) == 0: return group_imp
        reduced_imp = torch.zeros_like(group_imp[0], dtype=torch.float32)
        
        n_imp = 0
        for i, (imp, root_idxs) in enumerate(zip(group_imp, group_idxs)):
            imp = imp.to(reduced_imp.device, dtype=reduced_imp.dtype)
            if any([r is None for r in root_idxs]):
                #warnings.warn("Root idxs contain None values. Skipping this layer...")
                continue
            reduced_imp.scatter_add_(0, torch.tensor(root_idxs, device=imp.device), imp) # accumulated importance
            n_imp += 1

        reduced_imp /= n_imp
        
        return reduced_imp
