from ast import mod
from functools import partial
from operator import length_hint
import os
import numpy as np
import torch
from torch.nn import functional as F
from torch.cuda.amp import autocast, GradScaler

from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import argparse
import io
from PIL import Image
from layers import get_layers
from layers import LSTMAttentionModel,MLPAttentionModel,PrunableBatchNorm2d,LSTMAttentionModelCat
import sys
sys.path.append(".")
from data import prepare_expansive_data, IMAGENETCLASSES, IMAGENETNORMALIZE,get_tinyimagenet_dataloaders
from algorithms import generate_label_mapping_by_frequency,  label_mapping_base
from tools.misc import  set_seed
from models import ExpansiveVisualPrompt,PadVisualPrompt

from cfg import *
from torch.nn import Conv2d, Linear
import torch.nn as nn
import torchvision.models as models
import wandb
#from lightning.fabric import Fabric
from nfnets import AGC

def replace_layers(model, old_layer, new_layer,channelPrune,GenerateMask,Samplingnet):
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            model._modules[name] = replace_layers(module, old_layer, new_layer,channelPrune,GenerateMask,Samplingnet)
        if type(module) == old_layer :
            #print(module.weight.shape)
            if isinstance(module, nn.Conv2d):
                layer_new = new_layer(module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, module.bias,channelPrune,GenerateMask,Samplingnet)
                #print(layer_new.popup_scores)
            elif isinstance(module, nn.Linear):
                layer_new = new_layer(module.in_features, module.out_features, module.bias,channelPrune,GenerateMask,Samplingnet)            
            layer_new.weight.data=module.weight.data
            if module.bias is not None:
                layer_new.bias.data=module.bias.data
            model._modules[name] = layer_new
    return model

def create_conv(m,new_layer,channelPrune,GenerateMask,Samplingnet):
    layer_new=new_layer(m.in_channels, m.out_channels, m.kernel_size, m.stride, m.padding, m.dilation, m.groups, m.bias,channelPrune,GenerateMask,Samplingnet)
    layer_new.weight.data=m.weight.data
    if m.bias is not None:
        layer_new.bias.data=m.bias.data
    return layer_new
def infinite_loader(dataloader):
    while True:
        for batch in dataloader:
            yield batch
def getThreLSTM(network,cl,num_times,args=None,thres=None):
    total=[]
    num=0
    n_removed=0
    n_removed_total=0
    n_left_total=0
    for name, module in network.named_modules():        
        if isinstance(module,cl) and module.is_train:
            num+=1
            if num<=num_times:
                n_removed+=(module.adj==0).sum().item()
                n_removed_total+=module.adj.numel()
            else:
                #s=1-module.k                
                temp=(module.popup_scores.clone().detach().squeeze())
                n_left_total+=temp.numel()
                total.append(temp)
    if len(total)==0:
        threshold=thres
    else:
        total=torch.concat(total)
        s=1-args.current_k
        #if num_first>0:
        try:
            s_new=int(((n_removed_total+n_left_total)*s-n_removed))*1.0/total.shape[0]
            if s_new >=1.0:
                s_new=0.99
            if s_new<=0:
                s_new=0
            #else:
            #    s_new=s
            threshold=percentile(total,q=s_new*100)
        except Exception as e:
            print("S_New",s_new)
            print("args.current_k",args.current_k)
            print("Except occur",e)

    #print("sparsity",s*100,"threshold",threshold)
    return threshold          

def set_layerwise(network,module,args,n_times,cur_adj,pre_adj,thres):
    if n_times<=3:
        k=args.current_k
    else:
        k=args.current_k*args.scale_k
    temp_k=cur_adj.sum().item()/cur_adj.numel()
    if temp_k<k or temp_k<0.5:
        if n_times<=3:
            if temp_k<0.5:
                k=0.5
            else:
                k=args.current_k                            
        else:
            k=args.current_k 
        module.set_prune_rate(k)
        cur_adj=module.calculate_mask(pre_adj,None,glob=False)
        thres=getThreLSTM(network,args.cl,num_times=n_times,args=args,thres=thres)        
    else:
        thres=thres
        cur_adj=cur_adj
    return thres,cur_adj

