import copy
import functools
import os

import blobfile as bf
import torch as th
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import AdamW
import numpy as np
import random
import math

from . import dist_util, logger
from .fp16_util import MixedPrecisionTrainer
from .nn import update_ema
from .resample import LossAwareSampler, UniformSampler

# For ImageNet experiments, this was a good default value.
# We found that the lg_loss_scale quickly climbed to
# 20-21 within the first ~1K steps of training.
INITIAL_LOG_LOSS_SCALE = 20.0


def get_simple_t(skip_type, t_T, t_0, N):
    assert skip_type == 'time_uniform'
    tt = th.linspace(t_T, t_0, N + 1)
    tt = list(tt.cpu().numpy())
    tt.remove(min(tt))
    out_t = th.tensor(np.array(tt))
    out_t = out_t.float()
    return out_t


def get_model_input_time(t, total_N):
    return (t - 1. / total_N) * 1000.


class TrainLoop:
    def __init__(
        self,
        *,
        model,
        standard_model,
        diffusion,
        data,
        batch_size,
        microbatch,
        lr,
        ema_rate,
        log_interval,
        save_interval,
        resume_checkpoint,
        use_fp16=False,
        fp16_scale_growth=1e-3,
        schedule_sampler=None,
        weight_decay=0.0,
        lr_anneal_steps=0,
        use_simple_train=False,
        step_respacing=-1,
        dpm_solver_steps=20,
    ):
        self.model = model
        self.standard_model = standard_model
        self.diffusion = diffusion
        self.data = data
        self.batch_size = batch_size
        self.microbatch = microbatch if microbatch > 0 else batch_size
        self.lr = lr
        self.ema_rate = (
            [ema_rate]
            if isinstance(ema_rate, float)
            else [float(x) for x in ema_rate.split(",")]
        )
        self.log_interval = log_interval
        self.save_interval = save_interval
        self.resume_checkpoint = resume_checkpoint
        self.use_fp16 = use_fp16
        self.fp16_scale_growth = fp16_scale_growth
        self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
        self.weight_decay = weight_decay
        self.lr_anneal_steps = lr_anneal_steps

        ####################################
        self.use_simple_train = use_simple_train
        if use_simple_train:
            assert step_respacing != -1
            self.step_respacing = step_respacing
            self.dpm_solver_steps = dpm_solver_steps
        ####################################

        self.step = 0
        self.resume_step = 0
        self.global_batch = self.batch_size * dist.get_world_size()

        self.sync_cuda = th.cuda.is_available()

        self._load_and_sync_parameters()
        self.mp_trainer = MixedPrecisionTrainer(
            model=self.model,
            use_fp16=self.use_fp16,
            fp16_scale_growth=fp16_scale_growth,
        )

        self.opt = AdamW(
            self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
        )
        if self.resume_step:
            self._load_optimizer_state()
            # Model was resumed, either due to a restart or a checkpoint
            # being specified at the command line.
            self.ema_params = [
                self._load_ema_parameters(rate) for rate in self.ema_rate
            ]
        else:
            self.ema_params = [
                copy.deepcopy(self.mp_trainer.master_params)
                for _ in range(len(self.ema_rate))
            ]

        if th.cuda.is_available():
            self.use_ddp = True
            self.ddp_model = DDP(
                self.model,
                device_ids=[dist_util.dev()],
                output_device=dist_util.dev(),
                broadcast_buffers=False,
                bucket_cap_mb=128,
                find_unused_parameters=False,
            )
            ####################################
            self.ddp_std_model = self.standard_model
            ####################################
        else:
            if dist.get_world_size() > 1:
                logger.warn(
                    "Distributed training requires CUDA. "
                    "Gradients will not be synchronized properly!"
                )
            self.use_ddp = False
            self.ddp_model = self.model
            self.ddp_std_model = self.standard_model

    def _load_and_sync_parameters(self):
        resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint

        if resume_checkpoint:
            self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
            if dist.get_rank() == 0:
                logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
            self.model.load_state_dict(
                dist_util.load_state_dict(
                    resume_checkpoint, map_location=dist_util.dev()
                )
            )

        dist_util.sync_params(self.model.parameters())

    def _load_ema_parameters(self, rate):
        ema_params = copy.deepcopy(self.mp_trainer.master_params)

        main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
        ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
        if ema_checkpoint:
            if dist.get_rank() == 0:
                logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
                state_dict = dist_util.load_state_dict(
                    ema_checkpoint, map_location=dist_util.dev()
                )
                ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)

        dist_util.sync_params(ema_params)
        return ema_params

    def _load_optimizer_state(self):
        main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
        opt_checkpoint = bf.join(
            bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
        )
        if bf.exists(opt_checkpoint):
            logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
            state_dict = dist_util.load_state_dict(
                opt_checkpoint, map_location=dist_util.dev()
            )
            self.opt.load_state_dict(state_dict)

    def run_loop(self, single_label=-1):
        ####################################
        if self.use_simple_train:
            t_0 = 1. / float(self.step_respacing)
            N = self.dpm_solver_steps
            self.simple_steps = get_simple_t('time_uniform', 1., t_0, N)
            show_t = get_model_input_time(self.simple_steps, self.step_respacing)
            logger.log('The simple train step list is', list(show_t.cpu().numpy()))
        ####################################
        while (
            not self.lr_anneal_steps
            or self.step + self.resume_step < self.lr_anneal_steps
        ):
            batch, cond = next(self.data)
            if single_label != -1:
                y = single_label*th.ones(cond['y'].shape, dtype=cond['y'].dtype, device=cond['y'].device)
                cond['y'] = y.type(cond['y'].dtype)
            self.run_step(batch, cond)
            if self.step % self.log_interval == 0:
                logger.dumpkvs()
            if (self.step % self.save_interval == 0) and (self.step != 0):
                self.save()
                # Run for a finite amount of time in integration tests.
                if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
                    return
            self.step += 1
        # Save the last checkpoint if it wasn't already saved.
        if (self.step - 1) % self.save_interval != 0:
            self.save()

    def run_step(self, batch, cond):
        self.forward_backward(batch, cond)
        took_step = self.mp_trainer.optimize(self.opt)
        if took_step:
            self._update_ema()
        self._anneal_lr()
        self.log_step()

    def forward_backward(self, batch, cond):
        logger.log("The first 5 labels used in this step is", cond['y'][:5])
        self.mp_trainer.zero_grad()
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i : i + self.microbatch].to(dist_util.dev())
            micro_cond = {
                k: v[i : i + self.microbatch].to(dist_util.dev())
                for k, v in cond.items()
            }
            last_batch = (i + self.microbatch) >= batch.shape[0]
            t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())

            ####################################
            """
            We use Uniform Schedule Sampler,
            then all weights are torch.ones for all t.
            There's no need to update weights.
            """
            if self.use_simple_train:
                t_dtype = t.dtype
                t_device = t.device
                idx = np.random.randint(0,self.simple_steps.shape[0],(micro.shape[0],))
                simple_t = self.simple_steps[idx]
                t = get_model_input_time(simple_t, self.step_respacing)
                if random.randint(0,1) == 0:
                    t = t.int()
                else:
                    t = (t+0.5).int()
                t = t.type(t_dtype).to(t_device)
            ####################################

            compute_losses = functools.partial(
                self.diffusion.training_losses,
                self.ddp_model,
                micro,
                t,
                model_kwargs=micro_cond,
            )

            if last_batch or not self.use_ddp:
                losses = compute_losses(std_model=self.ddp_std_model, device=dist_util.dev())
            else:
                with self.ddp_model.no_sync():
                    losses = compute_losses(std_model=self.ddp_std_model, device=dist_util.dev())

            if isinstance(self.schedule_sampler, LossAwareSampler):
                self.schedule_sampler.update_with_local_losses(
                    t, losses["loss"].detach()
                )

            loss = (losses["loss"] * weights).mean()
            log_loss_dict(
                self.diffusion, t, {k: v * weights for k, v in losses.items()}
            )
            self.mp_trainer.backward(loss)

    def _update_ema(self):
        for rate, params in zip(self.ema_rate, self.ema_params):
            update_ema(params, self.mp_trainer.master_params, rate=rate)

    def _anneal_lr(self):
        if not self.lr_anneal_steps:
            return
        frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
        lr = self.lr * (1 - frac_done)
        for param_group in self.opt.param_groups:
            param_group["lr"] = lr

    def log_step(self):
        logger.logkv("step", self.step + self.resume_step)
        logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)

    def save(self):
        def save_checkpoint(rate, params):
            state_dict = self.mp_trainer.master_params_to_state_dict(params)
            if dist.get_rank() == 0:
                logger.log(f"saving model {rate}...")
                if not rate:
                    filename = f"model{(self.step+self.resume_step):06d}.pt"
                else:
                    filename = f"ema_{rate}_newest.pt"
                with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
                    th.save(state_dict, f)

        save_checkpoint(0, self.mp_trainer.master_params)
        for rate, params in zip(self.ema_rate, self.ema_params):
            save_checkpoint(rate, params)

        if dist.get_rank() == 0:
            with bf.BlobFile(
                bf.join(get_blob_logdir(), f"opt_newest.pt"),
                "wb",
            ) as f:
                th.save(self.opt.state_dict(), f)

        dist.barrier()


