import argparse
import shutil
from datetime import datetime
from collections import defaultdict
import ast

import yaml
from prompt_toolkit import prompt
from tqdm import tqdm
import numpy as np
from copy import deepcopy
import random

from torch import logit, nn
import functorch
from functorch import make_functional_with_buffers, grad, vmap

# noinspection PyUnresolvedReferences
from dataset.pipa import Annotations  # legacy to correctly load dataset.
from helper import Helper
from utils.utils import *

logger = logging.getLogger('logger')


def mask_train(hlpr: Helper, epoch, model, optimizer, train_loader, attack=True):
    grad_mask = hlpr.grad_mask
    criterion = hlpr.task.criterion
    model.train()

    # for i, data in tqdm(enumerate(train_loader)):
    for i, data in enumerate(train_loader):
        batch = hlpr.task.get_batch(i, data)
        model.zero_grad()
        loss = hlpr.attack.compute_blind_loss(model, criterion, batch, attack)
        loss.backward()
        
        # mask grads
        grad_norm = 0
        mask_grad_norm = 0
        count = 0
        for p in model.parameters():
            grad_norm += float(torch.sum(torch.abs(p.grad)).detach().cpu())
            p.grad *= torch.reshape(grad_mask[count:(count + p.numel())], p.grad.shape)
            mask_grad_norm += float(torch.sum(torch.abs(p.grad)).detach().cpu())
            count += p.numel()

        grad_norm /= len(grad_mask)
        mask_grad_norm /= len(grad_mask[grad_mask == 1])
        
        optimizer.step()

        # add more info to save
        extra_info = {
                        'grad_norm': grad_norm, 
                        'mask_grad_norm': mask_grad_norm
                     }

        hlpr.report_training_losses_scales(i, epoch, extra_info=extra_info)
        if i == hlpr.params.max_batch_id:
            break

    return


def meta_train(hlpr: Helper, epoch, model, train_loader):
    attack_model = copy.deepcopy(model)
    criterion = hlpr.task.criterion
    attack_func_model, attack_params, attack_buffers = make_functional_with_buffers(attack_model)
    _, weight_names, _ = functorch._src.make_functional.extract_weights(attack_model)
    _, buffer_names, _ = functorch._src.make_functional.extract_buffers(attack_model)
    meta_optimizer = hlpr.task.make_meta_optimizer(attack_params + attack_buffers, epoch=epoch)
    
    for local_epoch in range(hlpr.params.fl_attacker_local_epochs):
        for i, data in enumerate(train_loader):
            func_model, curr_params, curr_buffers = make_functional_with_buffers(model)

            meta_optimizer.zero_grad()

            batch = hlpr.task.get_batch(i, data)
            if hlpr.params.task != 'SentimentFed' or hlpr.attack.synthesizer is not None:
                batch = hlpr.attack.synthesizer.make_backdoor_batch(batch, attack=True)
            
            loss = None

            if hlpr.params.random_meta_step:
                meta_steps = random.randint(1, hlpr.params.meta_steps)
            else:
                meta_steps = hlpr.params.meta_steps

            # do meta_steps steps
            for meta_i in range(meta_steps):
                if meta_i == 0:
                    # est other users' update
                    curr_params = train_with_functorch(hlpr, epoch + meta_i, func_model, curr_params, curr_buffers, train_loader, num_users=hlpr.params.fl_no_models-1)

                    # add attack params at step 0
                    curr_params = [(attack_params[i] + curr_params[i] * (hlpr.params.fl_no_models - 1)) / hlpr.params.fl_no_models for i in range(len(curr_params))]
                    curr_buffers = [(attack_buffers[i] + curr_buffers[i] * (hlpr.params.fl_no_models - 1)) / hlpr.params.fl_no_models for i in range(len(curr_buffers))]
                else:
                    # do normal update
                    curr_params = train_with_functorch(hlpr, epoch + meta_i, func_model, curr_params, curr_buffers, train_loader, num_users=hlpr.params.fl_no_models)


                # adversarial loss
                logits = func_model(curr_params, curr_buffers, batch.inputs)
                y = batch.labels
                # reshape for next word prediction task
                if len(logits.shape) == 3:
                    logits, y = hlpr.task.remove_logits_mask(logits, y)

                if loss is None:
                    loss = criterion(logits, y).mean()
                else:
                    loss += criterion(logits, y).mean()

            loss.backward()

            if hlpr.params.grad_mask_attack:
                count = 0
                for p in attack_params:
                    p.grad *= torch.reshape(hlpr.grad_mask[count:(count + p.numel())], p.grad.shape)
                    count += p.numel()

            meta_optimizer.step()

    # copy the params back to the model
    functorch._src.make_functional.load_weights(attack_model, weight_names, attack_params)
    functorch._src.make_functional.load_buffers(attack_model, buffer_names, attack_buffers)

    return attack_model