def Calculate_mask(model,bn_detach=False,thres=None,glob=False,args=None,pp=False):
    pre_adj=1.0
    pre_scores=1.0
    n_times=0
    n_times+=1
    model.conv1.global_thre=thres
    cur_adj=model.conv1.calculate_mask(pre_adj,pre_scores,glob=glob)
    
    thres,cur_adj=set_layerwise(model,model.conv1,args,n_times,cur_adj,pre_adj,thres)
    # if cur_adj.sum().item()/cur_adj.numel()<args.first_k:
    #     model.conv1.set_prune_rate(args.first_k)
    #     cur_adj=model.conv1.calculate_mask(pre_adj,pre_scores,glob=False)
    #     thres=getThreLSTM(model,args.cl,skip_first=True,args=args)

    cur_scores=model.conv1.popup_scores
    
    if bn_detach and type(cur_adj) is not float:
        model.bn1.set_mask(cur_adj.detach())
    else:
        model.bn1.set_mask(cur_adj)
    #model.bn1.set_mask(cur_adj)
    pre_adj=cur_adj
    pre_scores=cur_scores
    if pp:
        print("pp thres",thres)
    for name,module in model.named_modules():
        if isinstance(module, models.resnet.BasicBlock):           
            if module.downsample is not None:
                if type(pre_adj) is float:
                    copy_adj=pre_adj
                    copy_scores=pre_scores
                else:
                    copy_adj=pre_adj.clone()
                    copy_scores=pre_scores.clone()
                n_times+=1
                module.conv1.global_thre=thres
                cur_adj=module.conv1.calculate_mask(pre_adj,pre_scores,glob=glob)
                thres,cur_adj=set_layerwise(model,module.conv1,args,n_times,cur_adj,pre_adj,thres)
                cur_scores=module.conv1.popup_scores

                #print("name",name,thres)
                if bn_detach and type(cur_adj) is not float:
                    module.bn1.set_mask(cur_adj.detach())
                else:
                    module.bn1.set_mask(cur_adj)
                pre_adj=cur_adj
                pre_scores=cur_scores
                n_times+=1
                module.conv2.global_thre=thres
                cur_adj=module.conv2.calculate_mask(pre_adj,pre_scores,glob=glob)
                thres,cur_adj=set_layerwise(model,module.conv2,args,n_times,cur_adj,pre_adj,thres)


                cur_scores=module.conv2.popup_scores
                if bn_detach and type(cur_adj) is not float:
                    module.bn2.set_mask(cur_adj.detach())
                else:
                    module.bn2.set_mask(cur_adj)


                pre_adj=cur_adj
                pre_scores=cur_scores
                module.downsample[0].global_thre=thres
                cur_adj=module.downsample[0].calculate_mask(copy_adj,1.0,glob=glob)
                cur_scores=module.downsample[0].popup_scores
                if args.shortcut=='depend':
                    module.downsample[0].adj=pre_adj
                    module.downsample[0].popup_scores=1.0

                    #module.downsample[1].set_mask(pre_adj)
                    if bn_detach and type(pre_adj) is not float:
                        module.downsample[1].set_mask(pre_adj.detach())
                    else:
                        module.downsample[1].set_mask(pre_adj)
                elif args.shortcut=='intersect':
                    adj=((cur_adj+pre_adj)>0).long()
                    module.conv2.adj=adj
                    if bn_detach and type(adj) is not float:
                        module.downsample[1].set_mask(adj.detach())
                    else:
                        module.downsample[1].set_mask(adj)
                    module.downsample[0].adj=adj
                    #module.downsample[1].set_mask(adj)
                    pre_adj=adj
                    

                ##no need for bn2
            else:
                #copy_adj=pre_adj.clone()
                n_times+=1
                module.conv1.global_thre=thres
                cur_adj=module.conv1.calculate_mask(pre_adj,pre_scores,glob=glob)

                thres,cur_adj=set_layerwise(model,module.conv1,args,n_times,cur_adj,pre_adj,thres)

                cur_scores=module.conv1.popup_scores
                #module.bn1.set_mask(cur_adj)
                if bn_detach and type(cur_adj) is not float:
                    module.bn1.set_mask(cur_adj.detach())
                else:
                    module.bn1.set_mask(cur_adj)
                pre_adj=cur_adj
                pre_scores=cur_scores
                n_times+=1
                module.conv2.global_thre=thres
                cur_adj=module.conv2.calculate_mask(pre_adj,pre_scores,glob=glob)

                thres,cur_adj=set_layerwise(model,module.conv2,args,n_times,cur_adj,pre_adj,thres)


                cur_scores=module.conv2.popup_scores
                #module.bn2.set_mask(cur_adj)
                if bn_detach and type(cur_adj) is not float:
                    module.bn2.set_mask(cur_adj.detach())
                else:
                    module.bn2.set_mask(cur_adj)
                pre_adj=cur_adj
                pre_scores=cur_scores

        elif  isinstance(module, models.resnet.Bottleneck):            
            if module.downsample is not None:
                if type(pre_adj) is float:
                    copy_adj=pre_adj
                else:
                    copy_adj=pre_adj.clone()
                n_times+=1
                module.conv1.global_thre=thres
                cur_adj=module.conv1.calculate_mask(pre_adj,glob=glob)
                thres,cur_adj=set_layerwise(model,module.conv1,args,n_times,cur_adj,pre_adj,thres)
                
                #module.bn1.set_mask(cur_adj)
                if bn_detach and type(cur_adj) is not float:
                    module.bn1.set_mask(cur_adj.detach())
                else:
                    module.bn1.set_mask(cur_adj)
                pre_adj=cur_adj
                n_times+=1
                module.conv2.global_thre=thres
                cur_adj=module.conv2.calculate_mask(pre_adj,glob=glob)
                thres,cur_adj=set_layerwise(model,module.conv2,args,n_times,cur_adj,pre_adj,thres)                
                #module.bn2.set_mask(cur_adj)
                if bn_detach and type(cur_adj) is not float:
                    module.bn2.set_mask(cur_adj.detach())
                else:
                    module.bn2.set_mask(cur_adj)
                    
                pre_adj=cur_adj
                n_times+=1
                module.conv3.global_thre=thres
                cur_adj=module.conv3.calculate_mask(pre_adj,glob=glob)
                thres,cur_adj=set_layerwise(model,module.conv3,args,n_times,cur_adj,pre_adj,thres)
                
                if bn_detach and type(cur_adj) is not float:
                    module.bn3.set_mask(cur_adj.detach())
                else:
                    module.bn3.set_mask(cur_adj)
                
                pre_adj=cur_adj
                module.downsample[0].global_thre=thres
                cur_adj=module.downsample[0].calculate_mask(copy_adj,glob=glob)
                if args.shortcut=='depend':
                    module.downsample[0].adj=pre_adj
                    if bn_detach and type(pre_adj) is not float:
                        module.downsample[1].set_mask(pre_adj.detach())
                    else:
                        module.downsample[1].set_mask(pre_adj)
                elif args.shortcut=='intersect':
                    adj=((cur_adj+pre_adj)>0).long()
                    module.conv3.adj=adj
                    module.downsample[0].adj=adj
                    if bn_detach and type(adj) is not float:
                        module.downsample[1].set_mask(adj.detach())
                    else:
                        module.downsample[1].set_mask(adj)
                    pre_adj=adj
            else:
                #copy_adj=pre_adj.clone()
                n_times+=1
                module.conv1.global_thre=thres
                cur_adj=module.conv1.calculate_mask(pre_adj,glob=glob)
                thres,cur_adj=set_layerwise(model,module.conv1,args,n_times,cur_adj,pre_adj,thres)
                
                #module.bn1.set_mask(cur_adj)
                if bn_detach and type(cur_adj) is not float:
                    module.bn1.set_mask(cur_adj.detach())
                else:
                    module.bn1.set_mask(cur_adj)
                pre_adj=cur_adj
                n_times+=1
                module.conv2.global_thre=thres
                cur_adj=module.conv2.calculate_mask(pre_adj,glob=glob)
                thres,cur_adj=set_layerwise(model,module.conv2,args,n_times,cur_adj,pre_adj,thres)
                
                #module.bn2.set_mask(cur_adj)
                if bn_detach and type(cur_adj) is not float:
                    module.bn2.set_mask(cur_adj.detach())
                else:
                    module.bn2.set_mask(cur_adj)
                pre_adj=cur_adj
                n_times+=1
                module.conv3.global_thre=thres
                cur_adj=module.conv3.calculate_mask(pre_adj,glob=glob)
                thres,cur_adj=set_layerwise(model,module.conv3,args,n_times,cur_adj,pre_adj,thres)
                
                if bn_detach and type(cur_adj) is not float:
                    module.bn3.set_mask(cur_adj.detach())
                else:
                    module.bn3.set_mask(cur_adj)
                #module.bn3.set_mask(cur_adj)
                pre_adj=cur_adj


def replace_layers_resnet(model,new_layer,channelPrune,GenerateMask,Samplingnet,args):    
   
    model.conv1=create_conv(model.conv1,new_layer,channelPrune,GenerateMask,Samplingnet)
    model.bn1=PrunableBatchNorm2d(model.bn1)

    for name,module in model.named_modules():
        if isinstance(module, models.resnet.BasicBlock):
            module.conv1=create_conv(module.conv1,new_layer,channelPrune,GenerateMask,Samplingnet)
            module.bn1=PrunableBatchNorm2d(module.bn1)
            module.conv2=create_conv(module.conv2,new_layer,channelPrune,GenerateMask,Samplingnet)
            module.bn2=PrunableBatchNorm2d(module.bn2)
            if module.downsample is not None:
                module.downsample[0]=create_conv(module.downsample[0],new_layer,channelPrune,GenerateMask,Samplingnet)
                if args.shortcut=='none':
                    module.downsample[0].is_train=False
                    module.conv2.is_train=False
                elif  args.shortcut=='depend':
                    module.downsample[0].is_train=False
                    #module.conv2.is_train=True
                module.downsample[1]=PrunableBatchNorm2d(module.downsample[1])
        

        elif  isinstance(module, models.resnet.Bottleneck):
            module.conv1=create_conv(module.conv1,new_layer,channelPrune,GenerateMask,Samplingnet)
            module.bn1=PrunableBatchNorm2d(module.bn1)
            module.conv2=create_conv(module.conv2,new_layer,channelPrune,GenerateMask,Samplingnet)
            module.bn2=PrunableBatchNorm2d(module.bn2)            
            module.conv3=create_conv(module.conv3,new_layer,channelPrune,GenerateMask,Samplingnet)
            module.bn3=PrunableBatchNorm2d(module.bn3)            
            if module.downsample is not None:
                module.downsample[0]=create_conv(module.downsample[0],new_layer,channelPrune,GenerateMask,Samplingnet)
                if args.shortcut=='none':
                    module.downsample[0].is_train=False
                    module.conv3.is_train=False
                if args.shortcut=='depend':
                    module.downsample[0].is_train=False
                module.downsample[1]=PrunableBatchNorm2d(module.downsample[1])


def setPruneRate(network,cl,ll,r):
    for name, module in network.named_modules():
        if isinstance(module,cl):
            module.set_prune_rate(r)
        if  isinstance(module,ll):
            module.set_prune_rate(r)

def setPruneScore(network,cl,maskgenerator,vp):
    i=0
    for name, module in network.named_modules():
        if isinstance(module,cl):
            if i==0:
                score=maskgenerator(vp.program,torch.mean(module.weight,dim=(3,2,1)))
                module.popup_scores=score.view(-1,1,1,1)[:module.weight.shape[0]]                
            else:
                score=maskgenerator(maskgenerator.out,module.weight.mean(dim=(3,2,1)))
                module.popup_scores=score.view(-1,1,1,1)[:module.weight.shape[0]]                
            i+=1    

