import torch
import torch.nn as nn
from net.mix_normal import MixNormal
from net.mix_neuro_sync import Single_NEW_MIX_Sample_Based
from net.mix_normal_resnet import MixNormalResNet
from net.resnet_sparse_neuro_sync import Single_NEW_MIX_Sample_Based_Sparse_Res
from net.new_vit import ViTForClassfication
import copy
from torch.nn.utils import parameters_to_vector
from cbp_utils.convGnT import ConvGnT
import torch.optim as optim
from cbp_utils.gnt import GnT, AdamGnT
from torch.autograd import Variable
import torch.autograd as autograd
import torch.nn.functional as F
import wandb
import numpy as np
from er_utils.buffer import Buffer
from er_utils.gem_utils import *
# from hat import HATPayload, HATConfig
from torchvision import transforms

class ERWrapper(nn.Module):
    def __init__(self, agent, buffer_size, minibatch_size, device):
        super().__init__()
        self.agent = agent
        self.buffer = Buffer(buffer_size, device)
        self.transform = None
        self.minibatch_size = minibatch_size
        self.agent.buffer_minibatch_size = minibatch_size
    
    def get_optimizer(self):
        self.agent.get_optimizer()
    
    def predict(self, x):
        return self.agent.predict(x)
    
    def step(self, x, y, raw_image):
        inputs = x
        labels = y
        real_batch_size = inputs.shape[0]
        if not self.buffer.is_empty():
            buf_inputs, buf_labels = self.buffer.get_data(self.minibatch_size, transform=self.transform) # type: ignore
            inputs = torch.cat((inputs, buf_inputs))
            labels = torch.cat((labels, buf_labels))
        
        logits, metrics = self.agent.step(x=inputs, y=labels)

        self.buffer.add_data(examples=inputs[:real_batch_size],
                             labels=labels[:real_batch_size])
        
        return logits.detach(), metrics
        

    def compute_activation_statistics(self, batch):
        return self.agent.compute_activation_statistics(batch)
    
    def load_consolidated_weights(self, wc_weights):
        if isinstance(self.agent, NeuroSyncAgent):
            self.agent.load_consolidated_weights(wc_weights=wc_weights)

    def forward(self, x):
        return self.predict(x)

    @property
    def model(self):
        return self.agent.model  # Getter for .model
    
class AGemWrapper(nn.Module):
    def __init__(self, agent, buffer_size, minibatch_size, num_tasks, device):
        super().__init__()
        self.agent = agent
        self.device = device
        
        self.buffer = Buffer(buffer_size, device)
        self.grad_dims = []
        for param in self.agent.parameters():
            self.grad_dims.append(param.data.numel())
        self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        
        self.transform = None
        self.minibatch_size = minibatch_size
        self.buffer_size = buffer_size
        self.num_tasks = num_tasks
    
    def get_optimizer(self):
        self.agent.get_optimizer()
    
    def predict(self, x):
        return self.agent.predict(x)
    
    def step(self, x, y):
        self.agent.optimizer.zero_grad()
        if isinstance(self.agent, NeuroSyncAgent):
            loss, logits, info = self.agent.compute_loss(x, y)  
        else:  
            loss, logits = self.agent.compute_loss(x, y)
            info = {}
        loss.backward()

        if not self.buffer.is_empty():
            store_grad(self.agent.parameters, self.grad_xy, self.grad_dims)

            buf_inputs, buf_labels = self.buffer.get_data(self.minibatch_size)
            self.agent.optimizer.zero_grad()
            if isinstance(self.agent, NeuroSyncAgent):
                penalty, _, _ = self.agent.compute_loss(buf_inputs, buf_labels)
            else:
                penalty, _ = self.agent.compute_loss(buf_inputs, buf_labels)
            penalty.backward()
            store_grad(self.agent.parameters, self.grad_er, self.grad_dims)

            dot_prod = torch.dot(self.grad_xy, self.grad_er)
            if dot_prod.item() < 0:
                g_tilde = project(gxy=self.grad_xy, ger=self.grad_er)
                overwrite_grad(self.agent.parameters, g_tilde, self.grad_dims)
            else:
                overwrite_grad(self.agent.parameters, self.grad_xy, self.grad_dims)

        self.agent.optimizer.step()
        
        # Metrics computed at every step.
        # Get the magnitude of the gradient
        grad_metrics = {}
        for name, param in self.agent.model.named_parameters():
            if param.requires_grad and 'layer_norm' not in name and \
                'init_params' not in name and \
                    'original_last_layer_params' not in name:
                grad_metrics[f'agent/{name}-grad-magnitude'] = torch.norm(param.grad)
                grad_metrics[f'agent/{name}-weight-magnitude'] = torch.norm(param.data)
                grad_metrics[f'agent/{name}-grad-frac-zero'] = torch.mean((param.grad == 0).float()).item()
                if 'bias' not in name:
                    tmp_name = name.split(".")[0]
                    if tmp_name in self.agent.model.activations.keys():
                        if torch.mean(self.agent.model.activations[tmp_name].abs()) > 0:
                            grad_metrics[f'agent/{tmp_name}-dormant'] = wandb.Histogram((self.agent.model.activations[tmp_name].abs() / torch.mean(self.agent.model.activations[tmp_name].abs())).detach().cpu().numpy())
                        else:
                            grad_metrics[f'agent/{tmp_name}-dormant'] = wandb.Histogram((self.agent.model.activations[tmp_name].abs()).detach().cpu().numpy())
                            
        metrics = {'curr_train_loss': loss.detach(),
                  **grad_metrics,
                  **info}

        return logits.detach(), metrics
    
    def end_task(self, not_aug_cur_x, cur_y):
        samples_per_task = self.buffer_size // self.num_tasks
        N = not_aug_cur_x.size(0)
        perm = torch.randperm(N)
        not_aug_cur_x = not_aug_cur_x[perm]
        cur_y = cur_y[perm]
        cur_x = not_aug_cur_x[:samples_per_task]
        cur_y = cur_y[:samples_per_task]
        self.buffer.add_data(
            examples=cur_x.to(self.device),
            labels=cur_y.to(self.device)
        )
        
    def compute_activation_statistics(self, batch):
        return self.agent.compute_activation_statistics(batch)
    
    def load_consolidated_weights(self, wc_weights):
        if isinstance(self.agent, NeuroSyncAgent):
            self.agent.load_consolidated_weights(wc_weights=wc_weights)
        
    def forward(self, x):
        return self.predict(x)

    @property
    def model(self):
        return self.agent.model  # Getter for .model