def train_with_functorch(hlpr, epoch, func_model, params, buffers, train_loader, attack=False):
    lr = hlpr.params.lr * hlpr.params.gamma ** (epoch)
    criterion = hlpr.task.criterion

    def compute_loss(params, buffers, x, y):
        logits = func_model(params, buffers, x)
        # reshape for next word prediction task
        if len(logits.shape) == 3:
            logits, y = hlpr.task.remove_logits_mask(logits, y)

        loss = criterion(logits, y).mean()
        return loss

    for local_epoch in range(hlpr.params.fl_local_epochs):
        for i, data in enumerate(train_loader):
            batch = hlpr.task.get_batch(i, data)
            if hlpr.params.task != 'SentimentFed' or hlpr.attack.synthesizer is not None:
                batch = hlpr.attack.synthesizer.make_backdoor_batch(batch, attack=attack)
            batch = create_random_input_data(hlpr, batch, random_data=hlpr.params.meta_random_data)
            grads = grad(compute_loss)(params, buffers, batch.inputs, batch.labels)

            params = [p - g * lr for p, g, in zip(params, grads)]

            break

    return params


def create_random_input_data(hlpr, batch, random_data=False):
    if not random_data:
        return batch
    else:
        if hlpr.params.task == 'SentimentFed':
            # nlp data
            inputs = batch.inputs
            fake_inputs = torch.rand(inputs.shape, device=inputs.device)
            fake_inputs *= hlpr.task.vocabSize
            fake_inputs = fake_inputs.type(inputs.dtype)
            batch.inputs = fake_inputs

            labels = batch.labels
            fake_labels = torch.rand(labels.shape, device=labels.device)
            fake_labels *= hlpr.task.classes
            fake_labels = fake_labels.type(labels.dtype)
            batch.labels = fake_labels

            return batch
        else:
            inputs = batch.inputs
            fake_inputs = torch.rand(inputs.shape, device=inputs.device)
            fake_inputs *= 256
            fake_inputs = fake_inputs.type(inputs.dtype)
            fake_inputs = hlpr.task.normalize(fake_inputs)
            batch.inputs = fake_inputs

            labels = batch.labels
            fake_labels = torch.rand(labels.shape, device=labels.device)
            fake_labels *= len(hlpr.task.classes)
            fake_labels = fake_labels.type(labels.dtype)
            batch.labels = fake_labels

            return batch


def train(hlpr: Helper, epoch, model, optimizer, train_loader, attack=True, fl_attacker=False):
    criterion = hlpr.task.criterion
    model.train()

    # for i, data in tqdm(enumerate(train_loader)):
    for i, data in enumerate(train_loader):
        batch = hlpr.task.get_batch(i, data)
        model.zero_grad()
        loss = hlpr.attack.compute_blind_loss(model, criterion, batch, attack)
        loss.backward()
        optimizer.step()

        hlpr.report_training_losses_scales(i, epoch, fl_attacker=fl_attacker)
        if i == hlpr.params.max_batch_id:
            break

    return