def percentile(t, q):
    k = 1 + round(.01 * float(q) * (t.numel() - 1))
    return t.view(-1).kthvalue(k).values.item()
def setPruneScoreLSTM(network,cl,results,sparsity=None,glob=False):
    i=0
    reg_loss=0
    total=[]
    if type(results) is not list:
        results=results.squeeze()               
    for name, module in network.named_modules():
        if isinstance(module,cl) and module.is_train:            
            score=results[i].view(-1,1,1,1)
            
            s=1-module.k
            module.popup_scores=score                              
            total.append(module.popup_scores.clone().detach().squeeze())
            
            i+=1
    if glob:

        total=torch.concat(total)
        threshold=percentile(total,q=s*100)
    else:
        threshold=None
    return threshold,reg_loss




def setPruneScoreConv(network,cl,results):
    #i=0
    results=results.squeeze()
    start=0
    for name, module in network.named_modules():
        if isinstance(module,cl) and module.is_train:            
            score=results[start:start+module.weight.shape[0]].view(-1,1,1,1)
            module.popup_scores=score
            start+=module.weight.shape[0]
    #print("total set Score:",start)


def setPruneScore1(network,cl,maskgenerator,vp):
    i=0
    for name, module in network.named_modules():
        if isinstance(module,cl):
            if i==0:
                score=maskgenerator[i](vp.program,torch.mean(module.weight,dim=(3,2,1)))
                module.popup_scores=score.view(-1,1,1,1)[:module.weight.shape[0]]
            else:
                score=maskgenerator[i](maskgenerator[i-1].out,module.weight.mean(dim=(3,2,1)))
                module.popup_scores=score.view(-1,1,1,1)[:module.weight.shape[0]]
            i+=1    

def sparseModel(visual_prompt,mask_generator,network,args,vp_detach,cl,x=None):
    global feature_maps
    s=1-args.k
    reg_loss=0.0
    if args.MaskGeneration:
        if vp_detach:
            vp1=visual_prompt.program.detach()
        else:
            vp1=visual_prompt.program
        if args.vp_mode=='vp':
            vp=vp1
        elif args.vp_mode=='vp_sigmoid':
            vp=torch.sigmoid(vp1)
        elif args.vp_mode=='vp_sigmoid_mask':
            if args.vp_method=='Pad':
                vp=vp1*visual_prompt.mask
            else:
                vp=torch.sigmoid(vp1)*visual_prompt.mask
        else:
            assert False
        if args.Masknetwork=='convShared':
            setPruneScore(network,cl,mask_generator,visual_prompt)
        elif args.Masknetwork=='lstm':
            out=mask_generator(x_mask,vp)
            thre,_=setPruneScoreLSTM(network,cl,out,sparsity=s,glob=args.glob)
        elif args.Masknetwork=='lstm_full':
            x_mask=GetWeights2Inputs(network,cl)
            out=mask_generator(x_mask,vp)
            #print(out)
            thre,reg_loss=setPruneScoreLSTM(network,cl,out,sparsity=s,glob=args.glob)
        elif args.Masknetwork=='lstmv1':
            feature_maps.clear()
            _=network(visual_prompt(x))
            features=torch.cat(feature_maps,dim=0).unsqueeze(0)
            out=mask_generator.forward_v1(features)
            thre,_=setPruneScoreLSTM(network,cl,out,sparsity=s,glob=args.glob)
        elif args.Masknetwork=='conv':
            out=mask_generator(vp)
            thre=setPruneScoreConv(network,cl,out)
        elif args.Masknetwork=='conv_inputs':
            out=mask_generator(visual_prompt(x))
            setPruneScoreConv(network,cl,out)
        elif args.Masknetwork=='lstm_vp':
            out=mask_generator(vp)
            thre,_=setPruneScoreLSTM(network,cl,out,sparsity=s,glob=args.glob)
        elif args.Masknetwork=='lstmv1_full':                
            feature_maps.clear()
            _=network(visual_prompt(x))
            #features=torch.cat(feature_maps,dim=0).unsqueeze(0)
            features=feature_maps
            out=mask_generator.forward_v1(features)
            thre,_=setPruneScoreLSTM(network,cl,out,sparsity=s,glob=args.glob)
        else:    
            print("Not implemented!")
    else:
        thre=None
    return thre,reg_loss

def gradual_pruning_rate(
        step: int,
        initial_threshold: float,
        final_threshold: float,
        initial_time: int,
        final_time: int,
):
    if step <= initial_time:
        threshold = initial_threshold
    elif step > final_time:
        threshold = final_threshold
    else:
        mul_coeff = 1 - (step - initial_time) / (final_time - initial_time)
        threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff ** 3)

    return threshold


import torch.nn.utils as utils
def Training(epoch,label_mapping, cl,network,visual_prompt,mask_generator,args,loaders,optimizer,logger,optimizer_vp=None,loader_score=None,wandb=None,name=None):
    visual_prompt.train()
    mask_generator.train()
    grad_accumulators = {n: torch.zeros_like(p, device=p.device) for n, p in mask_generator.named_parameters() if p.requires_grad and "bias" not in n}
    averaged_grads=None
    grad_accumulators_before = {n: torch.zeros_like(p, device=p.device) for n, p in mask_generator.named_parameters() if p.requires_grad and "bias" not in n}
    averaged_grads_before=None
    if args.OpenBatchnorm:
        network.train()
    else:
        network.eval()
    total_num = 0
    true_num = 0
    loss_sum = 0
    lr_str=f"Epoch: {epoch}"    
    for p in optimizer.param_groups:
        lr_str+=f"Training Lr {p['lr']:.1e}" 
        lr_str+=f"Training WD {p['weight_decay']:.1e}"
    if optimizer_vp is not None:
        for p in optimizer_vp.param_groups:
            lr_str+=f"Training Lr {p['lr']:.1e}" 
            lr_str+=f"Training WD {p['weight_decay']:.1e}"
    print(lr_str)

    density_level=[]
    info_dict={}
    total_steps=0
    for i,((x, y),(x1,y1)) in enumerate(zip(loaders['train'],loader_score['train'])):
        total_steps=epoch*len(loaders['train'])+i
        if x.get_device() == -1:
            x, y = x.to(device), y.to(device)
            x1,y1=x1.to(device),y1.to(device)        
        reg_loss=0.0
        if name !='FinetuneTrain':
            #reg_loss=0.0
            thre,reg_loss=sparseModel(visual_prompt,mask_generator,network,args,args.vp_detach,cl)
            if args.ChannelPrune=='channel':
                Calculate_mask(network,args.bn_detach,thres=thre,glob=args.glob,args=args)

        with autocast():
            fx = label_mapping(network(visual_prompt(x)))
            loss = F.cross_entropy(fx, y, reduction='mean')
            if args.Samplingnet or args.glob:
                loss=loss

        #optimizer.zero_grad()
        zero_grad(optimizer,optimizer_vp)
        loss.backward()
        if args.gradient_flow:
            for n, p in mask_generator.named_parameters():
                if p.requires_grad and "bias" not in n:
                    grad_accumulators_before[n] += p.grad.detach()


        if args.glob and args.gradient_clip:
            utils.clip_grad_value_(mask_generator.parameters(), args.grad_clip)
        optimizer.step()

        if args.gradient_flow:
            for n, p in mask_generator.named_parameters():
                if p.requires_grad and "bias" not in n:
                    grad_accumulators[n] += p.grad.detach()

        #fabric.backward(loss)


        if optimizer_vp:
            optimizer_vp.step()
        if args.alternate:
            if total_steps%args.interval==0:
                if name !='FinetuneTrain':
                    thre,reg_loss=sparseModel(visual_prompt,mask_generator,network,args,args.vp_detach,cl)
                    if args.ChannelPrune=='channel':
                        Calculate_mask(network,args.bn_detach,thres=thre,glob=args.glob,args=args)
                with autocast():
                    if args.two_loader:
                        #x1,y1=next(loader_score)
                        if x1.get_device() == -1:
                            x1, y1 = x1.to(device), y1.to(device) 
                        fx1 = label_mapping(network(args.normalize(x1)))
                        loss = F.cross_entropy(fx1, y1, reduction='mean')

                #optimizer_vp.zero_grad()

                zero_grad(optimizer,optimizer_vp)
                loss.backward()
                if args.glob and args.gradient_clip:
                    utils.clip_grad_value_(mask_generator.parameters(), args.grad_clip)
                #fabric.backward(loss)
                optimizer.step()
                if optimizer_vp is not None:
                    optimizer_vp.step()

        total_num += y.size(0)
        true_num += torch.argmax(fx, 1).eq(y).float().sum().item()
        loss_sum += loss.item() * fx.size(0)
    if args.Samplingnet or args.glob:
        num_ones=0
        num_all=0
        for name, module in network.named_modules():
            if isinstance(module,cl) and (hasattr(module,'is_train') and module.is_train) :
                temp_num=module.adj.clone().detach().sum().item()
                num_ones+=temp_num
                num_all+=module.adj.numel()
                print(f"layer density at {name}:",temp_num/module.adj.numel())
                #reg_loss+=module.clamped_scores.clone().sum()                         
        density_level.append(num_ones*1.0/num_all)


    if args.Samplingnet or args.glob:
        density_level=np.mean(density_level)
        print(f"Acc {100*true_num/total_num:.2f}% Train Loss {loss_sum/total_num:.3f}  density level {density_level:.3f}")
    else:
        print(f"Acc {100*true_num/total_num:.2f}% Train Loss {loss_sum/total_num:.3f}  ")

    logger.add_scalar("train/acc", true_num/total_num, epoch)
    logger.add_scalar("train/loss", loss_sum/total_num, epoch)
    info_dict[name+"-Train Loss"]=loss_sum/total_num
    info_dict[name+"-Train Acc"]=true_num/total_num

    if args.gradient_flow:
        num_batches=len(loaders['train'])
        averaged_grads = {n: grad_accumulators[n] / num_batches for n in grad_accumulators}
        averaged_grads_before = {n: grad_accumulators_before[n] / num_batches for n in grad_accumulators_before}
    return averaged_grads,averaged_grads_before