def parse_resume_step_from_filename(filename):
    """
    Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
    checkpoint's number of steps.
    """
    split = filename.split("model")
    if len(split) < 2:
        return 0
    split1 = split[-1].split(".")[0]
    try:
        return int(split1)
    except ValueError:
        return 0


def get_blob_logdir():
    # You can change this to be a separate path to save checkpoints to
    # a blobstore or some external drive.
    return logger.get_dir()


def find_resume_checkpoint():
    # On your infrastructure, you may want to override this to automatically
    # discover the latest checkpoint on your blob storage, etc.
    return None


def find_ema_checkpoint(main_checkpoint, step, rate):
    if main_checkpoint is None:
        return None
    filename = f"ema_{rate}_{(step):06d}.pt"
    path = bf.join(bf.dirname(main_checkpoint), filename)
    if bf.exists(path):
        return path
    return None


def log_loss_dict(diffusion, ts, losses):
    for key, values in losses.items():
        logger.logkv_mean(key, values.mean().item())
        # Log the quantiles (four quartiles, in particular).
        for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
            quartile = int(4 * sub_t / diffusion.num_timesteps)
            logger.logkv_mean(f"{key}_q{quartile}", sub_loss)


def get_compactor_mask_dict(model:th.nn.Module):
    compactor_name_to_mask = {}
    compactor_name_to_kernel_param = {}
    for name, buffer in model.named_buffers():
        if 'mask' in name:
            compactor_name_to_mask[name.replace('mask', '')] = buffer

    for name, param in model.named_parameters():
        if 'pwc.weight' in name:
            compactor_name_to_kernel_param[name.replace('pwc.weight', '')] = param

    result = {}
    for name, kernel in compactor_name_to_kernel_param.items():
        mask = compactor_name_to_mask[name]
        num_filters = mask.nelement()
        if kernel.ndimension() == 4:
            if mask.ndimension() == 1:
                broadcast_mask = mask.reshape(-1, 1).repeat(1, num_filters)
                result[name] = broadcast_mask.reshape(num_filters, num_filters, 1, 1)
            else:
                assert mask.ndimension() == 4
                result[name] = mask
        else:
            assert kernel.ndimension() == 1
            result[name] = mask

    return result


