import argparse
import os, sys
import shutil
import time
import copy

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models
import random
import numpy as np
from scipy.spatial import distance
from collections import OrderedDict
from tqdm import tqdm
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import matplotlib.pyplot as plt

from utils import convert_secs2time, time_string, time_file_str, timing, AverageMeter, print_log, get_balanced_subset
from metric import accuracy
import models
from wrapper import Wrapper
from trainer import train_ssp_head, train_knowledge_distillation
from visualizer import visualize_resnet50_feature_maps

model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))
# model_names.append('inception_v3')

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--sampled_data', type=str)
parser.add_argument('--save_dir', type=str, default='./snapshots', help='Folder to save checkpoints and log.')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
                    choices=model_names,
                    help='model architecture: ' +
                         ' | '.join(model_names) +
                         ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=100, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=64, type=int, metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W',
                    help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=200, type=int, metavar='N', help='print frequency (default: 200)')
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
parser.add_argument('--use_pretrain', dest='use_pretrain', action='store_true', help='use pre-trained model or not')
parser.add_argument('--val_freq', type=int, default=1)

# compress rate
parser.add_argument('--rate_norm', type=float, default=1, help='the remaining ratio of pruning based on Norm')
# parser.add_argument('--rate_dist', type=float, default=0.1, help='the reducing ratio of pruning based on Distance')
parser.add_argument('--min_rate_dist', type=float, default=0.15)
parser.add_argument('--max_rate_dist', type=float, default=0.3)

parser.add_argument('--layer_begin', type=int, default=3, help='compress layer of model')
parser.add_argument('--layer_end', type=int, default=3, help='compress layer of model')
parser.add_argument('--layer_inter', type=int, default=1, help='compress layer of model')
parser.add_argument('--epoch_prune', type=int, default=1, help='epoch interval of pruning')
parser.add_argument('--skip_downsample', type=int, default=1, help='compress layer of model')
parser.add_argument('--use_sparse', dest='use_sparse', action='store_true', help='use sparse model as initial or not')
parser.add_argument('--sparse',
                    type=str, metavar='PATH', help='path of sparse model')
parser.add_argument('--lr_adjust', type=int, default=30, help='number of epochs that change learning rate')
parser.add_argument('--VGG_pruned_style', choices=["CP_5x", "Thinet_conv"],
                    help='number of epochs that change learning rate')

# optimization
parser.add_argument('--init_optimize', action='store_true')
parser.add_argument('--num_sample_for_init_optimize', type=int, default=2048)
parser.add_argument('--init_optimize_algorithm', type=str, choices=['default','default_GD','SSKD'], default='default_GD')
parser.add_argument('--init_optimize_lr', type=float, default=0.005)
parser.add_argument('--init_optimize_epoch', type=int, default=50)

# for SSKD
parser.add_argument('--transform', type=str, choices=["rotation", "adv_attack", "random_mask"], default="rotation")
parser.add_argument('--t_epoch', type=int, default=100)
parser.add_argument('--s_epoch', type=int, default=130)
parser.add_argument('--reg_weight', type=float, default=0.0)
parser.add_argument('--reg_layer_idx', type=int, default=0)

# visualization
parser.add_argument("--vis_feature_maps", action="store_true")

args = parser.parse_args()
args.use_cuda = torch.cuda.is_available()
assert args.use_cuda

args.prefix = time_file_str()