def Test(epoch,label_mapping, cl,network,visual_prompt,mask_generator,args,loaders,optimizer,logger,wandb=None,name=None):
    # Test
    global best_acc 
    s=1-args.k
    mask_generator.eval()
    visual_prompt.eval()
    network.eval()
    if name!='FinetuneTest':       
        if args.MaskGeneration:
            vp1=visual_prompt.program.detach()
            if args.vp_mode=='vp':
                vp=vp1
            elif args.vp_mode=='vp_sigmoid':
                vp=torch.sigmoid(vp1)
            elif args.vp_mode=='vp_sigmoid_mask':
                if args.vp_method=='Pad':
                    vp=vp1*visual_prompt.mask
                else:
                    vp=torch.sigmoid(vp1)*visual_prompt.mask
            else:
                assert False
            if args.Masknetwork=='convShared':
                setPruneScore(network,cl,mask_generator,visual_prompt)
            elif args.Masknetwork=='lstm':
                out=mask_generator(x_mask,vp)
                thres,_=setPruneScoreLSTM(network,cl,out,sparsity=s,glob=args.glob)
            elif args.Masknetwork=='conv':
                out=mask_generator(vp)
                thres=None
                setPruneScoreConv(network,cl,out)
            elif args.Masknetwork=='lstm_vp':
                out=mask_generator(vp)
                thres,_=setPruneScoreLSTM(network,cl,out,sparsity=s,glob=args.glob)
            elif args.Masknetwork=='lstm_full':
                #pass
                x_mask=GetWeights2Inputs(network,cl)
                out=mask_generator(x_mask,vp)
                thres,reg_loss=setPruneScoreLSTM(network,cl,out,sparsity=s,glob=args.glob)        
            elif args.Masknetwork=='lstmv1_full' or args.Masknetwork=='lstmv1' or args.Masknetwork=='conv_inputs':
                pass 
            else:            
                print("Not implemented!")
        else:
            thres=None
        if args.ChannelPrune=='channel':
            #thres=getThreLSTM(network,cl,None,sparsity=s,glob=args.glob)
            print("thres:",thres)
            Calculate_mask(network,args.bn_detach,thres=thres,glob=args.glob,args=args,pp=True)
               
    network.eval()
    total_num = 0
    true_num = 0
    loss_sum=0
    #pbar = tqdm(, total=len(loaders['test']), desc=f"Epo {epoch} Testing", ncols=100,position=0, leave=True)
    fx0s = []
    ys = []
    info_dict={}
    for x, y in loaders['test']:
        if x.get_device() == -1:
            x, y = x.to(device), y.to(device)
       
        ys.append(y)
        with torch.no_grad():
            fx0 = network(visual_prompt(x))
            fx = label_mapping(fx0)            
            loss = F.cross_entropy(fx, y, reduction='mean')
            loss_sum+=loss.item()*fx.size(0)
        
        total_num += y.size(0)
        true_num += torch.argmax(fx, 1).eq(y).float().sum().item()
        acc = true_num/total_num
        fx0s.append(fx0)

    print(f"Test Acc {100*acc:.2f}%")
    fx0s = torch.cat(fx0s).cpu()
    ys = torch.cat(ys).cpu()
    logger.add_scalar("test/acc", acc, epoch)

    info_dict[name+"-Test Loss"]=loss_sum/total_num
    info_dict[name+"-Test Acc"]=true_num/total_num

    density_level=[]
    if args.Samplingnet or args.glob:
        num_ones=0
        num_all=0
        for name, module in network.named_modules():
            if isinstance(module,cl) and (hasattr(module,'is_train') and module.is_train) :
                temp_num=module.adj.clone().detach().sum().item()
                num_ones+=temp_num
                num_all+=module.adj.numel()
                print(f"layer density at {name}:",temp_num/module.adj.numel())
                #reg_loss+=module.clamped_scores.clone().sum()                         
        density_level.append(num_ones*1.0/num_all)
        density_level=np.mean(density_level)
        print(f"density level {density_level:.3f}")
    return acc