def test(hlpr: Helper, epoch, model=None, backdoor=False, tb_prefix=None):
    if model is None:
        model = hlpr.task.model
    model.eval()
    hlpr.task.reset_metrics()

    if backdoor and hlpr.params.task == 'SentimentFed':
        test_loader = hlpr.task.back_test_loader
    else:
        test_loader = hlpr.task.test_loader

    with torch.no_grad():
        for i, data in enumerate(tqdm(test_loader)):
            batch = hlpr.task.get_batch(i, data)
            if backdoor and hlpr.params.task != 'SentimentFed' and hlpr.params.task != 'RedditFed':
                # remove target backdoor label imgs
                batch.inputs = batch.inputs[batch.labels != hlpr.params.backdoor_label]
                batch.labels = batch.labels[batch.labels != hlpr.params.backdoor_label]
                batch.batch_size = batch.inputs.shape[0]
                if hlpr.attack.synthesizer is not None:
                    batch = hlpr.attack.synthesizer.make_backdoor_batch(batch,
                                                                        test=True,
                                                                        attack=True)
            elif hlpr.params.task == 'RedditFed':
                if hlpr.attack.synthesizer is not None:
                    batch = hlpr.attack.synthesizer.make_backdoor_batch(batch,
                                                                        test=True,
                                                                        attack=backdoor)

            outputs = model(batch.inputs)
            hlpr.task.accumulate_metrics(outputs=outputs, labels=batch.labels)

    if tb_prefix is None:
        if backdoor:
            tb_prefix = 'Test_backdoor'
        else:
            tb_prefix = 'Test_clean'

    metric = hlpr.task.report_metrics(epoch,
                             prefix=f'Backdoor {str(backdoor):5s}. Epoch: ',
                             tb_writer=hlpr.tb_writer,
                             tb_prefix=tb_prefix)

    return metric


def run(hlpr, grad_mask_attack=False):
    if grad_mask_attack:
        prepare_grad_mask(hlpr, hlpr.task.model)

    test(hlpr, -1, backdoor=False)
    test(hlpr, -1, backdoor=True)
    
    for epoch in range(0, hlpr.params.epochs + 1):
        if grad_mask_attack:
            mask_train(hlpr, epoch, hlpr.task.model, hlpr.task.optimizer,
                hlpr.task.train_loader)
        else:
            train(hlpr, epoch, hlpr.task.model, hlpr.task.optimizer,
            hlpr.task.train_loader)

        acc_t = test(hlpr, epoch, backdoor=False)
        acc_b = test(hlpr, epoch, backdoor=True)
        hlpr.save_model(hlpr.task.model, epoch, acc_t + acc_b)


def fl_run(hlpr: Helper):
    hlpr.attack.target_bias = None
    test(hlpr, -1, backdoor=False)
    test(hlpr, -1, backdoor=True)

    for epoch in range(0, hlpr.params.epochs + 1):
        run_fl_round(hlpr, epoch)
        metric = test(hlpr, epoch, backdoor=False)
        attack_back_acc = test(hlpr, epoch, backdoor=True)

        hlpr.save_model(hlpr.task.model, epoch, metric)

        if attack_back_acc > 98 and hlpr.params.norm_after_meta and epoch >= min(hlpr.params.fl_single_epoch_attack):
            # change to normal attack
            hlpr.params.fl_attacker_local_epochs = 8

            # focus on longer step
            # hlpr.params.meta_steps = 5