class BaseAgent(nn.Module):
    def __init__(self, agent_config, arch_config, task_config):
        super().__init__()
        if arch_config.arch_name == 'RESNET':
            assert agent_config.agent_type not in ['LayerNormAgent', 'HATAgent']
            self.model = MixNormalResNet(
                    input_shape=task_config.input_shape,
                    num_classes=task_config.num_classes,
                    activation=arch_config.activation,
                    load_pretrained=arch_config.load_pretrained,
                    dropout_percentage=arch_config.dropout_percentage,
                    disable_bn=arch_config.disable_bn,
                    agent_type=agent_config.agent_type) 
        else:
            self.model = MixNormal(
                input_type=arch_config.input_type, 
                input_shape=task_config.input_shape,
                num_classes=task_config.num_classes,
                cnn_channels=arch_config.cnn_channels,
                kernel_size=arch_config.kernel_size,
                padding=arch_config.padding,
                stride=arch_config.stride,
                pooling_type=arch_config.pooling_type,
                pooling_kernel=arch_config.pooling_kernel,
                fc_channels=arch_config.fc_channels,
                activation=arch_config.activation,
                layer_norm=False
            )
        self.agent_config = agent_config
        self.arch_config = arch_config
        self.task_config = task_config
        self.loss_fn = nn.CrossEntropyLoss()
        self.get_optimizer()
        self.buffer_minibatch_size=None
    
    def get_optimizer(self):
        if self.agent_config.optimizer == 'adam':
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.agent_config.lr)
        elif self.agent_config.optimizer == 'sgd':
            self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.agent_config.lr)
        else:
            raise ValueError(f"Unsupported optimizer: {self.agent_config.optimizer}")
    
    def compute_loss(self, x, y):
        logits = self.model(x)
        loss = self.loss_fn(logits, y)
        return loss, logits
    
    def predict(self, x):
        logits = self.model(x)
        return logits
    
    def forward(self, x):
        return self.predict(x)
    
    def step(self, x, y):

        loss, logits = self.compute_loss(x, y)

        self.optimizer.zero_grad()
        #breakpoint()
        loss.backward()
        self.optimizer.step()
        
        # Metrics computed at every step.
        # Get the magnitude of the gradient
        grad_metrics = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad and 'layer_norm' not in name and \
                'init_params' not in name and \
                    'original_last_layer_params' not in name:
                grad_metrics[f'agent/{name}-grad-magnitude'] = torch.norm(param.grad)
                grad_metrics[f'agent/{name}-weight-magnitude'] = torch.norm(param.data)
                grad_metrics[f'agent/{name}-grad-frac-zero'] = torch.mean((param.grad == 0).float()).item()
                if 'bias' not in name:
                    tmp_name = name.split(".")[0]
                    if tmp_name in self.model.activations.keys():
                        if torch.mean(self.model.activations[tmp_name].abs()) > 0:
                            grad_metrics[f'agent/{tmp_name}-dormant'] = wandb.Histogram((self.model.activations[tmp_name].abs() / torch.mean(self.model.activations[tmp_name].abs())).detach().cpu().numpy())
                        else:
                            grad_metrics[f'agent/{tmp_name}-dormant'] = wandb.Histogram((self.model.activations[tmp_name].abs()).detach().cpu().numpy())
                            
        metrics = {'curr_train_loss': loss.detach(),
                  **grad_metrics}

        return logits.detach(), metrics

    def compute_activation_statistics(self, batch):
        # Compute the effective feature rank.
        # Compute the number of activations with a value of 0 for all examples in the input batch.
        
        # First do a forward pass.
        with torch.no_grad():
            self.model(batch)
        
        srank_dict = {}
        effective_rank_dict = {}
        dead_neurons_dict = {}

        # Loop through all the activations.
        num_layers = 0.
        total_effective_rank = 0.
        total_srank = 0.

        num_dead_neurons = 0.
        total_neurons = 0.
        for layer_name, activations in self.model.activations.items():
            
            if 'conv' in layer_name:
                batch_size = len(batch)
                activation_matrix = activations.reshape(batch_size, -1)
            else:
                activation_matrix = activations
            
            # Compute the effective rank of the features.
            singular_values = torch.linalg.svdvals(activation_matrix, driver=None, out=None)
            cumulative_fraction = torch.cumsum(singular_values, dim=-1) / torch.sum(singular_values)

            # srank computation is on page 3 of this paper: https://arxiv.org/pdf/2010.14498.pdf 
            delta = 0.01
            srank = len(cumulative_fraction[cumulative_fraction < 1 - delta])
            srank_dict[f'model_feature_srank/{layer_name}'] = srank
            total_srank += srank

            dist = singular_values / torch.sum(singular_values)
            # dist = dist.detach().numpy()
            # entropy = scipy.stats.entropy(dist)
            dist = dist[dist > 0]
            entropy = -1. * torch.sum(dist * torch.log(dist))
            effective_rank = torch.exp(entropy).detach().item()
            effective_rank_dict[f'model_effective_feature_rank/{layer_name}'] = effective_rank
            total_effective_rank += effective_rank
            num_layers += 1
            
            # Count the number of activations which are zero for ALL inputs in the batch.
            num_neurons = activation_matrix.shape[1]
            total_neurons += num_neurons
            
            # activation_matrix is batch_size x hidden dimension for the hidden layer.
            # Compute the number of columns for which all entries are 0.
            is_zero_column = torch.all(activation_matrix == 0, dim=0)
            num_zero_columns = torch.sum(is_zero_column).detach().item()
            num_dead_neurons += num_zero_columns

            fraction_dead_neurons = num_zero_columns / float(num_neurons)
            
            dead_neurons_dict[f'model_dead_neurons_fraction/{layer_name}'] = fraction_dead_neurons

        srank_dict['model_feature_srank/avg_srank'] = total_srank / float(num_layers)
        effective_rank_dict['model_effective_feature_rank/avg_effective_rank'
                            ] = total_effective_rank / float(num_layers)
        dead_neurons_dict['model_dead_neurons_fraction/fraction_dead_neurons'
                          ] = num_dead_neurons / float(total_neurons)
        
        l1_norm_dict = self.model.compute_l1_norm()
        l2_norm_dict = self.model.compute_l2_norm()
        try:
            input_layer_norm_dict = self.model.input_layer_norms()
        except:
            input_layer_norm_dict = None

        if input_layer_norm_dict is not None:
            activation_statistics_dict = {
                **srank_dict,
                **effective_rank_dict,
                **dead_neurons_dict,
                **l1_norm_dict,
                **l2_norm_dict, 
                **input_layer_norm_dict,
            }
        else:
            activation_statistics_dict = {
                **srank_dict,
                **effective_rank_dict,
                **dead_neurons_dict,
                **l1_norm_dict,
                **l2_norm_dict,
            }

        return activation_statistics_dict