def resrep_mask_model(model:th.nn.Module, thresh, idx_of_scores, out_c_of_para, save_mask_score_path, num_upd_msk):
    out_ch = 0
    conv_name_ = ''
    scores = []
    last_com_shape0 = 0
    for k,p in model.named_parameters():
        if 'pwc' in k:
            ks = k.split('.')
            if 'skip_com' in k:
                conv_name = k.replace('skip_com.pwc','skip_connection')
            else:
                ks[-3] = str(int(ks[-3]) - 1)
                ks.pop(-2)
                conv_name = '.'.join(ks)

            compactor_out_ch = int(p.shape[0])
            compactor_in_ch = int(p.shape[1])
            last_com_shape0 = int(p.shape[0])
            assert k.endswith('weight')
            assert conv_name_ == conv_name
            assert compactor_in_ch == out_ch
            assert compactor_out_ch == out_ch

            kernel_weight = p.detach().cpu().numpy()
            scores_k = np.sqrt(np.sum(kernel_weight ** 2, axis=(1,2,3)))

            # Each Conv2d reserves at least 1 channel.
            idx = np.argmax(scores_k)
            scores_k[idx] = math.inf

            scores += list(scores_k)

        else:
            if len(p.shape) == 4:
                out_ch = int(p.shape[0])
                conv_name_ = k
    
    assert len(scores) == idx_of_scores[-1] + last_com_shape0
    np.save(save_mask_score_path+'/newest_mask_score.npy', scores)

    # for i in range(6):
    #     scores.append(math.inf)
    scores = np.array(scores)
    channel_to_layer = np.zeros(scores.shape[0],dtype=np.int32)

    for j in range(len(idx_of_scores)):
        start = idx_of_scores[j]
        end = idx_of_scores[j] + out_c_of_para[j]
        length = out_c_of_para[j]
        channel_to_layer[start:end] = j * np.ones(length, dtype=np.int32)

    if num_upd_msk == -1:
        channel_prune = np.where(scores < thresh)[0]
    else:
        assert num_upd_msk > 0
        sorted_idx_list = scores.argsort()
        channel_prune = sorted_idx_list[:num_upd_msk]

    layer_prune = channel_to_layer[list(channel_prune)]

    cpiel = list(-1 * np.ones(len(idx_of_scores), dtype=np.int32))
    for i, layer in enumerate(layer_prune):
        if cpiel[layer] == -1:
            cpiel[layer] = [channel_prune[i]-idx_of_scores[layer]]
        else:
            cpiel[layer].append(channel_prune[i]-idx_of_scores[layer])

    all_msk_num = 0
    for i in range(len(cpiel)):
        if cpiel[i] != -1:
            logger.log("The num of parameters' gradient to be masked in Compactor:", i, "is", len(list(cpiel[i])))
            all_msk_num += len(list(cpiel[i]))
            tmp = np.array(cpiel[i])
            cpiel[i] = list(np.sort(tmp))
        else:
            logger.log("The num of parameters' gradient to be masked in Compactor:", i, "is 0")
    logger.log("A total of", all_msk_num, "Conv2d channels are masked in this mask step.")
    logger.log("###########################################################")
    
    idx = -1
    for m in model.modules():
        if hasattr(m, 'pwc'):
            idx += 1
            if cpiel[idx] != -1:
                m.set_mask(cpiel[idx])


