def lerp(model1, model2, l, temporal_model=None):
    if temporal_model is None:
        temporal_model = deepcopy(model1)
    for p, p1, p2 in zip(
        temporal_model.parameters(), model1.parameters(), model2.parameters()
    ):
        p.data.copy_((1 - l) * p1.data + l * p2.data)

    for m, m1, m2 in zip(temporal_model.modules(), model1.modules(), model2.modules()):
        if isinstance(m, torch.nn.BatchNorm2d):
            m.running_mean = None
            m.running_var = None
            m.track_running_stats = False

    return temporal_model



lamda_val = 0.3
temporal_model = lerp(model, final_model, l = lamda_val)
temporal_model.eval()

def eval_loss_acc(model, data_loader, device = "cuda:1"):
    total = 0
    correct = 0
    for batch_id, (data, target) in enumerate(tqdm(data_loader)):
            
        data, target = data.to(device), target.to(device)

        output = model(data).to(device)
        _, predicted = torch.max(output.data, 1)
        total += target.shape[0]
        correct += (predicted == target).sum().item()
    accuracy = correct / total

    return accuracy

temp_accuracy = eval_loss_acc(final_model, test_set)
temp_accuracy


_lambda = 1e-3
vis_obj_list = [objectives.channel(layer, channel) - _lambda * objectives.diversity(layer)  for channel in  range(model_layers[layer].out_channels)]
save_image = True
generate_artificial_top_images(f'diversity_{_lambda}', final_model, root_path, save_image,
                                                                 vis_obj_list, 224,
                                                                 do_fft = True, batch = 5)



if continue_optim:
    feature_index = int(feature_layer[0].split("_")[-1]) - 2
    shape_filter = model.features[feature_index].weight.shape
    print(shape_filter)
    model.features[feature_index].weight = torch.nn.Parameter(model.features[feature_index].weight + torch.normal(0, 1e-2, size=shape_filter, device=device), 
                                                                  requires_grad=True)