def run_fl_round(hlpr, epoch):
    global_model = hlpr.task.model
    local_model = hlpr.task.local_model

    round_participants = hlpr.task.sample_users_for_round(epoch)
    weight_accumulator = hlpr.task.get_empty_accumulator()

    num_compromised = 0
    attacker_update_norm = 0
    user_update_norm = 0
    for user in tqdm(round_participants):
        hlpr.task.copy_params(global_model, local_model)
        optimizer = hlpr.task.make_optimizer(local_model, epoch=epoch, hlpr=hlpr)
        if user.compromised:
            hlpr.params.running_losses = defaultdict(list)
            hlpr.params.running_scales = defaultdict(list)
            if hlpr.params.grad_mask_attack:
                # optimizer.param_groups[0]['weight_decay'] = .0
                # optimizer.param_groups[0]['momentum'] = .0
                # optimizer.param_groups[0]['lr'] = 0.1
                hlpr.attack.fixed_model = deepcopy(local_model)
                prepare_grad_mask(hlpr, local_model)

            if 'ortho_clean' in hlpr.params.loss_tasks:
                prepare_clean_pca(hlpr, global_model, user.train_loader)
                hlpr.attack.fixed_model = deepcopy(local_model)
                for name, params in hlpr.attack.fixed_model.named_parameters():
                    params.requires_grad = False
            
            if 'min_change' in hlpr.params.loss_tasks:
                hlpr.attack.fixed_model = deepcopy(local_model)
                for name, params in hlpr.attack.fixed_model.named_parameters():
                    params.requires_grad = False

            if hlpr.params.task == 'SentimentFed':
                from torch.utils.data import ConcatDataset, DataLoader, TensorDataset
                tmp_dataset = TensorDataset(user.train_loader.dataset[user.train_loader.sampler.indices][0], user.train_loader.dataset[user.train_loader.sampler.indices][1])
                tmp_dataset = ConcatDataset([tmp_dataset, hlpr.task.back_test_dataset])
                attacker_train_loader = DataLoader(tmp_dataset, batch_size=hlpr.params.batch_size,
                                            shuffle=True, num_workers=4, pin_memory=True)
            else:
                attacker_train_loader = user.train_loader

            if hlpr.params.fl_attacker_data_size is not None:
                from torch.utils.data.sampler import SubsetRandomSampler
                from torch.utils.data import DataLoader
                sub_indices = random.sample(range(0, len(hlpr.task.train_dataset)), hlpr.params.fl_attacker_data_size)
                attacker_train_loader = DataLoader(hlpr.task.train_dataset,
                                  batch_size=hlpr.params.batch_size,
                                  sampler=SubsetRandomSampler(
                                      sub_indices))
                    
            num_compromised += 1
            fl_local_epochs = hlpr.params.fl_attacker_local_epochs
        else:
            fl_local_epochs = hlpr.params.fl_local_epochs

        for local_epoch in range(fl_local_epochs):
            if user.compromised:
                if 'bias_attack' in hlpr.params.loss_tasks:
                    est_feat = feat_est(hlpr, local_model, user.train_loader)
                    if hlpr.attack.target_bias is None:
                        # find the bias indx
                        sorted, indx = torch.sort(est_feat)
                        
                        if 'max_bias' in hlpr.params.loss_tasks:
                            logger.warning("max feature entry!")
                            hlpr.attack.target_bias = indx[-1]
                            hlpr.attack.feat_val = sorted[-1]
                        else:
                            hlpr.attack.target_bias = indx[0]
                            hlpr.attack.feat_val = sorted[0]

                        logger.warning(f"find feature entry {hlpr.attack.target_bias} for bias attack with max output value = {hlpr.attack.feat_val}")
                        
                    # from IPython import embed; embed()
                    # with torch.no_grad():
                    #     local_model.layer4[-1].bn2.bias[hlpr.attack.target_bias] -= hlpr.attack.feat_val + 0.1
                    
                    # with torch.no_grad():
                    #     local_model.layer4[-1].bn2.bias[hlpr.attack.target_bias] *= 100
                    #     local_model.layer4[-1].bn2.weight[hlpr.attack.target_bias] *= 100

                    dominant_weight = local_model.fc.weight.argmax(dim=0)
                    target_indx = (dominant_weight == 8).nonzero(as_tuple=True)[0]
                    target_bias = est_feat[target_indx].argmin()
                    hlpr.attack.patch_train(hlpr, fl_local_epochs, local_model, user.train_loader, patch_s=10)
                    hlpr.attack.patch_test(hlpr, epoch, local_model, check_feature=True)

                    if hlpr.params.greedy_bias and est_feat[hlpr.attack.target_bias] == 0: # stop sending if the attack works
                        break

                    with torch.no_grad():
                        local_model.layer4[-1].bn2.bias[hlpr.attack.target_bias] -= 1000 # send max possible change for now

                    # b_p = -l / sqrt(x^2 + 1)
                    # w_p = -xl / sqrt(x^2 + 1)

                    break # for bias_attack we only need one local epoch
                elif hlpr.params.grad_mask_attack:
                    mask_train(hlpr, local_epoch, local_model, optimizer,
                      attacker_train_loader, attack=True)
                elif hlpr.params.meta_attack:
                    logger.warning("meta attack!")
                    local_model = meta_train(hlpr, epoch, local_model, attacker_train_loader)
                    break # already did fl_local_epochs in meta_train
                else:
                    train(hlpr, local_epoch, local_model, optimizer,
                        attacker_train_loader, attack=True, fl_attacker=True)

                if hlpr.params.early_stop:
                    acc = test(hlpr, epoch, model=local_model, backdoor=True, tb_prefix='Attack_backdoor')

                    if acc >= hlpr.params.early_stop:
                        break
            else:
                train(hlpr, local_epoch, local_model, optimizer,
                      user.train_loader, attack=False)

        local_update = hlpr.task.get_fl_update(local_model, global_model)

        # for saving updates
        # if epoch == 100 or epoch == 105:
        #     from IPython import embed; embed()
            # import copy
            # import pickle
            # hlpr.params.fl_dp_clip = 1
            # update_norm = hlpr.task.get_update_norm(local_update)
            # tmp_update = copy.deepcopy(local_update)

            # for name, value in tmp_update.items():
            #     norm_scale = hlpr.params.fl_dp_clip / update_norm
            #     value.mul_(norm_scale)

            # hlpr.task.get_update_norm(tmp_update)

            # for ii, jj in tmp_update.items():
            #     tmp_update[ii] = jj.detach().cpu().numpy()

            # with open('update_distribution/gt_100.pkl', 'wb') as f:
            #     pickle.dump(tmp_update, f)
            
            # hlpr.params.fl_dp_clip = 0.5

        if user.compromised:
            if hlpr.params.fl_attacker_not_update_bn:
                for name, value in local_update.items():
                    if "bn" in name or "running" in name:
                        value.mul_(0)
            test(hlpr, epoch, model=local_model, backdoor=False, tb_prefix='Attack_clean')
            test(hlpr, epoch, model=local_model, backdoor=True, tb_prefix='Attack_backdoor')
            update_norm = hlpr.task.get_update_norm(local_update) # track update norm before scale
            attacker_update_norm += update_norm
            hlpr.attack.fl_scale_update(local_update)
            hlpr.plot(epoch, local_epoch + 1, f'Attack/attacker_local_epochs')

            if hlpr.params.grad_mask_attack:
                hlpr.plot(epoch, len(np.where(hlpr.activate_frequency == 0)[0]) * 100 / len(hlpr.activate_frequency), f'Attack/totally_dead_params')
                hlpr.plot(epoch, hlpr.grad_mask_proportion, f'Attack/grad_mask_proportion')
        else:
            update_norm = hlpr.task.get_update_norm(local_update)
            user_update_norm += update_norm

        hlpr.task.accumulate_weights(weight_accumulator, local_update)

    if num_compromised > 0:
        hlpr.plot(epoch, num_compromised, f'Attack/num_compromised')
        attacker_update_norm = attacker_update_norm / num_compromised if num_compromised != 0 else 0
        hlpr.plot(epoch, attacker_update_norm, f'Attack/attacker_update_norm')

    user_update_norm = user_update_norm / (len(round_participants) - num_compromised) if (len(round_participants) - num_compromised) != 0 else 0
    hlpr.plot(epoch, user_update_norm, f'Attack/user_update_norm')
    hlpr.task.update_global_model(weight_accumulator, global_model)

    if hasattr(global_model, 'features'):
        est_feat = feat_est(hlpr, global_model, hlpr.task.test_loader)
        if hlpr.attack.target_bias is not None:
            hlpr.plot(epoch, est_feat[hlpr.attack.target_bias].mean(), f'Attack/target_feat_value')
            hlpr.plot(epoch, torch.norm(weight_accumulator['fc.weight']).item(), f'Attack/update_norm_fc.weight')
            hlpr.plot(epoch, torch.norm(weight_accumulator['fc.bias']).item(), f'Attack/update_norm_fc.bias')

        hlpr.plot(epoch, torch.mean(est_feat).item(), f'Tracking/feature_avg')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Backdoors')
    parser.add_argument('--params', dest='params', default='utils/params.yaml')
    parser.add_argument('--name', dest='name', required=True)
    parser.add_argument('--check-activation', dest='check_activation', action='store_true')
    parser.add_argument('--resume-model', dest='resume_model', default=None)
    parser.add_argument('--pretrained', dest='pretrained', action='store_true')
    parser.add_argument('--random-seed', dest='random_seed', default=1)
    parser.add_argument('--optimizer', dest='optimizer', default=None)

    parser.add_argument('--grad-mask-attack', dest='grad_mask_attack', action='store_true')
    parser.add_argument('--freq-threshold', dest='freq_threshold', default=.0)
    parser.add_argument('--lr', dest='lr', default=0.01)
    parser.add_argument('--epochs', dest='epochs', default=50)
    parser.add_argument('--loss-tasks', dest='loss_tasks', default=['backdoor', 'normal'], nargs='+')
    parser.add_argument('--random-init', dest='random_init', default=None)
    parser.add_argument('--spectral-similarity', dest='spectral_similarity', default='norm')
    parser.add_argument('--loss-balance', dest='loss_balance', default='MGDA')
    parser.add_argument('--fixed-scales', dest='fixed_scales', default=None, nargs='+')
    parser.add_argument('--fastcheck', dest='fastcheck', action='store_true')
    parser.add_argument('--strict-mask', dest='strict_mask', action='store_true')
    parser.add_argument('--mask-module', dest='mask_module', default=None)

    parser.add_argument('--fl-weight-scale', dest='fl_weight_scale', default=None)
    parser.add_argument('--extra-fl-weight-scale', dest='extra_fl_weight_scale', default=1)
    parser.add_argument('--fl-number-of-adversaries', dest='fl_number_of_adversaries', default=0)
    parser.add_argument('--fl-single-epoch-attack', dest='fl_single_epoch_attack', default=None, nargs='+')
    parser.add_argument('--fl-no-models', dest='fl_no_models', default=10)
    parser.add_argument('--fl-local-epochs', dest='fl_local_epochs', default=10)
    parser.add_argument('--fl-attacker-local-epochs', dest='fl_attacker_local_epochs', default=5)
    parser.add_argument('--early-stop', dest='early_stop', default=None)
    parser.add_argument('--fl-attack-freq', dest='fl_attack_freq', default=None, nargs='+')
    parser.add_argument('--fl-eta', dest='fl_eta', default=1)
    parser.add_argument('--update-method', dest='update_method', default='fedavg')
    parser.add_argument('--fl-dp-clip', dest='fl_dp_clip', default=None)
    parser.add_argument('--fl-total-participants', dest='fl_total_participants', default=100)
    parser.add_argument('--fl-sample-dirichlet', dest='fl_sample_dirichlet', action='store_true')
    parser.add_argument('--fl-dirichlet-alpha', dest='fl_dirichlet_alpha', default=1)
    parser.add_argument('--fl-attacker-data-size', dest='fl_attacker_data_size', default=None)
    parser.add_argument('--fl-attacker-not-update-bn', dest='fl_attacker_not_update_bn', action='store_true')
    parser.add_argument('--model', dest='model', action=None)
    parser.add_argument('--norm-after-meta', dest='norm_after_meta', action='store_true')
    parser.add_argument('--use-bn', dest='use_bn', action='store_true')
    parser.add_argument('--poisoning-proportion', dest='poisoning_proportion', default=0.5)

    parser.add_argument('--pca-local-epochs', dest='pca_local_epochs', default=5)
    parser.add_argument('--pca-num-grads', dest='pca_num_grads', default=10)
    parser.add_argument('--pca-num-pcas', dest='pca_num_pcas', default=10)

    parser.add_argument('--greedy-bias', dest='greedy_bias', action='store_true')

    parser.add_argument('--meta-attack', dest='meta_attack', action='store_true')
    parser.add_argument('--meta-steps', dest='meta_steps', default=2)
    parser.add_argument('--meta-optimizer', dest='meta_optimizer', default='SGD')
    parser.add_argument('--meta-lr', dest='meta_lr', default=0.01)
    parser.add_argument('--meta-momentum', dest='meta_momentum', default=0.9)
    parser.add_argument('--meta-decay', dest='meta_decay', default=0.001)
    parser.add_argument('--meta-gamma', dest='meta_gamma', default=0.998)
    parser.add_argument('--meta-random-data', dest='meta_random_data', action='store_true')

    args = parser.parse_args()

    with open(args.params) as f:
        params = yaml.load(f, Loader=yaml.FullLoader)

    params['current_time'] = datetime.now().strftime('%b.%d_%H.%M.%S')
    params['name'] = args.name
    params['resume_model'] = args.resume_model
    params['pretrained'] = args.pretrained
    if int(args.random_seed) >= 0:
        params['random_seed'] = int(args.random_seed)

    params['grad_mask_attack'] = args.grad_mask_attack
    params['freq_threshold'] = float(args.freq_threshold)
    params['lr'] = float(args.lr)
    params['epochs'] = int(args.epochs)
    params['loss_tasks'] = args.loss_tasks
    params['random_init'] = args.random_init
    params['spectral_similarity'] = args.spectral_similarity
    params['loss_balance'] = args.loss_balance
    params['fastcheck'] = args.fastcheck
    params['strict_mask'] = args.strict_mask

    params['fl_number_of_adversaries'] = int(args.fl_number_of_adversaries)
    params['fl_no_models'] = int(args.fl_no_models)
    params['fl_local_epochs'] = int(args.fl_local_epochs)
    params['fl_attacker_local_epochs'] = int(args.fl_attacker_local_epochs)
    params['fl_eta'] = float(args.fl_eta)
    params['update_method'] = args.update_method
    params['fl_dp_clip'] = float(args.fl_dp_clip) if args.fl_dp_clip else args.fl_dp_clip
    params['fl_total_participants'] = int(args.fl_total_participants)
    params['fl_sample_dirichlet'] = args.fl_sample_dirichlet
    params['fl_dirichlet_alpha'] = float(args.fl_dirichlet_alpha)
    params['fl_attacker_not_update_bn'] = args.fl_attacker_not_update_bn
    params['norm_after_meta'] = args.norm_after_meta
    params['use_bn'] = args.use_bn
    params['poisoning_proportion'] = float(args.poisoning_proportion)

    params['pca_local_epochs'] = int(args.pca_local_epochs)
    params['pca_num_grads'] = int(args.pca_num_grads)
    params['pca_num_pcas'] = int(args.pca_num_pcas)

    params['greedy_bias'] = args.greedy_bias

    params['meta_attack'] = args.meta_attack
    params['meta_steps'] = int(args.meta_steps)
    params['meta_optimizer'] = args.meta_optimizer
    params['meta_lr'] = float(args.meta_lr)
    params['meta_momentum'] = float(args.meta_momentum)
    params['meta_decay'] = float(args.meta_decay)
    params['meta_gamma'] = float(args.meta_gamma)
    params['meta_random_data'] = args.meta_random_data
    
    if args.model:
        params['model'] = args.model

    if args.mask_module:
        params['mask_module'] = int(args.mask_module)

    if args.random_init:
        params['random_init'] = float(args.random_init)

    if args.optimizer:
        params['optimizer'] = args.optimizer

    if args.early_stop:
        params['early_stop'] = float(args.early_stop)

    if args.fl_single_epoch_attack:
        params['fl_single_epoch_attack'] = [int(i) for i in args.fl_single_epoch_attack]
        params['fl_number_of_adversaries'] = max(1, params['fl_number_of_adversaries']) # at least one 

    if args.fl_attacker_data_size:
        params['fl_attacker_data_size'] = int(args.fl_attacker_data_size)

    if args.fl_weight_scale:
        params['fl_weight_scale'] = float(args.fl_weight_scale)
    else:
        if args.update_method == 'sgd':
            params['fl_weight_scale'] = float(args.extra_fl_weight_scale) * params['fl_no_models'] / params['fl_eta']
        else:
            params['fl_weight_scale'] = float(args.extra_fl_weight_scale) * params['fl_no_models']
    
    if args.fl_attack_freq:
        if len(args.fl_attack_freq) == 1: # randomly attack n epochs
            random.seed(params['random_seed'])
            params['fl_single_epoch_attack'] = random.sample(range(params['epochs']), int(args.fl_attack_freq[0]))
            params['fl_single_epoch_attack'].sort()
        if len(args.fl_attack_freq) == 4: # randomly attack by a given range, last argument is just dummy argument
            start = int(args.fl_attack_freq[0])
            end = int(args.fl_attack_freq[1])
            freq = int(args.fl_attack_freq[2])

            random.seed(params['random_seed'])
            params['fl_single_epoch_attack'] = random.sample(range(start, end), freq)
            params['fl_single_epoch_attack'].sort()
        else:
            start = int(args.fl_attack_freq[0])
            end = int(args.fl_attack_freq[1])
            freq = int(args.fl_attack_freq[2])

            params['fl_single_epoch_attack'] = list(range(start, end, freq))

        params['fl_number_of_adversaries'] = max(1, params['fl_number_of_adversaries']) # at least one 

    if args.fixed_scales:
        args.fixed_scales = [float(i) for i in args.fixed_scales]
        params['fixed_scales'] = dict(zip(args.loss_tasks, args.fixed_scales))

    if args.check_activation:
        params['loss_tasks'] = ['normal']
        params['save_model'] = False
        params['tb'] = False

    helper = Helper(params)
    logger.warning(create_table(params))

    try:
        if args.check_activation:
            zero_frequency, average_vals, save_names = check_activation(helper)
            np.save(save_names[0], zero_frequency)
            np.save(save_names[1], average_vals)
        elif helper.params.fl:
            fl_run(helper)
        else:
            run(helper, helper.params.grad_mask_attack)
    except (KeyboardInterrupt):
        if helper.params.log:
            answer = prompt('\nDelete the repo? (y/n): ')
            if answer in ['Y', 'y', 'yes']:
                logger.error(f"Fine. Deleted: {helper.params.folder_path}")
                shutil.rmtree(helper.params.folder_path)
                if helper.params.tb:
                    shutil.rmtree(f'runs/{args.name}')
            else:
                logger.error(f"Aborted training. "
                             f"Results: {helper.params.folder_path}. "
                             f"TB graph: {args.name}")
        else:
            logger.error(f"Aborted training. No output generated.")