def get_lr(lr_schedule,epochs,lr_max):
    if lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [0, epochs * 2 // 5, epochs], [0, lr_max, 0])[0]
    elif lr_schedule == 'piecewise':
        def lr_schedule(t):
            if t / epochs < 0.8:
                return lr_max
            elif t / epochs < 0.9:
                return lr_max / 10.
            else:
                return lr_max / 100.
    elif lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [0, epochs // 3, epochs * 2 // 3, epochs], [lr_max, lr_max, lr_max / 10, lr_max / 100])[0]
    elif lr_schedule == 'onedrop':
        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return lr_max
            else:
                return args.lr_one_drop
    elif lr_schedule == 'multipledecay':
        def lr_schedule(t):
            return lr_max - (t//(epochs//10))*(lr_max/10)
    elif lr_schedule == 'cosine': 
        def lr_schedule(t): 
            return lr_max * 0.5 * (1 + np.cos(t / epochs * np.pi))
    return lr_schedule

def get_schedular(optimizer,scheduler,args,epochs,T_max=None):
    if scheduler == 'cosine':
        #if admmfintune:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)
        #else:
        #    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.T_max)
    elif scheduler == 'multistep':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(epochs * _) for _ in args.decreasing_step], gamma=0.1)
    else:
        raise ValueError('scheduler should be one of [cosine, multistep]')
    return scheduler


def get_sparsity(sparsity,current_epoches,start_epoches,end_epoches):
    if current_epoches>end_epoches:
        current_epoches=end_epoches
    sparsity=sparsity-sparsity*(1-(current_epoches-start_epoches)*1.0/(end_epoches-start_epoches))
    print("Sparsity at epochs:",current_epoches,"  ",sparsity)
    return sparsity

import random
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


best_acc = 0.
global feature_maps
def GetWeights2Inputs(network,cl):
    x_mask=[]
    for name, module in network.named_modules():                
        if isinstance(module,cl) and module.is_train:
            if type(module.pre_adj) is not float: 
                pre_adj=module.pre_adj.view(1,-1,1,1)
            else:
                pre_adj=module.pre_adj              
            temp=(module.weight.abs()*pre_adj).sum(dim=(3,2,1))/((module.weight.abs()*pre_adj>0)).sum(dim=(3,2,1))
            #temp=(module.weight*pre_adj).mean(dim=(3,2,1))
            padded=512-temp.shape[0]
            temp=F.pad(temp,(0,padded),mode='constant',value=0)
            x_mask.append(temp.clone().detach())
    return x_mask

def zero_grad(optimizer1,optimzer2):
    if optimizer1 is not None:
        optimizer1.zero_grad()
    if optimzer2 is not None:
        optimzer2.zero_grad()

import numpy as np
import matplotlib.pyplot as plt

def plot_grad_flow_avg(named_parameters):
    ave_grads = []
    max_grads = []
    layers = []
    for n, p in named_parameters:
        #if(p.requires_grad) and ("bias" not in n):
        layers.append(n)
        ave_grads.append(p.mean().item())
        max_grads.append(p.abs().max().item())
    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.3, lw=1,label='max')
    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.3, lw=1,label='avg')
    plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k")
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.legend()
    #plt.ylim(bottom = -0.1, top=0.2)  # Adjust if needed
    plt.title("Average Gradient Flow for Epoch")
    plt.tight_layout()