class L2InitAgent(BaseAgent):
    def __init__(self, agent_config, arch_config, task_config, device):
        super(L2InitAgent, self).__init__(agent_config, arch_config, task_config)
        self.l2_weight = agent_config.l2_weight
        self.sample_init_values = agent_config.sample_init_values
        self.device = device
        
        self.init_params_dict = {}
        # Populate init params dict.
        for name, param in self.model.named_parameters():
            if not param.requires_grad or 'layer_norm' in name or \
                'init_params' in name or \
                    'original_last_layer_params' in name:
                continue
            self.init_params_dict[name] = param.data.clone().detach()

    def compute_loss(self, x, y):
        
        if self.sample_init_values:
            init_model_named_params_resampled = self.model_constructor().state_dict()
            
        logits = self.model(x)
        loss = self.loss_fn(logits, y)
        
        l2_loss = 0
        for name, param in self.model.named_parameters():
            if not param.requires_grad or 'layer_norm' in name or \
                'init_params' in name or \
                    'original_last_layer_params' in name:
                continue
            
            if self.sample_init_values:
                init_param = init_model_named_params_resampled[name].detach()
            else:
                init_param = self.init_params_dict[name].detach().to(self.device)
            
            diff = param - init_param
            l2_loss += torch.sum(diff ** 2)

        loss += self.l2_weight * 0.5 * l2_loss

        return loss, logits
    
class L2Agent(BaseAgent):
    def __init__(self, agent_config, arch_config, task_config):
        super(L2Agent, self).__init__(agent_config, arch_config, task_config)
        self.l2_weight = agent_config.l2_weight
        
    def compute_loss(self, x, y):

        logits = self.model(x)
        loss = self.loss_fn(logits, y)
        
        # Compute the L2 norm.
        l2_loss = 0.0
        for name, param in self.model.named_parameters():
            if not param.requires_grad or 'layer_norm' in name or \
                'init_params' in name or \
                    'original_last_layer_params' in name:
                continue

            l2_loss += torch.sum(param ** 2)

        loss += self.l2_weight * 0.5 * l2_loss

        return loss, logits

class LayerNormAgent(BaseAgent):
    def __init__(self, agent_config, arch_config, task_config):
        super(LayerNormAgent, self).__init__(agent_config, arch_config, task_config)
        self.model = MixNormal(
            input_type=arch_config.input_type, 
            input_shape=task_config.input_shape,
            num_classes=task_config.num_classes,
            cnn_channels=arch_config.cnn_channels,
            kernel_size=arch_config.kernel_size,
            padding=arch_config.padding,
            stride=arch_config.stride,
            pooling_type=arch_config.pooling_type,
            pooling_kernel=arch_config.pooling_kernel,
            fc_channels=arch_config.fc_channels,
            activation=arch_config.activation,
            layer_norm=True
        )
    
        self.init_model = copy.deepcopy(self.model)
    
    def forward(self, x):
        logits = self.model(x)
        info = {}
        return logits, info