def main():
    best_prec1 = 0

    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)
    log = open(os.path.join(args.save_dir, '{}.{}.log'.format(args.arch, args.prefix)), 'w')

    # version information
    print_log("PyThon  version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("PyTorch version : {}".format(torch.__version__), log)
    print_log("cuDNN   version : {}".format(torch.backends.cudnn.version()), log)
    print_log("Vision  version : {}".format(torchvision.__version__), log)
    # create model
    model = None
    if "resnet" in args.arch or "vgg" in args.arch:
        print_log("=> creating model '{}'".format(args.arch), log)
        model = models.__dict__[args.arch](pretrained=args.use_pretrain)
    elif args.arch == 'inception_v3' or args.arch == 'densenet121' or args.arch == 'inception_v4':
        # model = timm.create_model(args.arch, pretrained=args.use_pretrain)
        model = models.__dict__[args.arch](pretrained=args.use_pretrain)
    else:
        raise NotImplementedError
    model = Wrapper(model)
    model.cuda()
    
    if args.use_sparse:
        raise NotImplementedError
    # print_log("=> Model : {}".format(model), log)
    print_log("=> Model: {}".format(args.arch), log)
    print_log("=> parameter : {}".format(args), log)
    print_log("=> sampled data: {}".format(args.sampled_data), log)
    print_log("Norm Pruning Rate: {}".format(args.rate_norm), log)
    print_log("Distance Pruning Rate: {} - {}".format(args.min_rate_dist, args.max_rate_dist), log)
    print_log("Init optimzation: {}".format(args.init_optimize), log)
    print_log("Number of Samples for Init Optimization: {}".format(args.num_sample_for_init_optimize), log)
    print_log("Init Optimize Algorithm: {}".format(args.init_optimize_algorithm), log)
    print_log("Init Optimize Learning Rate: {}".format(args.init_optimize_lr), log)
    print_log("Init Optimize Epoch: {}".format(args.init_optimize_epoch), log)
    print_log("SSKD t epoch: {}".format(args.t_epoch), log)
    print_log("SSKD s epoch: {}".format(args.s_epoch), log)
    print_log("SSKD transform: {}".format(args.transform), log)
    print_log("Reg weight: {}".format(args.reg_weight), log)
    print_log("Reg layer index: {}".format(args.reg_layer_idx), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("Skip downsample : {}".format(args.skip_downsample), log)
    print_log("Workers         : {}".format(args.workers), log)
    print_log("Batch size: {}".format(args.batch_size), log)
    print_log("Learning-Rate   : {}".format(args.lr), log)
    print_log("Use Pre-Trained : {}".format(args.use_pretrain), log)
    print_log("lr adjust : {}".format(args.lr_adjust), log)
    print_log("VGG pruned style : {}".format(args.VGG_pruned_style), log)
    print_log("Prefix: {}".format(args.prefix), log)

    '''if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()'''

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.backbone.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.backbone.load_state_dict(checkpoint['state_dict'])
            model.proj_head.load_state_dict(checkpoint['proj_head_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_log("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    normalize_mean = [0.485, 0.456, 0.406]
    normalize_std = [0.229, 0.224, 0.225]
    train_transform = None
    val_transform = None
    if "resnet" in args.arch or "vgg" in args.arch:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        val_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        config = resolve_data_config({}, model=model.backbone)
        normalize_mean = config['mean']
        normalize_std = config['std']
        print(config)
        train_transform = create_transform(**config)
        val_transform = create_transform(**config)

    train_dataset = datasets.ImageFolder(
        root = os.path.join(args.data, 'train'),
        transform = train_transform
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, sampler=None)

    val_dataset = datasets.ImageFolder(
        root = os.path.join(args.data, 'val'),
        transform = val_transform
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
    )

    if args.evaluate:
        validate(val_loader, model, criterion, log)
        return
    
    val_acc = validate(val_loader, model, criterion, log)
    print(">>>>> accu of the original model is: {:}".format(val_acc))

    filename = os.path.join(args.save_dir, 'checkpoint.{:}.{:}.pth.tar'.format(args.arch, args.prefix))
    bestname = os.path.join(args.save_dir, 'best.{:}.{:}.pth.tar'.format(args.arch, args.prefix))

    m = Mask(model, args.rate_norm, args.min_rate_dist, args.max_rate_dist, log)
    m.init_length()
    m.model.eval()
    num_batches = int(args.num_sample_for_init_optimize / args.batch_size)
    sample_dataset = None
    sample_dataloader = None
    if args.init_optimize:
        num_samples_per_label = int(args.num_sample_for_init_optimize/1000)
        if args.sampled_data != None:
            sample_dataset = datasets.ImageFolder(
                root=args.sampled_data,
                transform = train_transform
            )
        else:
            raise NotImplementedError
            # sample_dataset = get_balanced_subset(args.data, 'train', num_samples_per_label, train_transform)
        if args.init_optimize_algorithm == 'SSKD':
            sample_dataloader = torch.utils.data.DataLoader(
                sample_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
            )
        else:
            sample_dataloader = torch.utils.data.DataLoader(
                sample_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
            )
        print(f"There are {len(sample_dataloader)} batches of samples for initial optimization.")
        if args.init_optimize_algorithm != 'SSKD':
            m.save_original_feature_maps_multi_batch(sample_dataloader)

    m.init_mask()
    m.do_mask()
    m.do_similar_mask()
    model = m.model
    # m.if_zero()
    if args.use_cuda:
        model = model.cuda()
    val_acc_2 = validate(val_loader, model, criterion, log)
    print_log(">>>>> accu before init opimization is: {:}".format(val_acc_2), log)

    vis_batch, vis_target = None, None
    if args.arch == 'resnet50' and args.vis_feature_maps:
        for batch, target in val_loader:
            vis_batch = batch
            vis_target = target
            m.set_vis_batch(vis_batch, vis_target)
            break
        visualize_resnet50_feature_maps(m.original_model, vis_batch, vis_target, args.save_dir, 'original')
        visualize_resnet50_feature_maps(m.model, vis_batch, vis_target, args.save_dir, 'pruned')
    
    if args.init_optimize:
        if args.init_optimize_algorithm != 'SSKD':
            dist_before = m.get_feature_maps_dist(sample_dataloader)
            print_log(f"Before init optimization, distance between feature maps of the original network and the pruned network is {dist_before}", log)
        if args.init_optimize_algorithm == 'SSKD':
            m.optimize_parameters_with_SSKD(sample_dataloader, val_loader, 
                                            t_epoch=args.t_epoch, s_epoch=args.s_epoch,
                                            mean=normalize_mean, std=normalize_std)
        elif args.init_optimize_algorithm == 'default_GD':
            m.optimize_parameters_with_GD(sample_dataloader, GD_epoch=args.init_optimize_epoch)
        else:
            m.optimize_parameters(sample_dataloader, GD_epoch=args.init_optimize_epoch)
        val_acc_after_2 = validate(val_loader, model, criterion, log)
        print(">>>>> accu after init optimization is: {:}".format(val_acc_after_2))
        best_model = model
        m.do_mask()
        m.do_similar_mask()
        if args.init_optimize_algorithm != 'SSKD':
            dist_after = m.get_feature_maps_dist(sample_dataloader)
            print_log(f"After init optimization, distance between feature maps of the original network and the pruned network is {dist_after}", log)
            diff = []
            for i in range(len(dist_before)):
                diff.append(dist_before[i] - dist_after[i])
                if dist_before[i] < dist_after[i]:
                    print(i, end=' ')
            print("")
            print_log(f"diff: {diff}", log)

    val_acc_after_2 = validate(val_loader, model, criterion, log)
    print_log(">>>>> accu after init optimization is: {:}".format(val_acc_after_2), log)
    best_model = model
    # if args.arch == 'resnet50' and args.vis_feature_maps:
    #     visualize_resnet50_feature_maps(m.model, vis_batch, vis_target, args.save_dir, 'init_optimize')
    
    filename = os.path.join(args.save_dir, 'checkpoint.before_finetuning.{:}.{:}.pth.tar'.format(args.arch, args.prefix))
    save_checkpoint({
        'epoch': 0,
        'arch': args.arch,
        'state_dict': best_model.backbone.state_dict(),
        'proj_head_state_dict': best_model.proj_head.state_dict(),
        'best_prec1': val_acc_after_2,
        'optimizer': optimizer.state_dict(),
    }, False, filename, bestname) 
    
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.val * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
        print_log(
            ' [{:s}] :: {:3d}/{:3d} ----- [{:s}] {:s}'.format(args.arch, epoch, args.epochs, time_string(), need_time),
            log)

        # train for one epoch
        train(train_loader, best_model, criterion, optimizer, epoch, log, m)
        # evaluate on validation set
        if epoch % args.val_freq == 0:
            val_acc_1 = validate(val_loader, best_model, criterion, log)
        
        if epoch % args.epoch_prune == 0 or epoch == args.epochs - 1:
            m.model = best_model
            # m.if_zero()
            m.init_mask()
            m.do_mask()
            m.do_similar_mask()
            # m.if_zero()
            best_model = m.model
            if args.use_cuda:
                best_model = best_model.cuda()

        if epoch % args.val_freq == 0:
            val_acc_2 = validate(val_loader, best_model, criterion, log)

        # remember best prec@1 and save checkpoint
        is_best = val_acc_2 > best_prec1
        best_prec1 = max(val_acc_2, best_prec1)
        filename = os.path.join(args.save_dir, 'checkpoint.epoch{}.{:}.{:}.pth.tar'.format(epoch+1, args.arch, args.prefix))
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': best_model.backbone.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
        }, is_best, filename, bestname)
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    if args.arch == 'resnet50' and args.vis_feature_maps:
        visualize_resnet50_feature_maps(m.model, vis_batch, vis_target, args.save_dir, 'fine_tune')
    log.close()

'''def import_sparse(model):
    checkpoint = torch.load(args.sparse)
    new_state_dict = OrderedDict()
    for k, v in checkpoint['state_dict'].items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    print("sparse_model_loaded")
    return model'''

def train(train_loader, model, criterion, optimizer, epoch, log, m):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        input = input.cuda()
        target = target.cuda()
        # input_var = torch.autograd.Variable(input)
        # target_var = torch.autograd.Variable(target)

        # compute output
        output, _, _ = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.data.item(), input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()

        # Mask grad for iteration
        m.do_grad_mask()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print_log('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                epoch, i, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5), log)

def validate(val_loader, model, criterion, log):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        input = input.cuda()
        target = target.cuda()
        # input_var = torch.autograd.Variable(input, volatile=True)
        # target_var = torch.autograd.Variable(target, volatile=True)

        # compute output
        output, _, _ = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.data.item(), input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print_log('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                i, len(val_loader), batch_time=batch_time, loss=losses,
                top1=top1, top5=top5), log)

    print_log(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1, top5=top5,
                                                                                           error1=100 - top1.avg), log)

    return top1.avg

def save_checkpoint(state, is_best, filename, bestname):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, bestname)