import os
if __name__ == '__main__':

    p = argparse.ArgumentParser()
    p.add_argument('--network', choices=["resnet18", "resnet34","resnet50", "instagram"], required=True)
    p.add_argument('--seed', type=int, default=7)
    p.add_argument('--alg', type=str,default='ours')
    p.add_argument('--datadir',type=str,default='./')
    p.add_argument('--save', type=str,default='ckpoints')
    p.add_argument('--restore', action='store_true')
    p.add_argument('--dataset', choices=["cifar10", "cifar100", "abide", "dtd", "flowers102", "ucf101", "food101", "gtsrb", "svhn", "eurosat", "oxfordpets", "stanfordcars", "sun397",'tinyimagenet'], required=True)
    p.add_argument('--epoch', type=int, default=50)
    p.add_argument('--input_dim', type=int, default=32)
    p.add_argument('--hidden_dim', type=int, default=128)
    p.add_argument('--output_dim', type=int, default=512)
    #visual prompt
    p.add_argument('--vp_method', type=str, choices=['Pad','Expand'],default='Expand')
    p.add_argument('--input_size', type=int, default=224)
    p.add_argument('--pad_size', type=int, default=16)
    p.add_argument('--output_size', type=int, default=224)

    p.add_argument('--heads', type=int, default=8)
    p.add_argument('--alpha', type=float, default=1e-4)          
    p.add_argument('--resize', type=int, default=32)
    p.add_argument('--k', type=float, default=1.0)
    p.add_argument('--first_k', type=float, default=0.5)
    p.add_argument('--scale_k', type=float, default=0.6)
    p.add_argument('--warmup', type=int, default=5)    
    p.add_argument('--Aug', action='store_true')
    p.add_argument('--opt_ft_vp',action='store_true')
    p.add_argument('--glob',action='store_true')
    p.add_argument('--vp_detach', action='store_true')
    p.add_argument('--vp_mode',choices=["vp_sigmoid",'vp','vp_sigmoid_mask'],default='vp')
    p.add_argument('--MaskGeneration', action='store_true')
    p.add_argument('--Masknetwork',choices=["lstm",'lstm_full', "convShared", "conv",'lstmv1','lstmv1_full','lstm_vp','conv_inputs'],default='lstm')    
    p.add_argument('--ScaledInitialize', action='store_true')
    p.add_argument('--OpenBatchnorm', action='store_true')
    p.add_argument('--Samplingnet', action='store_true')
    p.add_argument('--Twostage', action='store_true')
    p.add_argument('--TestGM', action='store_true')
    p.add_argument('--PruneLinear', action='store_true')
    p.add_argument('--Test_val', action='store_true')    
    p.add_argument('--Admm',choices=["admm",'admmfintune','none'],default='none')
    p.add_argument('--ft_vp_opt',choices=["adamw",'SGD'],default='adamw')
    p.add_argument('--ft_weight_opt',choices=["adamw",'SGD'],default='adamw')
    p.add_argument('--prune_vp_opt',choices=["adamw",'SGD'],default='adamw')
    p.add_argument('--prune_score_opt',choices=["adamw",'SGD'],default='adamw')

    p.add_argument('--hydra_scheduler', default='cosine', help='decreasing strategy.', choices=['cosine', 'multistep'])
    p.add_argument('--vp_scheduler', default='multistep', help='decreasing strategy.', choices=['cosine', 'multistep'])
    p.add_argument('--weight_scheduler', default='cosine', help='decreasing strategy.', choices=['cosine', 'multistep'])        
    p.add_argument('--score_scheduler', default='multistep', help='decreasing strategy.', choices=['cosine', 'multistep'])

    p.add_argument('--beidi', action='store_true')
    p.add_argument('--DifferentLR', action='store_true')
    p.add_argument('--finetune', action='store_true')
    p.add_argument('--alternate',action='store_true')
    p.add_argument('--lstm_catinput',action='store_true')    
    p.add_argument('--lstm_novp',action='store_true')
    p.add_argument('--SepOptimizer', action='store_true')
    p.add_argument('--Testalternate', action='store_true')
    p.add_argument('--momentum', default=0.9, type=float, help='momentum')
    p.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
    p.add_argument('--weight_decay_FF', default=5e-4, type=float, help='weight decay')
    p.add_argument('--shortcut', default='none',  choices=['none', 'intersect','depend'])

    p.add_argument('--decreasing_step', default=[0.6,0.8], type = list, help='decreasing strategy')
    p.add_argument('--Gradual_Pruning', action='store_true')
    p.add_argument('--epoch_1', type=int, default=50)
    p.add_argument('--epoch_2', type=int, default=50)
    p.add_argument('--T_max', type=int, default=50)
    p.add_argument('--T_max_finetune', type=int, default=20)    
    p.add_argument('--total_admm',type=int,default=5)
    p.add_argument('--ChannelPrune', choices=['kernel','channel','weight','inputchannel'],default='kernel')
    p.add_argument('--layer_type',choices=['subnet','dense'],default='subnet')
    p.add_argument('--optimizer',choices=['Adam','SGD','AdamW','AdamwAll'],default='Adam')
    p.add_argument('--lr', type=float, default=0.01)
    p.add_argument('--lr_vp', type=float, default=0.001)
    p.add_argument('--ft_lr_vp', type=float, default=0.001)
    p.add_argument('--ft_lr', type=float, default=0.001)
    p.add_argument('--gmp_T',type=int,default=1)
    p.add_argument('--finetuneEpoch',type=int,default=60)
    p.add_argument('--interval',type=int,default=1)
    p.add_argument('--lr_scale', type=float, default=1.0)
    p.add_argument('--grad_clip', type=float, default=1e-4)
    p.add_argument('--lr_scale_decay', action='store_true')
    p.add_argument('--bn_detach', action='store_true')
    p.add_argument('--gradient_flow', action='store_true')
    p.add_argument('--gradient_clip', action='store_true')
    p.add_argument('--AGC', action='store_true')
    p.add_argument('--MLP', action='store_true')
    p.add_argument('--AutoRestore', action='store_true')
    
    p.add_argument('--co_ft', action='store_true')
    p.add_argument('--save_dir', type=str,default='./checkpoints')
    p.add_argument('--ft_init_vp', action='store_true')
    p.add_argument('--two_loader', action='store_true')
    p.add_argument('--normalize',choices=['original','none','downstream'], default='original')   
    p.add_argument('--clean_finetune',choices=['Original','VP','Aug'],default='Adam')  

    p.add_argument('--ft_WD',type=float,default=5e-4)  
    ##Wandb
    p.add_argument('--with-wandb', default=False, action='store_true', help='Enables Weights and Biases')
    p.add_argument('--wandb-entity', default='', type=str, help='WandB username (entity).')
    p.add_argument('--env', default='cifar10', type=str, help='WandB username (entity).')
    p.add_argument('--wandb-project', default='VPForSparse', type=str, help='WandB "Project"')
    p.add_argument('--wandb-group', default=None, type=str, help='WandB "Group". Name of the env by default.')
    p.add_argument('--wandb-job_type', default='train', type=str, help='WandB job type')
    p.add_argument('--wandb-tags', default=[], type=str, nargs='*', help='Tags can help finding experiments')
    p.add_argument('--wandb-key', default=None, type=str, help='API key for authorizing WandB')
    p.add_argument('--wandb-dir', default=None, type=str, help='the place to save WandB files')
    p.add_argument('--wandb-experiment', default='', type=str, help='Identifier to specify the experiment')  

    args = p.parse_args()
    print(args)
    args.first_k=args.k
    set_seed(100)
    # Misc
    # save_path = os.path.join(args.save_dir, args.network, args.dataset, 'Prune_VP'+str(args.vp_method),
    #         'SIZE'+str(args.output_size)+'_'+str(args.input_size)+'_'+str(args.pad_size),
    #         args.prune_vp_opt+'_'+args.prune_score_opt+'_'+args.ft_weight_opt+'_'+args.ft_vp_opt, 
    #         'LR'+str(args.lr)+'_'+str(args.lr_vp)+'_'+str(args.ft_lr)+'_'+str(args.ft_lr_vp),  
    #         'Masknetwork'+str(args.Masknetwork)+"_Hiddensize_"+str(args.hidden_dim), 'clean_finetune'+str(args.clean_finetune)+"_"+'normalize'+str(args.normalize),'glob'+str(args.glob),'k'+str(args.k))
    save_path = os.path.join(args.save_dir, args.network, args.dataset, 'Prune_VP'+str(args.vp_method),
            'SIZE'+str(args.output_size)+'_'+str(args.input_size)+'_'+str(args.pad_size),
            args.prune_vp_opt+'_'+args.prune_score_opt, 
            'LR'+str(args.lr)+'_'+str(args.lr_vp),  
            'Masknetwork'+str(args.Masknetwork)+"_Hiddensize_"+str(args.hidden_dim), 'normalize'+str(args.normalize)+'Aug'+str(args.Aug),'glob'+str(args.glob),'k'+str(args.k))
    print(save_path)
    os.makedirs(save_path, exist_ok=True)
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    #fabric=Fabric(accelerator='cuda')
    ##fabric.launch()
    set_seed(args.seed)
    exp = f"cnn/flm_vp"
    #save_path = os.path.join(results_path, exp, gen_folder_name(args))

    # Data
    if args.dataset=='tinyimagenet':
        loaders, configs=get_tinyimagenet_dataloaders(args,resize=args.resize, dataAug=args.Aug)
        loaders_score, configs_score=get_tinyimagenet_dataloaders(args,resize=args.resize, dataAug=True)
    else:
        loaders, configs = prepare_expansive_data(args.dataset, data_path=data_path, resize=args.resize, dataAug=args.Aug)
        loaders_score, configs_score = prepare_expansive_data(args.dataset, data_path=data_path, resize=args.resize, dataAug=True)
    #loaders_score['train']=infinite_loader(loaders_score['train'])
    print("train dataset size",len(loaders['train']),"test dataset size",len(loaders['test']))
    normalize = transforms.Normalize(IMAGENETNORMALIZE['mean'], IMAGENETNORMALIZE['std'])
    #loaders['train']=fabric.setup_dataloaders(loaders['train'])
    #loaders['test']=fabric.setup_dataloaders(loaders['test'])
    #print("loaders['test']",loaders['test'].data.shape)

    # Network
    if args.network == "resnet18":
        from torchvision.models import resnet18, ResNet18_Weights
        network = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1).to(device)
    elif args.network == "resnet50":
        from torchvision.models import resnet50, ResNet50_Weights
        network = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1).to(device)
    elif args.network == "resnet34":
        from torchvision.models import resnet34, ResNet34_Weights
        network = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1).to(device)
    elif args.network == "instagram":
        from torch import hub
        network = hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl').to(device)
    else:
        raise NotImplementedError(f"{args.network} is not supported")

    cl, ll = get_layers(args.layer_type)
    args.cl=cl
    #print(network)
    if 'resnet' in args.network:
        replace_layers_resnet(network,cl,args.ChannelPrune,args.MaskGeneration,args.Samplingnet,args)
    else:
        network=replace_layers(network,Conv2d,cl,args.ChannelPrune,args.MaskGeneration,args.Samplingnet)
    if args.PruneLinear:
        print("Not Supported!")
        #network=replace_layers(network,Linear,ll,args.ChannelPrune,args.MaskGeneration,args.Samplingnet)
    print(network)
    network=network.to(device)
    
    # Visual Prompt
    if args.normalize=='none':
        if args.vp_method=='Expand':
            visual_prompt = ExpansiveVisualPrompt(224, mask=configs['mask']).to(device)
        else:
            visual_prompt = PadVisualPrompt(args).to(device)
    elif args.normalize=='original':
        if args.vp_method=='Expand':
            print(configs['mask'].shape)
            visual_prompt = ExpansiveVisualPrompt(224, mask=configs['mask'], normalize=normalize).to(device)
        else:
            visual_prompt = PadVisualPrompt(args,normalize).to(device)
    elif args.normalize=='downstream':
        if args.dataset=='cifar100':
            cifar_mean = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
            cifar_std = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
            normalize = transforms.Normalize(cifar_mean, cifar_std)
        elif args.dataset=='cifar10':
            mean=(0.4914, 0.4822, 0.4465)
            std=(0.2023, 0.1994, 0.2010)
            normalize = transforms.Normalize(mean, std)
        elif args.dataset=='tinyimagenet':
            mean=[0.485, 0.456, 0.406]
            std=[0.229, 0.224, 0.225]
            normalize = transforms.Normalize(mean, std)
        if args.vp_method=='Expand':
            visual_prompt = ExpansiveVisualPrompt(224, mask=configs['mask'], normalize=normalize).to(device)
        else:
            visual_prompt = PadVisualPrompt(args,normalize).to(device).to(device)
    args.normalize=normalize

    # Get the channel_list
    channel_list=[]
    for name, module in network.named_modules():                
        if isinstance(module,cl) and module.is_train:                 
            channel_list.append(module.weight.shape[0])
    mask_generator=None
    #Get Score Parameters
    if args.MaskGeneration:
        if args.Masknetwork=='lstm_full':            
            if args.lstm_catinput:
                mask_generator=LSTMAttentionModelCat(args.input_dim,args.hidden_dim,args.output_dim,args.heads,channel_list,args=args).to(device)
            else:
                mask_generator=LSTMAttentionModel(args.input_dim,args.hidden_dim,args.output_dim,args.heads,channel_list,args=args).to(device)
                if args.MLP:
                    mask_generator=MLPAttentionModel(args.input_dim,args.hidden_dim,args.output_dim,args.heads,channel_list,args=args).to(device)                           
        else:
            print("Not implemented yet!")
        print("Mask Generator params:",sum(p.numel() for p in mask_generator.parameters()))
        print(mask_generator)
        vp1=visual_prompt.program.detach()
        if args.vp_mode=='vp':
            vp=vp1
        elif args.vp_mode=='vp_sigmoid':
            vp=torch.sigmoid(vp1)
        elif args.vp_mode=='vp_sigmoid_mask':
            if args.vp_method=='Pad':
                vp=vp1*visual_prompt.mask
            else:
                vp=torch.sigmoid(vp1)*visual_prompt.mask
        else:
            assert False
       
        if args.Masknetwork=='lstm' or args.Masknetwork=='lstm_full':
            x_mask=GetWeights2Inputs(network,cl)
            out=mask_generator(x_mask,vp)
            thres,_=setPruneScoreLSTM(network,cl,out,sparsity=1-args.k,glob=args.glob)
            params=list(mask_generator.parameters())
    else:
        params=[param for param in network.parameters() if hasattr(param, 'is_score') and param.is_score]
        print(params[0].shape)
    
    optimizer_vp=None
   
    if args.Admm=='admm':
        print("Use Adamm Training Stragety!")
        optimizer1 = torch.optim.AdamW(params, lr=args.lr,weight_decay=args.weight_decay)
        optimizer2 = torch.optim.AdamW(visual_prompt.parameters(), lr=args.lr_vp,weight_decay=args.weight_decay)
        #scheduler2=get_lr(args,args.epoch_2,args.lr_vp)        
    else:
        print("One stage training!")
        if args.MaskGeneration:
            print("Mask Scores are Generated by Neural Networks!")
            if args.TestGM:
                print("Only optimize the parameters of neural network for Mask score!")
                optimizer1 = torch.optim.Adam(params, lr=args.lr)
            else:
                print("Optimize the parameters of neural network for Mask score and VP!")
                if args.DifferentLR:
                    print("Set different Lr for VP and params of NN")
                    if args.optimizer=='Adam':
                        optimizer1 = torch.optim.Adam([{'params':params,'lr':args.lr},{"params":list(visual_prompt.parameters()),'lr':args.lr_vp}])
                    elif args.optimizer=='SGD':
                        optimizer1 = torch.optim.SGD([{'params':params,'lr':args.lr},{"params":list(visual_prompt.parameters()),'lr':args.lr_vp}],weight_decay=args.weight_decay,momentum=0.9)
                    elif args.optimizer=='AdamW':
                        optimizer1 = torch.optim.AdamW([{'params':params,'lr':args.lr,'weight_decay':args.weight_decay},{"params":list(visual_prompt.parameters()),'lr':args.lr_vp,'weight_decay':args.weight_decay}])
                        #optimizer_vp=torch.optim.AdamW([{'params':mask_generator.parameters(),'lr':0.001,'weight_decay':0.0}])
                    elif args.optimizer=='AdamwAll':
                        p_network=[param for param in network.parameters() if not (hasattr(param, 'is_score') and param.is_score)]
                        optimizer1 = torch.optim.AdamW([{'params':params,'lr':args.lr,'weight_decay':args.weight_decay},{'params':p_network,'lr':0.001,'weight_decay':0.01},{"params":list(visual_prompt.parameters()),'lr':args.lr_vp,'weight_decay':args.weight_decay}])                
                else:
                    print("Set same Lr for both VP and params of NN")
                    if args.SepOptimizer:
                        print("Set two different optimizers!")
                        if args.prune_vp_opt=='adamw':
                            optimizer_vp=torch.optim.Adam([{"params":list(visual_prompt.parameters()),'lr':args.lr_vp,'weight_decay':args.weight_decay}])
                        else:
                            optimizer_vp=torch.optim.SGD([{"params":list(visual_prompt.parameters()),'lr':args.lr_vp,'weight_decay':1e-4}],momentum=0.9)
                        if args.prune_score_opt=='adamw':
                            optimizer_score=torch.optim.Adam([{'params':params,'lr':args.lr,'weight_decay':args.weight_decay}])
                        else:
                            optimizer_score=torch.optim.SGD([{'params':params,'lr':args.lr}],momentum=0.9,weight_decay=1e-4)

                        if args.AGC:
                            optimizer_score = AGC(params, optimizer_score)

                        scheduler_vp=get_schedular(optimizer_vp,args.hydra_scheduler,args,args.epoch_1,args.T_max)
                        scheduler_score=get_schedular(optimizer_score,args.score_scheduler,args,args.epoch_1,args.T_max)
                        
                    else:
                        if args.optimizer=='Adam':
                            optimizer_score = torch.optim.Adam(list(visual_prompt.parameters())+params, lr=args.lr)
                        elif args.optimizer=='SGD':
                            optimizer_score = torch.optim.SGD(list(visual_prompt.parameters())+params, lr=args.lr,weight_decay=args.weight_decay,momentum=0.9)
                        elif args.optimizer=='AdamW':
                            optimizer_score = torch.optim.AdamW(list(visual_prompt.parameters())+params, lr=args.lr,weight_decay=args.weight_decay)
                        optimizer_vp=None
                    
            print("OPtimizer:", optimizer_score,optimizer_vp)
            scheduler_score=get_schedular(optimizer_score,args.hydra_scheduler,args,args.epoch_1,args.T_max)
            #scheduler_vp=None
        
    # Make dir
    os.makedirs(save_path, exist_ok=True)
    logger = SummaryWriter(os.path.join(save_path, 'tensorboard'))
       
    # FLM
    mapping_sequence = generate_label_mapping_by_frequency(visual_prompt, network, loaders['train'])
    label_mapping = partial(label_mapping_base, mapping_sequence=mapping_sequence)
    # For Warmup    
    # Train    
    scaler = GradScaler()
    if args.AutoRestore:
        if os.path.exists(os.path.join(save_path, 'Prune.pth')):
            args.restore=True
        else:
            args.restore=False    
    if args.restore is False:
        print("Starting Standard Training Stragety Alternative Way!")
        if args.Twostage is False:
            epoch_1=args.epoch
            epoch_2=0
        else:
            epoch_1=args.epoch_1
            epoch_2=args.epoch_2
        
        total_step = args.epoch * len(loaders['train'])
        final_prune_time = int(total_step * 0)
        initial_prune_time = int(total_step * args.warmup)

        for epoch in range(epoch_1):
            if args.beidi or args.Twostage is False: 
                if args.Gradual_Pruning:
                    k=get_sparsity(1-args.k,epoch+1,0,args.warmup)
                    print("sparsity:",k)
                    setPruneRate(network,cl,ll,1-k)
                    args.current_k=1-k
                else:
                    setPruneRate(network,cl,ll,args.k)
                    args.current_k=args.k
            averaged_grads,averaged_grads_before=Training(epoch,label_mapping,cl,network,visual_prompt,mask_generator,args,loaders,optimizer_score,logger,optimizer_vp,loaders_score,wandb=wandb,name="P1")
            if args.gradient_flow:
                plot_grad_flow_avg(averaged_grads.items())
                plt.savefig(f'average_gradient_flow_epoch{epoch}.jpg')
                plt.close()
                plot_grad_flow_avg(averaged_grads_before.items())
                plt.savefig(f'average_gradient_flow_before_epoch{epoch}.jpg')
                plt.close()
            
            if scheduler_score is not None:                
                scheduler_score.step()
            if scheduler_vp is not None:
                scheduler_vp.step()
            acc=Test(epoch,label_mapping,cl,network,visual_prompt,mask_generator,args,loaders_score,optimizer_score,logger,wandb=wandb,name="P1")
        if args.beidi:
            # Data
            if args.dataset=='tinyimagenet':
                loaders, configs=get_tinyimagenet_dataloaders(args,resize=args.resize, dataAug=args.Aug)
            else:
                loaders, configs = prepare_expansive_data(args.dataset, data_path=data_path, resize=args.resize, dataAug=args.Aug)

            #loaders, configs = prepare_expansive_data(args.dataset, data_path=data_path,resize=128,dataAug=args.Aug)
            normalize = transforms.Normalize(IMAGENETNORMALIZE['mean'], IMAGENETNORMALIZE['std'])
            visual_prompt = ExpansiveVisualPrompt(224, mask=configs['mask'], normalize=normalize).to(device)
            optimizer2 = torch.optim.Adam(list(visual_prompt.parameters()), lr=args.lr)
        if args.k<1.0:
            setPruneRate(network,cl,ll,1.0)
        state_dict = {
            "visual_prompt_dict": visual_prompt.state_dict(),
            "network":network.state_dict(),
            "mask_generator":mask_generator.state_dict(),
            "best_acc": acc,
            "mapping_sequence": mapping_sequence,
        }
        torch.save(state_dict, os.path.join(save_path, 'Prune.pth'))     
    else:
        print("restore!")

        best_ckpt = torch.load(os.path.join(save_path,'Prune.pth'))
        if type(best_ckpt['network']) is not type(network):
            pass
        else:
            network=best_ckpt['network']

        visual_prompt.load_state_dict(best_ckpt['visual_prompt_dict'])
        mask_generator.load_state_dict(best_ckpt['mask_generator'])
        label_mapping = partial(label_mapping_base, mapping_sequence=best_ckpt['mapping_sequence'])
        print("K",args.k)
        #args.k=1.0
        setPruneRate(network,cl,ll,args.k)
        mask_generator.eval()
        thre,reg_loss=sparseModel(visual_prompt,mask_generator,network,args,args.vp_detach,cl)        
        if args.ChannelPrune=='channel':
            Calculate_mask(network,args.bn_detach,thres=thre,glob=args.glob,args=args)
        acc=Test(0,label_mapping,cl,network,visual_prompt,mask_generator,args,loaders,None,logger,wandb=wandb,name="restored")
        print("saved ACC",best_ckpt['best_acc'],"loaded ACC",acc)

    if args.finetune:
        save_path = os.path.join(args.save_dir, args.network, args.dataset, 'Prune_VP'+str(args.vp_method),
        'SIZE'+str(args.output_size)+'_'+str(args.input_size)+'_'+str(args.pad_size),
        args.prune_vp_opt+'_'+args.prune_score_opt, 
        'LR'+str(args.lr)+'_'+str(args.lr_vp),  
        'Masknetwork'+str(args.Masknetwork)+"_Hiddensize_"+str(args.hidden_dim),'normalize'+str(args.normalize)+'Aug'+str(args.Aug),'glob'+str(args.glob),'k'+str(args.k),str(args.ft_lr)+'_'+str(args.ft_lr_vp)+'_'+args.ft_weight_opt+'_'+args.ft_vp_opt+'clean_finetune'+str(args.clean_finetune)+"_"+args.weight_scheduler)
        print(save_path)
        os.makedirs(save_path, exist_ok=True)

        args.gradient_flow=False       
        if args.clean_finetune=='Original':            
            #loaders1, configs = prepare_expansive_data(args.dataset, data_path=data_path, resize=224, dataAug=True)            
            if args.dataset=='tinyimagenet':
                loaders1, configs=get_tinyimagenet_dataloaders(args,resize=224, dataAug=True)
            else:
                loaders1, configs = prepare_expansive_data(args.dataset, data_path=data_path, resize=224, dataAug=True)            
            visual_prompt = ExpansiveVisualPrompt(224, mask=configs['mask'], normalize=normalize).to(device)
        elif args.clean_finetune=='VP':
            if args.ft_init_vp:
                if args.vp_method=='Expand':
                    visual_prompt = ExpansiveVisualPrompt(224, mask=configs['mask'], normalize=normalize).to(device)
                else:
                    visual_prompt = PadVisualPrompt(args,normalize).to(device)
            #else:
            loaders1=loaders
        elif args.clean_finetune=='Aug':
            #loaders1, configs = prepare_expansive_data(args.dataset, data_path=data_path, resize=args.resize, dataAug=True)
            if args.dataset=='tinyimagenet':
                loaders1, configs=get_tinyimagenet_dataloaders(args,resize=args.resize, dataAug=True)
            else:
                loaders1, configs = prepare_expansive_data(args.dataset, data_path=data_path, resize=args.resize, dataAug=True)\
            
            if args.ft_init_vp:
                if args.vp_method=='Expand':
                    visual_prompt = ExpansiveVisualPrompt(224, mask=configs['mask'], normalize=normalize).to(device)
                else:
                    visual_prompt = PadVisualPrompt(args,normalize).to(device)

        else:
            assert False    
        
        for name, module in network.named_modules():                
            if isinstance(module,cl):                 
                if type(module.adj) is not float:
                    module.adj=module.adj.detach()
                if type(module.pre_adj) is not float:
                    module.pre_adj=module.pre_adj.detach()
            if isinstance(module,PrunableBatchNorm2d):
                if type(module.weight_mask) is not float:
                    module.weight_mask=module.weight_mask.detach() 
        params=[param for param in network.parameters() if  not (hasattr(param, 'is_score') and param.is_score) ]
        if args.opt_ft_vp:
            args.Test_val=True
            
            if args.Testalternate:
                args.alternate=True
                print("Testalternate")
                params=[param for param in network.parameters() if  not (hasattr(param, 'is_score') and param.is_score) ]
                if args.ft_vp_opt =='adamw':
                    optimizer_vp = torch.optim.AdamW([{'params':list(visual_prompt.parameters()),'lr':args.ft_lr_vp}])
                else:
                    optimizer_vp = torch.optim.SGD([{'params':list(visual_prompt.parameters()),'lr':args.ft_lr_vp}],momentum=0.9,weight_decay=1e-4)
                if args.ft_weight_opt=='adamw':
                    optimizer_weight=torch.optim.AdamW([{"params":params,'lr':args.ft_lr}])
                else:
                    #optimizer_weight=torch.optim.AdamW([{"params":params,'lr':args.ft_lr}])
                    optimizer_weight=torch.optim.SGD([{"params":params,'lr':args.ft_lr}],momentum=0.9,weight_decay=args.ft_WD)
                
                
                scheduler_score=get_schedular(optimizer_weight,args.weight_scheduler,args,args.finetuneEpoch,args.finetuneEpoch)                
                
                scheduler_vp=get_schedular(optimizer_vp,'cosine',args,args.finetuneEpoch,args.finetuneEpoch)
                print("prefinetuning",optimizer_weight,optimizer_vp)
            else:
                args.alternate=False
                params=[param for param in network.parameters() if  not (hasattr(param, 'is_score') and param.is_score) ]
                optimizer_weight = torch.optim.AdamW([{'params':params,'lr':args.ft_lr},{"params":list(visual_prompt.parameters()),'lr':args.ft_lr_vp,'weight_decay':0.0}])
                print("prefinetuning",optimizer_weight)
                scheduler_score=get_schedular(optimizer_weight,args.weight_scheduler,args,args.finetuneEpoch,args.finetuneEpoch)
                scheduler_vp=None 
            for epoch in range(args.finetuneEpoch):
                Training(epoch,label_mapping,cl,network,visual_prompt,mask_generator,args,loaders1,optimizer_weight,logger,optimizer_vp=optimizer_vp,wandb=wandb,loader_score=loaders_score,name="FinetuneTrain")
                scheduler_score.step()
                if scheduler_vp is not None:
                    scheduler_vp.step()
                acc=Test(epoch,label_mapping,cl,network,visual_prompt,mask_generator,args,loaders_score,optimizer_weight,logger,wandb=wandb,name="FinetuneTest")
        
        state_dict = {
            "visual_prompt_dict": visual_prompt.state_dict(),
            "network":network,
            "mask_generator":mask_generator.state_dict(),
            "best_acc": acc,
            "mapping_sequence": mapping_sequence,
        }
        torch.save(state_dict, os.path.join(save_path, 'FT.pth'))