class HATAgent(BaseAgent):
    def __init__(self, agent_config, arch_config, task_config):
        super(HATAgent, self).__init__(agent_config, arch_config, task_config)
        hat_config = HATConfig(num_tasks=task_config.num_tasks)
        self.model = MixNormal(
            input_type=arch_config.input_type, 
            input_shape=task_config.input_shape,
            num_classes=task_config.num_classes,
            cnn_channels=arch_config.cnn_channels,
            kernel_size=arch_config.kernel_size,
            padding=arch_config.padding,
            stride=arch_config.stride,
            pooling_type=arch_config.pooling_type,
            pooling_kernel=arch_config.pooling_kernel,
            fc_channels=arch_config.fc_channels,
            activation=arch_config.activation,
            layer_norm=False,
            use_hat=True,
            hat_config=hat_config,
            num_tasks=task_config.num_tasks,
            class_inc=task_config.class_inc
        )
        self.max_mask_scale = agent_config.s_max
        self.current_mask_scale = agent_config.s_max
        self.current_task_id = 0
        self.init_model = copy.deepcopy(self.model)
        if task_config.class_inc:
            self.heads = nn.ModuleList([nn.Linear(self.model.last_filter_output if len(self.model.fc_channels) == 1 else self.model.fc_channels[-2], task_config.num_classes//task_config.num_tasks) for _ in range(task_config.num_tasks)])
        else:
            self.heads = nn.ModuleList([nn.Linear(self.model.last_filter_output if len(self.model.fc_channels) == 1 else self.model.fc_channels[-2], task_config.num_classes) for _ in range(task_config.num_tasks)])
        self.get_optimizer()
    
    def get_optimizer(self):
        if self.agent_config.optimizer == 'adam':
            self.optimizer = torch.optim.Adam(self.parameters(), lr=self.agent_config.lr)
        elif self.agent_config.optimizer == 'sgd':
            self.optimizer = torch.optim.SGD(self.parameters(), lr=self.agent_config.lr)
        else:
            raise ValueError(f"Unsupported optimizer: {self.agent_config.optimizer}")

    def compute_activation_statistics(self, batch):
        # Compute the effective feature rank.
        # Compute the number of activations with a value of 0 for all examples in the input batch.
        return {}
        # First do a forward pass.
        with torch.no_grad():
            batch = HATPayload(batch, task_id=self.current_task_id, mask_scale=self.max_mask_scale)
            self.model(batch)
        
        srank_dict = {}
        effective_rank_dict = {}
        dead_neurons_dict = {}

        # Loop through all the activations.
        num_layers = 0.
        total_effective_rank = 0.
        total_srank = 0.

        num_dead_neurons = 0.
        total_neurons = 0.
        for layer_name, activations in self.model.activations.items():
            
            if 'conv' in layer_name:
                batch_size = len(batch)
                activation_matrix = activations.reshape(batch_size, -1)
            else:
                activation_matrix = activations
            
            # Compute the effective rank of the features.
            singular_values = torch.linalg.svdvals(activation_matrix, driver=None, out=None)
            cumulative_fraction = torch.cumsum(singular_values, dim=-1) / torch.sum(singular_values)

            # srank computation is on page 3 of this paper: https://arxiv.org/pdf/2010.14498.pdf 
            delta = 0.01
            srank = len(cumulative_fraction[cumulative_fraction < 1 - delta])
            srank_dict[f'model_feature_srank/{layer_name}'] = srank
            total_srank += srank

            dist = singular_values / torch.sum(singular_values)
            # dist = dist.detach().numpy()
            # entropy = scipy.stats.entropy(dist)
            dist = dist[dist > 0]
            entropy = -1. * torch.sum(dist * torch.log(dist))
            effective_rank = torch.exp(entropy).detach().item()
            effective_rank_dict[f'model_effective_feature_rank/{layer_name}'] = effective_rank
            total_effective_rank += effective_rank
            num_layers += 1
            
            # Count the number of activations which are zero for ALL inputs in the batch.
            num_neurons = activation_matrix.shape[1]
            total_neurons += num_neurons
            
            # activation_matrix is batch_size x hidden dimension for the hidden layer.
            # Compute the number of columns for which all entries are 0.
            is_zero_column = torch.all(activation_matrix == 0, dim=0)
            num_zero_columns = torch.sum(is_zero_column).detach().item()
            num_dead_neurons += num_zero_columns

            fraction_dead_neurons = num_zero_columns / float(num_neurons)
            
            dead_neurons_dict[f'model_dead_neurons_fraction/{layer_name}'] = fraction_dead_neurons

        srank_dict['model_feature_srank/avg_srank'] = total_srank / float(num_layers)
        effective_rank_dict['model_effective_feature_rank/avg_effective_rank'
                            ] = total_effective_rank / float(num_layers)
        dead_neurons_dict['model_dead_neurons_fraction/fraction_dead_neurons'
                          ] = num_dead_neurons / float(total_neurons)
        
        l1_norm_dict = self.model.compute_l1_norm()
        l2_norm_dict = self.model.compute_l2_norm()
        try:
            input_layer_norm_dict = self.model.input_layer_norms()
        except:
            input_layer_norm_dict = None

        if input_layer_norm_dict is not None:
            activation_statistics_dict = {
                **srank_dict,
                **effective_rank_dict,
                **dead_neurons_dict,
                **l1_norm_dict,
                **l2_norm_dict, 
                **input_layer_norm_dict,
            }
        else:
            activation_statistics_dict = {
                **srank_dict,
                **effective_rank_dict,
                **dead_neurons_dict,
                **l1_norm_dict,
                **l2_norm_dict,
            }

        return activation_statistics_dict

    def step(self, x, y):

        loss, logits = self.compute_loss(x, y)

        self.optimizer.zero_grad()
        #breakpoint()
        loss.backward()
        self.optimizer.step()
        
        # Metrics computed at every step.
        # Get the magnitude of the gradient
        grad_metrics = {}
        # for name, param in self.model.named_parameters():
        #     if param.requires_grad and 'layer_norm' not in name and \
        #         'init_params' not in name and \
        #             'original_last_layer_params' not in name:
        #         grad_metrics[f'agent/{name}-grad-magnitude'] = torch.norm(param.grad)
        #         grad_metrics[f'agent/{name}-weight-magnitude'] = torch.norm(param.data)
        #         grad_metrics[f'agent/{name}-grad-frac-zero'] = torch.mean((param.grad == 0).float()).item()
        #         if 'bias' not in name:
        #             tmp_name = name.split(".")[0]
        #             if tmp_name in self.model.activations.keys():
        #                 if torch.mean(self.model.activations[tmp_name].abs()) > 0:
        #                     grad_metrics[f'agent/{tmp_name}-dormant'] = wandb.Histogram((self.model.activations[tmp_name].abs() / torch.mean(self.model.activations[tmp_name].abs())).detach().cpu().numpy())
        #                 else:
        #                     grad_metrics[f'agent/{tmp_name}-dormant'] = wandb.Histogram((self.model.activations[tmp_name].abs()).detach().cpu().numpy())
                            
        metrics = {'curr_train_loss': loss.detach(),
                  **grad_metrics}

        return logits.detach(), metrics

    def forward(self, x):
        pay_load = HATPayload(x, task_id=self.current_task_id,
                               mask_scale=self.current_mask_scale if self.training else self.max_mask_scale)
        # TODO 
        # if self.task_config.class_inc:
        #     logits = self.heads[self.current_task_id](self.model(pay_load))
        # else:
        #     logits = self.model(pay_load)
        logits = self.heads[self.current_task_id](self.model(pay_load))
        if isinstance(logits, HATPayload):
            logits = logits.data
        info = {}
        return logits, info

    def compute_loss(self, x, y):
        pay_load = HATPayload(x, task_id=self.current_task_id,
                               mask_scale=self.current_mask_scale if self.training else self.max_mask_scale)
        # TODO
        # if self.task_config.class_inc:
        #     logits = self.heads[self.current_task_id](self.model(pay_load))
        # else:
        #     logits = self.model(pay_load)
        logits = self.heads[self.current_task_id](self.model(pay_load))
        if isinstance(logits, HATPayload):
            logits = logits.data
        loss = self.loss_fn(logits, y)
        return loss, logits
    
    def predict(self, x):
        pay_load = HATPayload(x, task_id=self.current_task_id,
                               mask_scale=self.current_mask_scale if self.training else self.max_mask_scale)
        # TODO 
        # if self.task_config.class_inc:
        #     logits = self.heads[self.current_task_id](self.model(pay_load))
        # else:
        #     logits = self.model(pay_load)
        logits = self.heads[self.current_task_id](self.model(pay_load))
        if isinstance(logits, HATPayload):
            logits = logits.data
        return logits

class CReLUAgent(BaseAgent):
    def __init__(self, agent_config, arch_config, task_config):
        super(CReLUAgent, self).__init__(agent_config, arch_config, task_config)
        if arch_config.arch_name == 'RESNET':
            assert agent_config.agent_type not in ['LayerNormAgent', 'HATAgent']
            self.model = MixNormalResNet(
                    input_shape=task_config.input_shape,
                    num_classes=task_config.num_classes,
                    activation='crelu',
                    load_pretrained=arch_config.load_pretrained,
                    dropout_percentage=arch_config.dropout_percentage,
                    disable_bn=arch_config.disable_bn,
                    agent_type=agent_config.agent_type) 
        else:
            self.model = MixNormal(
                input_type=arch_config.input_type, 
                input_shape=task_config.input_shape,
                num_classes=task_config.num_classes,
                cnn_channels=arch_config.cnn_channels,
                kernel_size=arch_config.kernel_size,
                padding=arch_config.padding,
                stride=arch_config.stride,
                pooling_type=arch_config.pooling_type,
                pooling_kernel=arch_config.pooling_kernel,
                fc_channels=arch_config.fc_channels,
                activation='crelu',
                layer_norm=False
            )
    
        self.init_model = copy.deepcopy(self.model)
    
    def forward(self, x):
        logits = self.model(x)
        info = {}
        return logits, info

class PReLUAgent(BaseAgent):
    def __init__(self, agent_config, arch_config, task_config):
        super(PReLUAgent, self).__init__(agent_config, arch_config, task_config)
        if arch_config.arch_name == 'RESNET':
            assert agent_config.agent_type not in ['LayerNormAgent', 'HATAgent']
            self.model = MixNormalResNet(
                    input_shape=task_config.input_shape,
                    num_classes=task_config.num_classes,
                    activation='prelu',
                    load_pretrained=arch_config.load_pretrained,
                    dropout_percentage=arch_config.dropout_percentage,
                    disable_bn=arch_config.disable_bn) 
        else:
            self.model = MixNormal(
                input_type=arch_config.input_type, 
                input_shape=task_config.input_shape,
                num_classes=task_config.num_classes,
                cnn_channels=arch_config.cnn_channels,
                kernel_size=arch_config.kernel_size,
                padding=arch_config.padding,
                stride=arch_config.stride,
                pooling_type=arch_config.pooling_type,
                pooling_kernel=arch_config.pooling_kernel,
                fc_channels=arch_config.fc_channels,
                activation='prelu',
                layer_norm=False
            )
    
        self.init_model = copy.deepcopy(self.model)
    
    def forward(self, x):
        logits = self.model(x)
        info = {}
        return logits, info
    
class ViTAgent(BaseAgent):
    def __init__(self, agent_config, arch_config, task_config, device):
        super(ViTAgent, self).__init__(agent_config, arch_config, task_config)
        config = {
            "patch_size": 4,  # Input image size: 32x32 -> 8x8 patches
            "hidden_size": 512,
            "num_hidden_layers": 4,
            "num_attention_heads": 4,
            "intermediate_size": 4 * 512, # 4 * hidden_size
            "hidden_dropout_prob": 0.0,
            "attention_probs_dropout_prob": 0.0,
            "initializer_range": 0.02,
            "image_size": task_config.input_shape[1],
            "num_classes": task_config.num_classes, # num_classes of CIFAR10
            "num_channels": task_config.input_shape[0],
            "qkv_bias": True,
            "use_faster_attention": True,
        }
        self.model = ViTForClassfication(config=config)
    
    def forward(self, x):
        logits = self.model(x)
        info = {}
        return logits, info
    
    def compute_activation_statistics(self, batch):
        return {}
    
    def step(self, x, y):

        loss, logits = self.compute_loss(x, y)

        self.optimizer.zero_grad()
        #breakpoint()
        loss.backward()
        self.optimizer.step()
        
        # Metrics computed at every step.
        # Get the magnitude of the gradient
        grad_metrics = {}
        # for name, param in self.model.named_parameters():
        #     if param.requires_grad and 'layer_norm' not in name and \
        #         'init_params' not in name and \
        #             'original_last_layer_params' not in name:
        #         grad_metrics[f'agent/{name}-grad-magnitude'] = torch.norm(param.grad)
        #         grad_metrics[f'agent/{name}-weight-magnitude'] = torch.norm(param.data)
        #         grad_metrics[f'agent/{name}-grad-frac-zero'] = torch.mean((param.grad == 0).float()).item()
        #         if 'bias' not in name:
        #             tmp_name = name.split(".")[0]
        #             if tmp_name in self.model.activations.keys():
        #                 if torch.mean(self.model.activations[tmp_name].abs()) > 0:
        #                     grad_metrics[f'agent/{tmp_name}-dormant'] = wandb.Histogram((self.model.activations[tmp_name].abs() / torch.mean(self.model.activations[tmp_name].abs())).detach().cpu().numpy())
        #                 else:
        #                     grad_metrics[f'agent/{tmp_name}-dormant'] = wandb.Histogram((self.model.activations[tmp_name].abs()).detach().cpu().numpy())
                            
        metrics = {'curr_train_loss': loss.detach(),
                  **grad_metrics}

        return logits.detach(), metrics

class DeepFourierAgent(BaseAgent):
    def __init__(self, agent_config, arch_config, task_config):
        super(DeepFourierAgent, self).__init__(agent_config, arch_config, task_config)
        self.model = MixNormal(
            input_type=arch_config.input_type, 
            input_shape=task_config.input_shape,
            num_classes=task_config.num_classes,
            cnn_channels=arch_config.cnn_channels,
            kernel_size=arch_config.kernel_size,
            padding=arch_config.padding,
            stride=arch_config.stride,
            pooling_type=arch_config.pooling_type,
            pooling_kernel=arch_config.pooling_kernel,
            fc_channels=arch_config.fc_channels,
            activation='deepfourier',
            layer_norm=False
        )
    
        self.init_model = copy.deepcopy(self.model)
    
    def forward(self, x):
        logits = self.model(x)
        info = {}
        return logits, info


class ReDoAgent(BaseAgent):
    def __init__(self, agent_config, arch_config, task_config):
        super().__init__(agent_config, arch_config, task_config)
        
        self.recycle_period = agent_config.recycle_period
        self.recycle_threshold = agent_config.recycle_threshold
        self.step_count = 0
        # Populate init params dict.
        self.init_params = []
        self.init_params_dict = {}
        for name, param in self.model.named_parameters():
            if not param.requires_grad or 'layer_norm' in name or \
                                          'init_params' in name or \
                                          'original_last_layer_params' in name:
                continue
            self.init_params.append(param.data.clone().detach())
            self.init_params_dict[name] = param.data.clone().detach()
    
    def recycle_neurons(self):

        for layer_id, (layer_type, activation_tuple) in enumerate(self.model.activations_for_redo.items()):
            
            activation_set = activation_tuple[0]
            current_layer_type = activation_tuple[1]
            next_layer_type = activation_tuple[2]
            
            # Compute the expected absolute value activation over the batch.
            expected_activation = torch.mean(torch.abs(activation_set), dim=0)
            
            if current_layer_type == 'conv':
                # Shape (conv layer): batch_size x num_output_channels x feature_map_dim x feature_map_dim
                
                # Compute the expected absolute value activation over the feature map.
                expected_activation = torch.mean(expected_activation, dim=(-2, -1))
            
            # Compute the average expected absolute value activation for the layer.
            average_expected_activation = torch.mean(expected_activation)
            neuron_scores = expected_activation / average_expected_activation
        
            # If neuron score is less than threshold, reset the incoming weights
            # to be the initial values and the outgoing weights to be 0.
            
            for neuron_index in range(len(neuron_scores)):
                if neuron_scores[neuron_index] <= self.recycle_threshold:
                    
                    # Get the incoming and outgoing weight matrices.
                    incoming_weights = getattr(self.model, self.model.layer_names[layer_id])
                    outgoing_weights = getattr(self.model, self.model.layer_names[layer_id + 1])

                    # Reset incoming weights to be initial values.
                    layer_name = self.model.layer_names[layer_id]
                    
                    weight_param_name = f'{layer_name}.weight'
                    initial_weights = self.init_params_dict[weight_param_name]
                    
                    bias_param_name = f'{layer_name}.bias'
                    initial_biases = self.init_params_dict[bias_param_name]

                    with torch.no_grad():
                        incoming_weights.weight.data[neuron_index].copy_(initial_weights.data[neuron_index])
                        incoming_weights.bias.data[neuron_index].copy_(initial_biases.data[neuron_index])
                        
                    if current_layer_type == 'conv' and next_layer_type == 'fc':
                        # Shape of conv activation: (batch_size, output_channels, feature_map_width, feature_map_width)
                        # where feature_map_width is after pooling.
                        # i*(feature_map_width*feature_map_width):(i+1)*(feature_map_width*feature_map_width)
                        
                        # Get the number of channels in the activation.
                        num_channels = activation_set.shape[1]
                        # Get the number of features in the feature map after doing max pool operation.
                        num_features_after_max_pool = int(outgoing_weights.weight.shape[-1] / num_channels)

                        neuron_indices_start = num_features_after_max_pool * neuron_index
                        neuron_indices_end = num_features_after_max_pool * (neuron_index + 1)
                        
                        with torch.no_grad():
                            outgoing_weights.weight.data[:, neuron_indices_start:neuron_indices_end] = 0.
                    else:
                        # Set outgoing weights to be zero.
                        with torch.no_grad():
                            outgoing_weights.weight.data[:, neuron_index] = 0.
    
    def step(self, x, y):

        if self.step_count % self.recycle_period == 0 and self.step_count > 0:
            self.recycle_neurons()

        logits, metrics = super().step(x, y)
        self.step_count += 1
        return logits, metrics


class NeuroSyncAgent(BaseAgent):
    def __init__(self, agent_config, arch_config, task_config, device, track_params=None):
        super(NeuroSyncAgent, self).__init__(agent_config, arch_config, task_config)
        if agent_config.use_resnet:
            assert task_config.benchmark in ['shuffle_mini_imagenet', 'shuffle_cifar10', 'random_label_cifar10',
                                              'continual_cifar100', 'new_continual_imagenet'] # TODO 
            self.model = Single_NEW_MIX_Sample_Based_Sparse_Res(
                input_type=arch_config.input_type, 
                 controller_input_type ='conv',
                 input_shape=task_config.input_shape,
                 num_classes=task_config.num_classes,
                 transformer_embed_dim=agent_config.conv_embed_dim,
                 learnable_dim=agent_config.learnable_dim,
                 neuro_sync_conv_kernel_size=agent_config.conv_neuro_sync_conv_kernel_size,
                 k=agent_config.k,
                 num_heads=agent_config.num_heads,
                 ema_decay=agent_config.ema_decay,
                 num_freq=agent_config.num_freq,
                 prelu_transformer=agent_config.prelu_transformer,
                 WC=agent_config.WC, 
                 SM=agent_config.SM,
                 AL=agent_config.AL,
                 ARM=agent_config.ARM,
                 SM_detach=agent_config.SM_detach,
                 attention_neuro_sync=agent_config.attention_neuro_sync,
                 mask_portion=agent_config.mask_portion,             # <--- NEW: portion of neurons per layer to KEEP
                 decoder_layers=None,           # <--- NEW: if None, uses k
                 load_pretrained=arch_config.load_pretrained,
                 dropout_percentage=arch_config.dropout_percentage,
                 use_query_self_attn=agent_config.use_query_self_attn,
                 disable_bn=arch_config.disable_bn,
                 device=device
            )
        else:
            self.model = Single_NEW_MIX_Sample_Based(
                input_type=arch_config.input_type,
                controller_input_type = task_config.input_type, 
                input_shape=task_config.input_shape,
                num_classes=task_config.num_classes,
                cnn_channels=arch_config.cnn_channels,
                kernel_size=arch_config.kernel_size,
                padding=arch_config.padding,
                stride=arch_config.stride,
                pooling_type=arch_config.pooling_type,
                pooling_kernel=arch_config.pooling_kernel,
                fc_channels=arch_config.fc_channels,
                transformer_embed_dim=agent_config.conv_embed_dim if arch_config.input_type == 'conv' else agent_config.fc_embed_dim,
                learnable_dim=agent_config.learnable_dim,
                neuro_sync_conv_kernel_size=agent_config.conv_neuro_sync_conv_kernel_size if task_config.input_type == 'conv' else agent_config.fc_neuro_sync_conv_kernel_size,
                k=agent_config.k,
                num_heads=agent_config.num_heads,
                ema_decay=agent_config.ema_decay,
                num_freq=agent_config.num_freq,
                # neuron_type=agent_config.neuron_type,
                track_params=agent_config.track_params if track_params is None else track_params,
                WC=agent_config.WC, 
                SR=agent_config.SR,
                SM=agent_config.SM,
                AL=agent_config.AL,
                ARM=agent_config.ARM,
                SM_detach=agent_config.SM_detach,
                simplified=agent_config.simplified,
                prelu_transformer=agent_config.prelu_transformer,
                attention_neuro_sync=agent_config.attention_neuro_sync,
                img_to_controller=agent_config.attention_neuro_sync,
                global_modulate=agent_config.global_modulate,
                neu_state_to_controller=agent_config.neu_state_to_controller,
                device=device
            )
        self.alpha_bound = float(agent_config.alpha_bound)
        self.alpha_uniform = float(agent_config.alpha_uniform)
        self.uniform_ema=1.0
        self.bound_ema=1.0
        self.momentum = float(agent_config.momentum)
        self.agent_config = agent_config
        self.arch_config = arch_config
        self.task_config = task_config
        self.get_optimizer()
    
    def load_consolidated_weights(self, wc_weights):
        self.model.load_consolidated_weights(wc_weights)
    
    def reset_neuro_sync(self):
        self.model.create_controller()
        self.model.to(self.model.device)

    def compute_loss(self, x, y):
        logits, info = self.model(x)
        uniform_loss = info['uniform_loss']
        bound_loss = info['bound_loss']   
        loss = self.loss_fn(logits, y)
        self.uniform_ema = self.momentum * self.uniform_ema + (1 - self.momentum) * uniform_loss.item()
        self.bound_ema = self.momentum * self.bound_ema + (1 - self.momentum) * bound_loss.item()

        uniform_loss = self.alpha_uniform * loss.detach() * (uniform_loss / (self.uniform_ema + 1e-8))
        bound_loss = self.alpha_bound * loss.detach() * (bound_loss / (self.bound_ema + 1e-8)) 
        loss = loss + bound_loss + uniform_loss
        return loss, logits, info
    
    def predict(self, x):
        logits, _ = self.model(x)
        return logits
    
    def step(self, x, y):

        loss, logits, info = self.compute_loss(x, y)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Metrics computed at every step.
        # Get the magnitude of the gradient
        grad_metrics = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad and (name.startswith("conv") or name.startswith("fc")):
                if self.agent_config.save_ema:
                    grad_metrics[f'agent/{name}-ema-weight-histogram'] = wandb.Histogram(self.model._ema_params[name].detach().cpu().numpy()) # type: ignore
                grad_metrics[f'agent/{name}-grad-magnitude'] = torch.norm(param.grad)
                grad_metrics[f'agent/{name}-weight-magnitude'] = torch.norm(param.data)
                grad_metrics[f'agent/{name}-weight-histogram'] = wandb.Histogram(param.data.detach().cpu().numpy()) # type: ignore
                grad_metrics[f'agent/{name}-grad-frac-zero'] = torch.mean((param.grad == 0).float()).item()
                if 'bias' not in name:
                    tmp_name = name.split(".")[0]
                    if tmp_name in self.model.activations.keys():
                        if torch.mean(self.model.activations[tmp_name].abs()) > 0:
                            grad_metrics[f'agent/{tmp_name}-dormant'] = wandb.Histogram((self.model.activations[tmp_name].abs() / torch.mean(self.model.activations[tmp_name].abs())).detach().cpu().numpy())
                        else:
                            grad_metrics[f'agent/{tmp_name}-dormant'] = wandb.Histogram((self.model.activations[tmp_name].abs()).detach().cpu().numpy())
                            
        metrics = {'curr_train_loss': loss.detach(),
                  **grad_metrics,
                  **info}

        return logits.detach(), metrics

    def compute_activation_statistics(self, batch):
        # Compute the effective feature rank.
        # Compute the number of activations with a value of 0 for all examples in the input batch.
        #return {}
        # First do a forward pass.
        local_activations = {}
        batch_size = 256
        self.eval()
        with torch.no_grad():
            for i in range(0, batch.shape[0], batch_size):
                data = batch[i:i+batch_size]  # Last batch will be smaller
                self.model(data)
                for layer_name, activations in self.model.activations.items():
                    if layer_name in local_activations:
                        local_activations[layer_name].append(activations)
                    else:
                        local_activations[layer_name] = [activations]
        
        for layer_name, activations in local_activations.items():
            local_activations[layer_name] = torch.cat(activations, dim=0)
        
        srank_dict = {}
        effective_rank_dict = {}
        dead_neurons_dict = {}

        # Loop through all the activations.
        num_layers = 0.
        total_effective_rank = 0.
        total_srank = 0.

        num_dead_neurons = 0.
        total_neurons = 0.
        for layer_name, activations in local_activations.items():
            
            if 'conv' in layer_name:
                batch_size = len(batch)
                activation_matrix = activations.reshape(batch_size, -1)
            else:
                activation_matrix = activations
            
            # Compute the effective rank of the features.
            singular_values = torch.linalg.svdvals(activation_matrix, driver=None, out=None)
            cumulative_fraction = torch.cumsum(singular_values, dim=-1) / torch.sum(singular_values)

            # srank computation is on page 3 of this paper: https://arxiv.org/pdf/2010.14498.pdf 
            delta = 0.01
            srank = len(cumulative_fraction[cumulative_fraction < 1 - delta])
            srank_dict[f'model_feature_srank/{layer_name}'] = srank
            total_srank += srank

            dist = singular_values / torch.sum(singular_values)
            # dist = dist.detach().numpy()
            # entropy = scipy.stats.entropy(dist)
            dist = dist[dist > 0]
            entropy = -1. * torch.sum(dist * torch.log(dist))
            effective_rank = torch.exp(entropy).detach().item()
            effective_rank_dict[f'model_effective_feature_rank/{layer_name}'] = effective_rank
            total_effective_rank += effective_rank
            num_layers += 1
            
            # Count the number of activations which are zero for ALL inputs in the batch.
            num_neurons = activation_matrix.shape[1]
            total_neurons += num_neurons
            
            # activation_matrix is batch_size x hidden dimension for the hidden layer.
            # Compute the number of columns for which all entries are 0.
            is_zero_column = torch.all(activation_matrix == 0, dim=0)
            num_zero_columns = torch.sum(is_zero_column).detach().item()
            num_dead_neurons += num_zero_columns

            fraction_dead_neurons = num_zero_columns / float(num_neurons)
            
            dead_neurons_dict[f'model_dead_neurons_fraction/{layer_name}'] = fraction_dead_neurons

            dead_neurons_dict[f'activations/{layer_name}'] = wandb.Histogram(torch.mean(activation_matrix, dim=0).detach().cpu().numpy()) # type: ignore

        srank_dict['model_feature_srank/avg_srank'] = total_srank / float(num_layers)
        effective_rank_dict['model_effective_feature_rank/avg_effective_rank'
                            ] = total_effective_rank / float(num_layers)
        dead_neurons_dict['model_dead_neurons_fraction/fraction_dead_neurons'
                          ] = num_dead_neurons / float(total_neurons)
        
        l1_norm_dict = self.model.compute_l1_norm()
        l2_norm_dict = self.model.compute_l2_norm()
        try:
            input_layer_norm_dict = self.model.input_layer_norms()
        except:
            input_layer_norm_dict = None

        if input_layer_norm_dict is not None:
            activation_statistics_dict = {
                **srank_dict,
                **effective_rank_dict,
                **dead_neurons_dict,
                **l1_norm_dict,
                **l2_norm_dict,
                **input_layer_norm_dict,
            }
        else:
            activation_statistics_dict = {
                **srank_dict,
                **effective_rank_dict,
                **dead_neurons_dict,
                **l1_norm_dict,
                **l2_norm_dict,
            }
        self.train()
        return activation_statistics_dict

class CBPAgent(BaseAgent):
    def __init__(self, agent_config, arch_config, task_config, device):
        
        super().__init__(agent_config, arch_config, task_config)
        assert arch_config.activation == 'relu'
        # Override the optimizer.
        self.get_optimizer()

        if arch_config.input_type == 'fc':
            self.gnt = GnT(
                net=self.model.layers,
                hidden_activation='relu',
                opt=self.optimizer,
                replacement_rate=agent_config.replacement_rate,
                decay_rate=agent_config.decay_rate,
                maturity_threshold=agent_config.maturity_threshold,
                util_type=agent_config.util_type,
                device=device,
                loss_func=self.loss_fn,
                init='kaiming',
                accumulate=agent_config.accumulate,
            )
        elif arch_config.input_type == 'conv':
            self.gnt = ConvGnT(
                net=self.model.layers,
                hidden_activation='relu',
                opt=self.optimizer,
                replacement_rate=agent_config.replacement_rate,
                decay_rate=agent_config.decay_rate,
                init='kaiming',
                num_last_filter_outputs=self.model.last_filter_output,
                util_type=agent_config.util_type,
                maturity_threshold=agent_config.maturity_threshold,
                device=device,
            )

    def get_optimizer(self):
        if self.agent_config.optimizer == 'sgd':
            self.optimizer = optim.SGD(self.model.parameters(), lr=self.agent_config.lr)
        elif self.agent_config.optimizer == 'adam':
            self.optimizer = AdamGnT(self.model.parameters(),
                                     lr=self.agent_config.lr, 
                                     weight_decay=self.agent_config.weight_decay)

    def step(self, x, y):
        _ = self.predict(x)
        self.previous_features = list(self.model.activations.values())
        logits, metrics = super().step(x, y)
        self.gnt.gen_and_test(features=self.previous_features)
        return logits, metrics

class EWCAgent(BaseAgent):
    def __init__(self, agent_config, arch_config, task_config, device):
        super().__init__(agent_config, arch_config, task_config)
        self.ewc_weight = agent_config.ewc_weight
        self.use_fisher = agent_config.use_fisher
        self.device = device

        self.star_params_dict = {}
        self.fisher = {}
    
    def update_star_params(self):
        self.star_params_dict = {}
        # Populate init params dict.
        for name, param in self.model.named_parameters():
            if not param.requires_grad or 'layer_norm' in name or \
                'init_params' in name or \
                    'original_last_layer_params' in name:
                continue
            self.star_params_dict[name] = param.data.clone().detach()
        
    # Update the Fisher Information
    def update_fisher(self, xs, ys, batch_size):

        losses = []
        for i in range(0, len(xs) - batch_size, batch_size):
            batch_x = xs[i:i+batch_size]
            batch_y = ys[i:i+batch_size]
            x = batch_x.to(self.device)
            y = batch_y.to(self.device)

            losses.append(
                F.log_softmax(self.model(x), dim=1)[range(batch_size), y.data]
            )

        # estimate the fisher information of the parameters.
        sample_losses = torch.cat(losses).unbind()
        sample_grads = zip(*[autograd.grad(l, self.model.parameters(), retain_graph=(i < len(sample_losses))) 
        for i, l in enumerate(sample_losses, 1)])

        sample_grads = [torch.stack(gs) for gs in sample_grads]
        fisher_diagonals = [(g ** 2).mean(0) for g in sample_grads]
        self.fisher = {}

        for (name, param), fisher in zip(
            self.model.named_parameters(), fisher_diagonals):
            self.fisher[name] = fisher.detach()

    def update_params_and_fisher(self, xs, ys, batch_size):
        self.update_star_params()
        if self.use_fisher:
            self.update_fisher(xs, ys, batch_size)

    def compute_loss(self, x, y):
            
        logits = self.model(x)
        loss = self.loss_fn(logits, y)
        
        ewc_loss = 0
        if len(self.star_params_dict) > 0:
            for name, param in self.model.named_parameters():
                if not param.requires_grad or 'layer_norm' in name or \
                    'init_params' in name or \
                        'original_last_layer_params' in name:
                    continue
                
                star_param = self.star_params_dict[name].detach()
                
                diff = param - star_param

                fisher = 1
                if self.use_fisher and len(self.fisher):
                    fisher = self.fisher[name]

                ewc_loss += torch.sum(fisher * (diff ** 2)) 

        loss += self.ewc_weight * 0.5 * ewc_loss

        return loss, logits
    
class L2InitPlusEWCAgent(BaseAgent):
    def __init__(self, agent_config, arch_config, task_config, device):
        super().__init__(agent_config, arch_config, task_config)
        self.l2_weight = agent_config.l2_weight
        self.ewc_weight = agent_config.ewc_weight
        self.use_fisher = agent_config.use_fisher
        self.device = device

        self.star_params_dict = {}
        self.fisher = {}

        self.init_params_dict = {}
        # Populate init params dict.
        for name, param in self.model.named_parameters():
            if not param.requires_grad or 'layer_norm' in name or \
                'init_params' in name or \
                    'original_last_layer_params' in name:
                continue
            self.init_params_dict[name] = param.data.clone().detach()
    
    def update_star_params(self):
        self.star_params_dict = {}
        # Populate init params dict.
        for name, param in self.model.named_parameters():
            if not param.requires_grad or 'layer_norm' in name or \
                'init_params' in name or \
                    'original_last_layer_params' in name:
                continue
            self.star_params_dict[name] = param.data.clone().detach()

    # Update the Fisher Information
    def update_fisher(self, xs, ys, batch_size):

        losses = []
        for i in range(0, len(xs) - batch_size, batch_size):
            batch_x = xs[i:i+batch_size]
            batch_y = ys[i:i+batch_size]
            x = batch_x.to(self.device)
            y = batch_y.to(self.device)
            losses.append(
                F.log_softmax(self.model(x), dim=1)[range(batch_size), y.data]
            )

        # estimate the fisher information of the parameters.
        sample_losses = torch.cat(losses).unbind()
        sample_grads = zip(*[autograd.grad(l, self.model.parameters(), retain_graph=(i < len(sample_losses))) 
        for i, l in enumerate(sample_losses, 1)])

        sample_grads = [torch.stack(gs) for gs in sample_grads]
        fisher_diagonals = [(g ** 2).mean(0) for g in sample_grads]
        self.fisher = {}

        for (name, param), fisher in zip(
            self.model.named_parameters(), fisher_diagonals):
            self.fisher[name] = fisher.detach()

    def update_params_and_fisher(self, xs, ys, batch_size):
        self.update_star_params()
        if self.use_fisher:
            self.update_fisher(xs, ys, batch_size)

    def compute_loss(self, x, y):
            
        logits = self.model(x)
        loss = self.loss_fn(logits, y)
        
        l2_loss = 0
        ewc_loss = 0
        
        for name, param in self.model.named_parameters():
            if not param.requires_grad or 'layer_norm' in name or \
                'init_params' in name or \
                    'original_last_layer_params' in name:
                continue
            
            init_param = self.init_params_dict[name].detach().to(self.device)
            diff = param - init_param
            l2_loss += torch.sum(diff ** 2)

            if len(self.star_params_dict) > 0:
                star_param = self.star_params_dict[name].detach()
                diff = param - star_param
                
                fisher = 1
                if self.use_fisher and len(self.fisher):
                    fisher = self.fisher[name]

                ewc_loss += torch.sum(fisher * (diff ** 2)) 


        loss += self.l2_weight * 0.5 * l2_loss
        loss += self.ewc_weight * 0.5 * ewc_loss

        return loss, logits

class ShrinkAndPerturbAgent(BaseAgent):
    def __init__(self, agent_config, arch_config, task_config, device):
        super().__init__(agent_config, arch_config, task_config)
        self.shrink = agent_config.shrink
        self.perturb_scale = agent_config.perturb_scale
        self.device = device
    
    def _shrink_and_perturb(self):
        """Shrinks the parameter towards the origin and perturbs it"""

        # Sample perturbation
        random_model = self.model = MixNormal(
            input_type=self.arch_config.input_type, 
            input_shape=self.task_config.input_shape,
            num_classes=self.task_config.num_classes,
            cnn_channels=self.arch_config.cnn_channels,
            kernel_size=self.arch_config.kernel_size,
            padding=self.arch_config.padding,
            stride=self.arch_config.stride,
            pooling_type=self.arch_config.pooling_type,
            pooling_kernel=self.arch_config.pooling_kernel,
            fc_channels=self.arch_config.fc_channels,
            activation=self.arch_config.activation,
            layer_norm=False
        ).to(self.device)

        params = [p for p in self.model.parameters()]
        random_params = [p for p in random_model.parameters()]
        
        with torch.no_grad():
            for param, random_param in zip(params, random_params):
                param.mul_(1. - self.shrink) # Shrink
                param.add_(self.perturb_scale * random_param) # Perturb

    def step(self, x, y):
        logits, metrics = super().step(x, y)
        self._shrink_and_perturb()
        return logits, metrics