class RepTrain:
    def __init__(
        self,
        *,
        model,
        diffusion,
        data,
        batch_size,
        microbatch,
        lr,
        log_interval,
        save_interval,
        resume_checkpoint,
        idx_of_scores,
        out_c_of_para,
        save_mask_score_path,
        use_fp16=False,
        fp16_scale_growth=1e-3,
        schedule_sampler=None,
        weight_decay=0.0,
        lr_anneal_steps=0,
        use_simple_train=False,
        step_respacing=-1,
        dpm_solver_steps=20,
        lasso_strength=1e-4,
        mask_interval=200,
        before_mask_iters=10000,
        rep_train_thresh=1e-5,
        num_upd_msk=-1,
        rep_thresh_decay=10,
        rep_thresh_min=1e-5,
        num_upd_msk_increment=-1,
    ):
        self.model = model
        self.diffusion = diffusion
        self.data = data
        self.batch_size = batch_size
        self.microbatch = microbatch if microbatch > 0 else batch_size
        self.lr = lr
        self.log_interval = log_interval
        self.save_interval = save_interval
        self.resume_checkpoint = resume_checkpoint
        self.use_fp16 = use_fp16
        self.fp16_scale_growth = fp16_scale_growth
        self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
        self.weight_decay = weight_decay
        self.lr_anneal_steps = lr_anneal_steps

        ####################################
        self.use_simple_train = use_simple_train
        if use_simple_train:
            assert step_respacing != -1
            self.step_respacing = step_respacing
            self.dpm_solver_steps = dpm_solver_steps

        self.lasso_strength = lasso_strength
        self.mask_interval = mask_interval
        self.before_mask_iters = before_mask_iters
        self.idx_of_scores = idx_of_scores
        self.out_c_of_para = out_c_of_para
        self.rep_train_thresh = rep_train_thresh
        self.rep_thresh_decay = rep_thresh_decay
        self.rep_thresh_min = rep_thresh_min
        self.save_mask_score_path = save_mask_score_path
        self.num_upd_msk = num_upd_msk
        self.num_upd_msk_increment = num_upd_msk_increment
        ####################################

        self.step = 0
        self.resume_step = 0
        self.global_batch = self.batch_size * dist.get_world_size()

        self.sync_cuda = th.cuda.is_available()

        self._load_and_sync_parameters()
        self.mp_trainer = MixedPrecisionTrainer(
            model=self.model,
            use_fp16=self.use_fp16,
            fp16_scale_growth=fp16_scale_growth,
        )

        self.opt = AdamW(
            self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
        )
        if self.resume_step:
            self._load_optimizer_state()

        if th.cuda.is_available():
            self.use_ddp = True
            self.ddp_model = DDP(
                self.model,
                device_ids=[dist_util.dev()],
                output_device=dist_util.dev(),
                broadcast_buffers=False,
                bucket_cap_mb=128,
                find_unused_parameters=False,
            )
        else:
            if dist.get_world_size() > 1:
                logger.warn(
                    "Distributed training requires CUDA. "
                    "Gradients will not be synchronized properly!"
                )
            self.use_ddp = False
            self.ddp_model = self.model

    def _load_and_sync_parameters(self):
        resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint

        if resume_checkpoint:
            self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
            if dist.get_rank() == 0:
                logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
                self.model.load_state_dict(
                    dist_util.load_state_dict(
                        resume_checkpoint, map_location=dist_util.dev()
                    ),
                    strict=False
                )

        dist_util.sync_params(self.model.parameters())

    def _load_optimizer_state(self):
        main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
        opt_checkpoint = bf.join(
            bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
        )
        if bf.exists(opt_checkpoint):
            logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
            state_dict = dist_util.load_state_dict(
                opt_checkpoint, map_location=dist_util.dev()
            )
            self.opt.load_state_dict(state_dict)

    def run_loop(self, single_label=-1):
        ####################################
        if self.use_simple_train:
            t_0 = 1. / float(self.step_respacing)
            N = self.dpm_solver_steps
            self.simple_steps = get_simple_t('time_uniform', 1., t_0, N)
            show_t = get_model_input_time(self.simple_steps, self.step_respacing)
            logger.log('The simple train step list is', list(show_t.cpu().numpy()))
        ####################################

        ####################################
        compactor_mask_dict = get_compactor_mask_dict(self.ddp_model)
        ####################################
        while (
            not self.lr_anneal_steps
            or self.step + self.resume_step < self.lr_anneal_steps
        ):
            batch, cond = next(self.data)
            if single_label != -1:
                y = single_label*th.ones(cond['y'].shape, dtype=cond['y'].dtype, device=cond['y'].device)
                cond['y'] = y.type(cond['y'].dtype)
            
            if (self.step+self.resume_step) >= self.before_mask_iters:
                iters_in_compactor_phase = self.step + self.resume_step - self.before_mask_iters
                if iters_in_compactor_phase % self.mask_interval == 0:
                    logger.log("###########################################################")
                    if self.num_upd_msk == -1:
                        logger.log("Update mask at iter", self.step+self.resume_step, ". The threshhold of the filt is", self.rep_train_thresh)
                    else:
                        logger.log("Update mask at iter", self.step+self.resume_step, ". The num of the filt to mask is", self.num_upd_msk)
                    logger.log("###########################################################")
                    resrep_mask_model(self.ddp_model, self.rep_train_thresh, self.idx_of_scores, self.out_c_of_para, self.save_mask_score_path, self.num_upd_msk)
                    if (self.num_upd_msk != -1) and (self.num_upd_msk < 32000):
                        self.num_upd_msk += self.num_upd_msk_increment
                    if float(self.rep_train_thresh/self.rep_thresh_decay) >= self.rep_thresh_min:
                        self.rep_train_thresh = float(self.rep_train_thresh/self.rep_thresh_decay)
                    compactor_mask_dict = get_compactor_mask_dict(self.ddp_model)

            self.run_step(batch, cond, compactor_mask_dict)
            if self.step % self.log_interval == 0:
                logger.dumpkvs()
            if (self.step % self.save_interval == 0) and (self.step != 0):
                self.save()
            self.step += 1

        # Save the last checkpoint if it wasn't already saved.
        if (self.step - 1) % self.save_interval != 0:
            self.save()

        logger.log("Rep train finished.")
    
    def run_step(self, batch, cond, compactor_mask_dict):
        self.forward_backward(batch, cond, compactor_mask_dict)
        self._anneal_lr()
        self.log_step()
    
    def forward_backward(self, batch, cond, compactor_mask_dict):
        logger.log("The first 5 labels used in this step is", cond['y'][:5])
        self.mp_trainer.zero_grad()
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i : i + self.microbatch].to(dist_util.dev())
            micro_cond = {
                k: v[i : i + self.microbatch].to(dist_util.dev())
                for k, v in cond.items()
            }
            last_batch = (i + self.microbatch) >= batch.shape[0]
            t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())

            ####################################
            """
            We use Uniform Schedule Sampler,
            then all weights are torch.ones for all t.
            There's no need to update weights.
            """
            if self.use_simple_train:
                t_dtype = t.dtype
                t_device = t.device
                idx = np.random.randint(0,self.simple_steps.shape[0],(micro.shape[0],))
                simple_t = self.simple_steps[idx]
                t = get_model_input_time(simple_t, self.step_respacing)
                if random.randint(0,1) == 0:
                    t = t.int()
                else:
                    t = (t+0.5).int()
                t = t.type(t_dtype).to(t_device)
            ####################################

            compute_losses = functools.partial(
                self.diffusion.training_losses,
                self.ddp_model,
                micro,
                t,
                model_kwargs=micro_cond,
            )

            if last_batch or not self.use_ddp:
                losses = compute_losses(device=dist_util.dev())
            else:
                with self.ddp_model.no_sync():
                    losses = compute_losses(device=dist_util.dev())

            if isinstance(self.schedule_sampler, LossAwareSampler):
                self.schedule_sampler.update_with_local_losses(
                    t, losses["loss"].detach()
                )

            loss = (losses["loss"] * weights).mean()
            log_loss_dict(
                self.diffusion, t, {k: v * weights for k, v in losses.items()}
            )
            self.mp_trainer.backward(loss)


            ####################################
            msk_id = 0
            for name, param in self.ddp_model.named_parameters():
                if 'pwc.weight' in name:
                    msk_id += 1
                    mask = compactor_mask_dict[name.replace('pwc.weight','')]
                    param.grad.data = mask * param.grad.data
                    lasso_grad = param.data * ((param.data ** 2).sum(dim=(1, 2, 3), keepdim=True) ** (-0.5))

                    if msk_id == 10 and i == 0:
                        mat = param.detach().cpu().numpy()
                        metric = np.sqrt(np.sum(mat ** 2, axis=(1, 2, 3)))
                        logger.log("Min in the 10th compactor", metric.min())

                    param.grad.data.add_(self.lasso_strength, lasso_grad)

            all_num = 0
            if i == 0:
                for name, param in self.ddp_model.named_parameters():
                    if 'pwc.weight' in name:
                        mat = param.detach().cpu().numpy()
                        metric = np.sqrt(np.sum(mat ** 2, axis=(1, 2, 3)))
                        for i in range(len(list(metric))):
                            if metric[i] < self.rep_thresh_min:
                                all_num += 1
                logger.log("The num of channels whose metric is lower than final thresh", self.rep_thresh_min, "is", all_num)

            # for compactor_param, mask in compactor_mask_dict.items():
            #     compactor_param.grad.data = mask * compactor_param.grad.data
            #     lasso_grad = compactor_param.data * ((compactor_param.data ** 2).sum(dim=(1, 2, 3), keepdim=True) ** (-0.5))
            #     compactor_param.grad.data.add_(self.lasso_strength, lasso_grad)
            # _ = self.mp_trainer.optimize(self.opt)
            ####################################
        _ = self.mp_trainer.optimize(self.opt)

    def save(self):
        def save_checkpoint(rate, params):
            state_dict = self.mp_trainer.master_params_to_state_dict(params)
            if dist.get_rank() == 0:
                logger.log(f"saving model...")
                if not rate:
                    filename = f"model{(self.step+self.resume_step):06d}.pt"
                else:
                    filename = f"ema_{rate}_newest.pt"
                with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
                    th.save(state_dict, f)

        save_checkpoint(0, self.mp_trainer.master_params)

        if dist.get_rank() == 0:
            with bf.BlobFile(
                bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
                "wb",
            ) as f:
                th.save(self.opt.state_dict(), f)

        dist.barrier()

    def _anneal_lr(self):
        if not self.lr_anneal_steps:
            return
        frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
        lr = self.lr * (1 - frac_done)
        for param_group in self.opt.param_groups:
            param_group["lr"] = lr

    def log_step(self):
        logger.logkv("step", self.step + self.resume_step)
        logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