def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // args.lr_adjust))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

class Mask:
    def __init__(self, model, rate_norm_per_layer, min_rate_dist_per_layer, max_rate_dist_per_layer, log):
        self.log = log
        self.model_size = {}
        self.model_length = {}
        self.compress_rate = {}
        self.distance_rate = {}
        self.mat = {}
        self.model = model
        self.original_model = copy.deepcopy(model)
        self.mask_index = []
        self.filter_small_index = {}
        self.filter_large_index = {}
        self.similar_matrix = {}

        self.original_feature_in = None
        self.original_feature_out = None
        self.pruned_feature_in = None
        self.pruned_feature_out = None

        self.vis_batch = None
        self.vis_target = None

        self.init_rate(rate_norm_per_layer, min_rate_dist_per_layer, max_rate_dist_per_layer)

    def set_vis_batch(self, vis_batch, vis_target):
        self.vis_batch = vis_batch.clone()
        self.vis_target = vis_target.clone()

    def original_hook(self, module, feature_in, feature_out):
        self.original_feature_in.append(feature_in[0].detach().clone())
        self.original_feature_out.append(feature_out.detach().clone())

    def pruned_hook(self, module, feature_in, feature_out):
        self.pruned_feature_in.append(feature_in[0].detach().clone())
        self.pruned_feature_out.append(feature_out.detach().clone())

    def get_codebook(self, weight_torch, compress_rate, length):
        weight_vec = weight_torch.view(length)
        weight_np = weight_vec.cpu().numpy()

        weight_abs = np.abs(weight_np)
        weight_sort = np.sort(weight_abs)

        threshold = weight_sort[int(length * (1 - compress_rate))]
        weight_np[weight_np <= -threshold] = 1
        weight_np[weight_np >= threshold] = 1
        weight_np[weight_np != 1] = 0

        print("codebook done")
        return weight_np

    def get_filter_codebook(self, weight_torch, compress_rate, length):
        codebook = np.ones(length)
        if len(weight_torch.size()) == 4:
            filter_pruned_num = int(weight_torch.size()[0] * (1 - compress_rate))
            weight_vec = weight_torch.view(weight_torch.size()[0], -1) # (out_channels, kernel_size*kernel_size*in_channels)
            
            # norm1 = torch.norm(weight_vec, 1, 1)
            # norm1_np = norm1.cpu().numpy()
            norm2 = torch.norm(weight_vec, p=2, dim=1) # (out_channels,)
            norm2_np = norm2.cpu().numpy()
            filter_index = norm2_np.argsort()[:filter_pruned_num]
            kernel_length = weight_torch.size()[1] * weight_torch.size()[2] * weight_torch.size()[3]
            for x in range(0, len(filter_index)):
                codebook[filter_index[x] * kernel_length: (filter_index[x] + 1) * kernel_length] = 0
            # print("filter codebook done")
        elif len(weight_torch.size()) == 2:
            weight_torch = weight_torch.view(weight_torch.size()[0], weight_torch.size()[1], 1, 1)
            codebook = self.get_filter_codebook(weight_torch, compress_rate, length)
            # print("filter codebook for fc done")
        else:
            pass
        return codebook

    # optimize for fast ccalculation
    def get_filter_similar(self, weight_torch, compress_rate, distance_rate, length):
        codebook = np.ones(length)
        if len(weight_torch.size()) == 4:
            filter_pruned_num = int(weight_torch.size()[0] * (1 - compress_rate))
            similar_pruned_num = int(weight_torch.size()[0] * distance_rate)
            weight_vec = weight_torch.view(weight_torch.size()[0], -1)

            # norm1 = torch.norm(weight_vec, 1, 1)
            # norm1_np = norm1.cpu().numpy()
            norm2 = torch.norm(weight_vec, p=2, dim=1)
            norm2_np = norm2.cpu().numpy()
            filter_small_index = []
            filter_large_index = []
            filter_large_index = norm2_np.argsort()[filter_pruned_num:]
            filter_small_index = norm2_np.argsort()[:filter_pruned_num]

            # # distance using pytorch function
            # similar_matrix = torch.zeros((len(filter_large_index), len(filter_large_index)))
            # for x1, x2 in enumerate(filter_large_index):
            #     for y1, y2 in enumerate(filter_large_index):
            #         # cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
            #         # similar_matrix[x1, y1] = cos(weight_vec[x2].view(1, -1), weight_vec[y2].view(1, -1))[0]
            #         pdist = torch.nn.PairwiseDistance(p=2)
            #         similar_matrix[x1, y1] = pdist(weight_vec[x2].view(1, -1), weight_vec[y2].view(1, -1))[0][0]
            # # more similar with other filter indicates large in the sum of row
            # similar_sum = torch.sum(torch.abs(similar_matrix), 0).numpy()

            # distance using numpy function
            indices = torch.LongTensor(filter_large_index).cuda()
            weight_vec_after_norm = torch.index_select(weight_vec, 0, indices).cpu().numpy()
            # for euclidean distance
            similar_matrix = distance.cdist(weight_vec_after_norm, weight_vec_after_norm, metric='euclidean')
            # for cos similarity
            # similar_matrix = 1 - distance.cdist(weight_vec_after_norm, weight_vec_after_norm, 'cosine')
            similar_sum = np.sum(np.abs(similar_matrix), axis=0)

            # for distance similar: get the filter index with largest similarity == small distance
            similar_large_index = similar_sum.argsort()[similar_pruned_num:]
            similar_small_index = similar_sum.argsort()[:  similar_pruned_num]
            similar_index_for_filter = [filter_large_index[i] for i in similar_small_index]

            kernel_length = weight_torch.size()[1] * weight_torch.size()[2] * weight_torch.size()[3]
            for x in range(0, len(similar_index_for_filter)):
                codebook[
                similar_index_for_filter[x] * kernel_length: (similar_index_for_filter[x] + 1) * kernel_length] = 0
            # print("similar index done")
        else:
            pass
        return codebook

    def convert2tensor(self, x):
        x = torch.FloatTensor(x)
        return x

    def init_length(self):
        for index, item in enumerate(self.model.backbone.parameters()):
            self.model_size[index] = item.size()

        for index1 in self.model_size:
            for index2 in range(0, len(self.model_size[index1])):
                if index2 == 0:
                    self.model_length[index1] = self.model_size[index1][0]
                else:
                    self.model_length[index1] *= self.model_size[index1][index2]

    def init_rate(self, rate_norm_per_layer, min_rate_dist_per_layer, max_rate_dist_per_layer):
        if "resnet" in args.arch:
            for index, item in enumerate(self.model.backbone.parameters()):
                self.compress_rate[index] = 1
                self.distance_rate[index] = 1
            for key in range(args.layer_begin, args.layer_end + 1, args.layer_inter):
                self.compress_rate[key] = rate_norm_per_layer
                self.distance_rate[key] = np.random.uniform(low=min_rate_dist_per_layer, high=max_rate_dist_per_layer)

            # different setting for different architecture
            if args.arch == 'resnet18':
                # last index include last fc layer
                last_index = 60
                skip_list = [21, 36, 51]
            elif args.arch == 'resnet34':
                last_index = 108
                skip_list = [27, 54, 93]
            elif args.arch == 'resnet50':
                last_index = 159
                skip_list = [12, 42, 81, 138]
            elif args.arch == 'resnet101':
                last_index = 312
                skip_list = [12, 42, 81, 291]
            elif args.arch == 'resnet152':
                last_index = 465
                skip_list = [12, 42, 117, 444]
            self.mask_index = [x for x in range(0, last_index, 3)]
            # skip downsample layer
            if args.skip_downsample == 1:
                for x in skip_list:
                    self.compress_rate[x] = 1
                    self.mask_index.remove(x)
                    print(self.mask_index)
            else:
                pass
        elif args.arch == "inception_v3":
            for index, item in enumerate(self.model.backbone.parameters()):
                self.compress_rate[index] = 1
                self.distance_rate[index] = 1
            for key in range(args.layer_begin, args.layer_end + 1, args.layer_inter):
                self.compress_rate[key] = rate_norm_per_layer
                self.distance_rate[key] = np.random.uniform(low=min_rate_dist_per_layer, high=max_rate_dist_per_layer)
            
            last_index = 282
            # skip_list = [33, 54, 75, 117, 147, 177, 207, 252, 279]
            skip_list = list(range(0, 118, 3)) + [147, 177, 207, 252, 279]
            self.mask_index = [x for x in range(0, last_index, 3)]
            if args.skip_downsample == 1:
                for x in skip_list:
                    self.compress_rate[x] = 1
                    self.mask_index.remove(x)
                    print(self.mask_index)
        elif "vgg" in args.arch or args.arch == 'densenet121' or args.arch == 'inception_v4':
            for index, item in enumerate(self.model.backbone.parameters()):
                self.compress_rate[index] = 1
                self.distance_rate[index] = 1
            for key in range(args.layer_begin, args.layer_end + 1, args.layer_inter):
                self.compress_rate[key] = rate_norm_per_layer
                self.distance_rate[key] = np.random.uniform(low=min_rate_dist_per_layer, high=max_rate_dist_per_layer)
            
            self.mask_index = [x for x in range(args.layer_begin, args.layer_end + 1, args.layer_inter)]
        else:
            raise NotImplementedError

    def init_mask(self):
        for index, item in enumerate(self.model.backbone.parameters()):
            if index in self.mask_index:
                # mask for norm criterion
                self.mat[index] = self.get_filter_codebook(item.data, self.compress_rate[index],
                                                           self.model_length[index])
                self.mat[index] = self.convert2tensor(self.mat[index])
                if args.use_cuda:
                    self.mat[index] = self.mat[index].cuda()

                # mask for distance criterion
                self.similar_matrix[index] = self.get_filter_similar(item.data, self.compress_rate[index],
                                                                     self.distance_rate[index],
                                                                     self.model_length[index])
                self.similar_matrix[index] = self.convert2tensor(self.similar_matrix[index])
                if args.use_cuda:
                    self.similar_matrix[index] = self.similar_matrix[index].cuda()
        print("mask Ready")

    def do_mask(self):
        for index, item in enumerate(self.model.backbone.parameters()):
            if index in self.mask_index:
                a = item.data.reshape(self.model_length[index])
                b = a * self.mat[index]
                item.data = b.reshape(self.model_size[index])
        print("mask Done")

    def do_similar_mask(self):
        for index, item in enumerate(self.model.backbone.parameters()):
            if index in self.mask_index:
                a = item.data.view(self.model_length[index])
                b = a * self.similar_matrix[index]
                item.data = b.view(self.model_size[index])
        print("mask similar Done")

    def do_grad_mask(self):
        for index, item in enumerate(self.model.backbone.parameters()):
            if index in self.mask_index:
                a = item.grad.data.reshape(self.model_length[index])
                # reverse the mask of model
                # b = a * (1 - self.mat[index])
                b = a * self.mat[index]
                b = b * self.similar_matrix[index]
                item.grad.data = b.reshape(self.model_size[index])
        # print("grad zero Done")

    def if_zero(self):
        for index, item in enumerate(self.model.backbone.parameters()):
            if index in self.mask_index:
                # if index in [x for x in range(args.layer_begin, args.layer_end + 1, args.layer_inter)]:
                a = item.data.view(self.model_length[index])
                b = a.cpu().numpy()

                print("layer: %d, number of nonzero weight is %d, zero is %d" % (
                    index, np.count_nonzero(b), len(b) - np.count_nonzero(b)))
    
    @torch.no_grad()
    def save_original_feature_maps_multi_batch(self, loader):
        handles = []
        for m in list(self.model.backbone.modules()):
            if isinstance(m, nn.Conv2d):
                handle = m.register_forward_hook(self.original_hook)
                handles.append(handle)
        num_conv_layers = len(handles)
        for i, (batch, _) in enumerate(tqdm(loader)):
            batch = batch.cuda()
            self.original_feature_in = []
            self.original_feature_out = []
            self.model.backbone(batch)
            torch.save(self.original_feature_in, f"/tmp2/edward0530/tmp/batch{i}_feature_in.{args.arch}.{args.prefix}.pt")
            torch.save(self.original_feature_out, f"/tmp2/edward0530/tmp/batch{i}_feature_out.{args.arch}.{args.prefix}.pt")
            del self.original_feature_in
            del self.original_feature_out
        for handle in handles:
            handle.remove()
    
    def save_pruned_feature_maps_single_layer(self, loader, conv_layer_idx):
        self.pruned_feature_in = []
        self.pruned_feature_out = []
        current_idx = 0
        for m in list(self.model.backbone.modules()):
            if isinstance(m, nn.Conv2d):
                if current_idx == conv_layer_idx:
                    handle = m.register_forward_hook(self.pruned_hook)
                    break
                current_idx += 1

        for i, (batch, _) in enumerate(loader):
            batch = batch.cuda()
            self.model.backbone(batch)
        handle.remove()

    @torch.no_grad()
    def get_feature_maps_dist(self, loader):
        dist = []
        original_feature_out = torch.load(f"/tmp2/edward0530/tmp/batch0_feature_out.{args.arch}.{args.prefix}.pt")
        for i in range(len(original_feature_out)):
            flatten_output_feat = original_feature_out[i].flatten(start_dim=2).transpose(1, 2)
            self.save_pruned_feature_maps_single_layer(loader, i)
            pruned_feature_out = self.pruned_feature_out[0]
            pruned_flatten_output_feat = pruned_feature_out.flatten(start_dim=2).transpose(1, 2)
            mask_channel = torch.abs(torch.sum(pruned_flatten_output_feat, (0, 1))) < 1e-3 
            remain_channel = torch.logical_not(mask_channel)
            dist.append(torch.norm(pruned_flatten_output_feat[:,:,remain_channel] - flatten_output_feat[:,:,remain_channel]).item() / torch.numel(pruned_flatten_output_feat[:,:,remain_channel]))
        return dist

    def optimize_parameters(self, train_loader, GD_epoch=5, plot_figure=True):
        if GD_epoch > 0 and plot_figure == True:
            os.mkdir(f"figure/{args.arch}_{args.prefix}")
            with open(f"figure/{args.arch}_{args.prefix}/parameter", 'w') as f:
                f.write(str(args))
        idx = 0
        self.model.eval()
        for m in tqdm(list(self.model.backbone.modules())):
            if isinstance(m, nn.Conv2d):
                if idx == 0: # skip the first layer
                    idx += 1 
                    continue
                self.save_pruned_feature_maps_single_layer(train_loader, idx)

                optimal_param_list = []

                unfold = nn.Unfold(m.kernel_size, m.dilation, m.padding, m.stride)

                pruned_flatten_output_feat = self.pruned_feature_out[0].flatten(start_dim=2).transpose(1, 2) # (B, H'*W', C')
                mask_channel = torch.abs(torch.sum(pruned_flatten_output_feat, (0, 1))) < 1e-3 # (C',)
                remain_channel = torch.logical_not(mask_channel) # (C',)

                for batch_idx in range(len(train_loader)):
                    input_feat = self.pruned_feature_in[batch_idx] # (B, C, H, W)
                    input_channel = input_feat.size()[1]
                    unfold_input_feat = unfold(input_feat) # (B, C*k*k, H'*W')
                    unfold_input_feat = unfold_input_feat.transpose(1, 2) # (B, H'*W', C*k*k)
                    if m.bias != None:
                        dummy_input_feat = torch.ones(unfold_input_feat.size()[0], unfold_input_feat.size()[1], 1)
                        unfold_input_feat = torch.cat((unfold_input_feat, dummy_input_feat), dim=2)
                    num_param_per_outchannel = unfold_input_feat.size()[2]

                    original_feature_out = torch.load(f"/tmp2/edward0530/tmp/batch{batch_idx}_feature_out.{args.arch}.{args.prefix}.pt")
                    output_feat = original_feature_out[idx] # (B, C', H', W')
                    flatten_output_feat = output_feat.flatten(start_dim=2).transpose(1, 2) # (B, H'*W', C')
                    num_outchannel = output_feat.size()[1]

                    # optimize parameters (m.weight.data, m.bias.data) of this layer
                    unfold_input_feat = unfold_input_feat.reshape([-1, unfold_input_feat.size()[-1]]) # (B*H'*W', C*k*k)
                    flatten_output_feat = flatten_output_feat.reshape([-1, flatten_output_feat.size()[-1]])
                    # A = unfold_input_feat.numpy()
                    # B = flatten_output_feat[:, remain_channel].numpy()
                    # avg_optimal_param_with_hole = np.linalg.lstsq(A, B, rcond=1e-3)[0]
                    # avg_optimal_param_with_hole = torch.from_numpy(avg_optimal_param_with_hole)
                    optimal_param_with_hole = torch.linalg.pinv(unfold_input_feat, rcond=1e-3) @ flatten_output_feat[:, remain_channel] # (C*k*k, C'-num_mask_channel)                    
                    optimal_param_list.append(optimal_param_with_hole)

                avg_optimal_param_with_hole = torch.mean(torch.stack(optimal_param_list), dim=0).detach().clone()
                avg_optimal_param_with_hole = Variable(avg_optimal_param_with_hole, requires_grad=True)
                batch_idx_list = list(range(len(train_loader)))
                loss_sequence = []
                if GD_epoch > 0:
                    loss_fn = nn.MSELoss()
                    # optimizer = torch.optim.Adam([avg_optimal_param_with_hole], lr=args.lr, weight_decay=args.weight_decay)
                    optimizer = torch.optim.SGD([avg_optimal_param_with_hole], lr=args.init_optimize_lr)
                    for _ in range(GD_epoch):
                        loss_sequence.append(0)
                        random.shuffle(batch_idx_list)
                        for batch_idx in batch_idx_list:
                            input_feat = self.pruned_feature_in[batch_idx]
                            unfold_input_feat = unfold(input_feat).transpose(1,2)
                            if m.bias != None:
                                dummy_input_feat = torch.ones(unfold_input_feat.size()[0], unfold_input_feat.size()[1], 1)
                                unfold_input_feat = torch.cat((unfold_input_feat, dummy_input_feat), dim=2)
                            unfold_input_feat = unfold_input_feat.reshape([-1, unfold_input_feat.size()[-1]]) # (B*H'*W', C*k*k)
                            unfold_input_feat = Variable(unfold_input_feat, requires_grad=False)

                            original_feature_out = torch.load(f"/tmp2/edward0530/tmp/batch{batch_idx}_feature_out.{args.arch}.{args.prefix}.pt")
                            target_output_feat = original_feature_out[idx][:, remain_channel, :, :]
                            target_output_feat = target_output_feat.flatten(start_dim=2).transpose(1,2) # (B, H'*W', C')
                            target_output_feat = target_output_feat.reshape([-1, target_output_feat.size()[-1]]) # (B*H'*W', C')
                            target_output_feat = Variable(target_output_feat, requires_grad=False)

                            output_feat = unfold_input_feat @ avg_optimal_param_with_hole
                            loss = loss_fn(output_feat, target_output_feat)
                            loss_sequence[-1] += (loss.item()/len(batch_idx_list))
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()
                    
                    if GD_epoch > 0 and plot_figure == True:
                        plt.clf()
                        plt.plot(list(range(GD_epoch)), loss_sequence)
                        plt.savefig(f"figure/{args.arch}_{args.prefix}/{idx}_loss_curve.png")

                avg_optimal_param = torch.zeros(num_param_per_outchannel, num_outchannel).cuda()
                fill_idx = 0
                for i in range(num_outchannel):
                    if remain_channel[i]:
                        avg_optimal_param[:, i] = avg_optimal_param_with_hole[:, fill_idx].detach().clone()
                        fill_idx += 1

                if m.bias != None:
                    avg_optimal_bias = avg_optimal_param[-1]
                    avg_optimal_weight = avg_optimal_param[:(-1)].t().reshape((-1, input_channel, m.kernel_size[0], m.kernel_size[1])) # (C', C, k, k)
                else:
                    avg_optimal_weight = avg_optimal_param.t().reshape((-1, input_channel, m.kernel_size[0], m.kernel_size[1])) # (C', C, k, k)
                
                with torch.no_grad():
                    m.weight.copy_(avg_optimal_weight)
                    if m.bias != None:
                        m.bias.copy_(avg_optimal_bias)

                del self.pruned_feature_in
                del self.pruned_feature_out

                idx += 1

    def optimize_parameters_with_GD(self, train_loader, GD_epoch=20, plot_figure=True):
        if GD_epoch > 0 and plot_figure == True:
            os.mkdir(f"figure/{args.arch}_{args.prefix}")
            with open(f"figure/{args.arch}_{args.prefix}/parameter", 'w') as f:
                f.write(str(args))
        
        idx = 0
        self.model.train()
        for m in tqdm(list(self.model.backbone.modules())):
            if isinstance(m, nn.Conv2d):
                if idx == 0: # skip the first layer
                    idx += 1 
                    continue
                if args.arch == "inception_v3" and idx <= 39:
                    idx += 1
                    continue
                self.save_pruned_feature_maps_single_layer(train_loader, idx)

                loss_fn = nn.MSELoss()
                optimizer = torch.optim.SGD(m.parameters(), lr=args.init_optimize_lr)

                pruned_flatten_output_feat = self.pruned_feature_out[0].flatten(start_dim=2).transpose(1, 2) # (B, H'*W', C')
                mask_channel = torch.abs(torch.sum(pruned_flatten_output_feat, (0, 1))) < 1e-3 # (C',)
                remain_channel = torch.logical_not(mask_channel) # (C',)
                num_outchannel = remain_channel.size()[0]
                mask = remain_channel.reshape((1, num_outchannel, 1, 1))

                batch_idx_list = list(range(len(train_loader)))
                loss_sequence = []
                for epoch in range(GD_epoch):
                    random.shuffle(batch_idx_list)
                    loss_sequence.append(0)
                    for batch_idx in batch_idx_list:
                        input_feat = self.pruned_feature_in[batch_idx] # (B, C, H, W)
                        output_feat = m(input_feat)

                        original_feature_out = torch.load(f"/tmp2/edward0530/tmp/batch{batch_idx}_feature_out.{args.arch}.{args.prefix}.pt")
                        target_output_feat = original_feature_out[idx] # (B, C', H', W')

                        loss = loss_fn(output_feat*mask, target_output_feat*mask)
                        loss_sequence[-1] += (loss.item()/len(batch_idx_list))
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                
                if GD_epoch > 0 and plot_figure == True:
                    plt.clf()
                    plt.plot(list(range(GD_epoch)), loss_sequence)
                    plt.savefig(f"figure/{args.arch}_{args.prefix}/{idx}_loss_curve.png")

                '''if idx >= 49:
                    original_feature_out = torch.load(f"/tmp2/edward0530/tmp/batch0_feature_out.{args.arch}.{args.prefix}.pt")
                    target_output_feat = original_feature_out[idx] # (B, C', H', W')
                    print("original distance:", torch.norm(target_output_feat[:,remain_channel,:,:]-self.pruned_feature_out[0][:,remain_channel,:,:]))
                    print("new distance:", torch.norm(target_output_feat[:,remain_channel,:,:]-m(self.pruned_feature_in[0])[:,remain_channel,:,:]))
                    breakpoint()'''
                
                del self.pruned_feature_in
                del self.pruned_feature_out
                idx += 1

    def optimize_parameters_with_SSKD(self, train_loader, val_loader, t_epoch=90, s_epoch=240, mean=None, std=None):
        # print(mean, std)
        train_ssp_head(train_loader, self.original_model, self.log, t_epoch, lr=args.init_optimize_lr, 
                       transform=args.transform, mean=mean, std=std)
        reg_layer_idx = args.reg_layer_idx
        '''if args.arch == 'resnet50':
            reg_layer_idx = 23
        elif args.arch == 'vgg16_bn':
            reg_layer_idx = 6
        elif args.arch == 'densenet121':
            reg_layer_idx = 59
        else:
            raise NotImplementedError
            reg_layer_idx = 0'''
        if args.arch == "resnet50" and args.vis_feature_maps:
            model_before_distillation = copy.deepcopy(self.model)
            train_knowledge_distillation(train_loader, val_loader, self, self.log, s_epoch, lr=args.init_optimize_lr, 
                                     arch=args.arch, prefix=args.prefix, transform=args.transform, mean=mean, std=std,
                                     reg_layer_idx=reg_layer_idx, reg_weight=0)
            visualize_resnet50_feature_maps(self.model, self.vis_batch, self.vis_target, args.save_dir, 'init_optimize')
            self.model = model_before_distillation
        train_knowledge_distillation(train_loader, val_loader, self, self.log, s_epoch, lr=args.init_optimize_lr, 
                                     arch=args.arch, prefix=args.prefix, transform=args.transform, mean=mean, std=std,
                                     reg_layer_idx=reg_layer_idx, reg_weight=args.reg_weight)
        if args.arch == "resnet50" and args.vis_feature_maps:
            visualize_resnet50_feature_maps(self.model, self.vis_batch, self.vis_target, args.save_dir, 'init_optimize_reg')
    
    '''@torch.no_grad()
    def optimize_parameters_with_FKSD(self, batches):
        idx = 0
        for m in tqdm(list(self.model.modules())):
            if isinstance(m, nn.Conv2d):
                if idx == 0: # skip the first layer
                    idx += 1 
                    continue
                self.save_pruned_feature_maps_multi_batch(batches)

                num_out_channel = self.original_feature_out[idx].size()[1]
                originial_output_feat = self.original_feature_out[idx].permute((0, 2, 3, 1)).reshape((-1, num_out_channel)) # (N*H'*W', C')
                new_output_feat = self.pruned_feature_out[idx].permute((0, 2, 3, 1)).reshape((-1, num_out_channel)) # (N*H'*W', C')
                mask_channel = torch.abs(torch.sum(new_output_feat, 0)) < 1e-3 # (C',)
                remain_channel = torch.logical_not(mask_channel)
                
                A = new_output_feat[:, remain_channel].numpy()
                B = originial_output_feat[:, remain_channel].numpy()
                pw_conv_weight = np.linalg.lstsq(A, B, rcond=1e-3)[0] # (C', C')
                pw_conv_weight= torch.from_numpy(pw_conv_weight)
                # pw_conv_weight = torch.linalg.lstsq(originial_output_feat[:, remain_channel], new_output_feat[:, remain_channel]).solution # (C'-num_mask_channel, C'-num_mask_channel)
                #print("original distance:", torch.norm(new_output_feat[:, remain_channel] - originial_output_feat[:, remain_channel]))
                #print("recovered distance:", torch.norm(new_output_feat[:, remain_channel] @ pw_conv_weight - originial_output_feat[:, remain_channel]))
                #breakpoint()

                weight = m.weight.data[remain_channel,:,:,:].detach().cpu()
                recovered_weight = torch.zeros_like(m.weight.data[remain_channel,:,:,:])
                for i in range(recovered_weight.size()[2]):
                    for j in range(recovered_weight.size()[3]):
                        recovered_weight[:,:,i,j] = torch.matmul(pw_conv_weight.t(), weight[:,:,i,j]) 
                #recovered_weight: (C'-num_mask_channel, C, k, k)
                recovered_weight_2 = torch.zeros_like(m.weight.data) # (C', C, k, k)
                fill_idx = 0
                for i in range(m.weight.data.size()[0]):
                    if remain_channel[i]:
                        recovered_weight_2[i] = recovered_weight[fill_idx].clone()
                        fill_idx += 1
                recovered_weight_2 = recovered_weight_2.cuda()
                m.weight.copy_(recovered_weight_2)
                #print("original distance:", torch.norm(self.original_feature_out[idx][:,remain_channel,:,:] - self.pruned_feature_out[idx][:,remain_channel,:,:]))
                #print("recovered distance:", torch.norm(m(self.pruned_feature_in[idx].cuda()).cpu()[:,remain_channel,:,:] - self.original_feature_out[idx][:,remain_channel,:,:]))
                #breakpoint()

                idx += 1'''

if __name__ == '__main__':
    main()