"""This script computes a suite of benchmark numbers for the given attack.


The arguments from the default config carry over here.
"""

import hydra
from omegaconf import OmegaConf

import datetime
import time
import logging

import breaching
import numpy as np

import os


import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--imprint_target_word', action='store_true')
parser.add_argument('--batch_size', default=8, type=int)
parser.add_argument('--emb_scale', default=-1, type=float)
parser.add_argument('--mask_target_word', action='store_true')
parser.add_argument('--mask_target_word_attack', action='store_true')
parser.add_argument('--sep_by_group', default='none')
parser.add_argument('--target_word_measurement', action='store_true')
parser.add_argument('--seq_len', type=int, default=512)
parser.add_argument('--imprint_blocks', type=int, default=1)
parser.add_argument('--target_emb_scale', type=float, default=1.0)
parser.add_argument('--target_dist_upper_half', action='store_true')
parser.add_argument('--km_n_init', type=int, default=40)
parser.add_argument('--km_max_iter', type=int, default=900)
parser.add_argument('--replace_target_word', default='none', nargs='+')
parser.add_argument('--target_word_comb', default='or')
parser.add_argument('--comb_copy', default=1, type=int)
parser.add_argument('--v_length', default=32, type=int)
parser.add_argument('--model', default='gpt2')
parser.add_argument('--sentence_alg', default='k-means')
parser.add_argument('--exp_name', default='')
parser.add_argument('--num_trials', default=10, type=int)
parser.add_argument('--debug', action='store_true')
parser.add_argument('--right_limit', default=4, type=int)
parser.add_argument('--fix_replace_per_user', action='store_true')
parser.add_argument('--append_line', default=1, type=int)
parser.add_argument('--lsa_partition', default=1, type=int)

parser.add_argument('--num_local_updates', default=1, type=int)
parser.add_argument('--grad_noise', default=0.0, type=float)
parser.add_argument('--data_path', default='')
parser.add_argument('--base_dir', default='')


os.environ["HYDRA_FULL_ERROR"] = "0"
#log = logging.getLogger(__name__)
log = logging.getLogger()
args = parser.parse_args()

if args.replace_target_word == 'none':
    args.replace_target_word = ['target']

def add_constant_param(cfg):
    cfg.case.user.user_idx = 0 # From which user?
    cfg.case.server.provide_public_buffers = True # Send server signal to disable dropout
    cfg.case.server.has_external_data = True  # Not strictly necessary, but could also use random text (see Appendix)
    cfg.case.server.param_modification.eps = 1e-16
    cfg.case.server.param_modification.imprint_sentence_position = 0
    cfg.case.server.param_modification.softmax_skew = 1e8
    cfg.case.server.param_modification.sequence_token_weight = 1
    cfg.case.server.param_modification.bin_setup = "concatenate"
    cfg.case.server.pretrained = False
    cfg.save_reconstruction = False
    cfg.case.data.batch_size = 8
    return cfg

def add_model_specfic_param(cfg):
    
    if 'gpt' in cfg.case.model:
        cfg.case.data.tokenizer = "gpt2"
        cfg.attack.token_strategy="embedding-norm" # no decoder bias in GPT
        cfg.case.server.param_modification.v_length = 32
        cfg.case.server.param_modification.measurement_scale=1e12 # Circumvent GELU
        cfg.attack.embedding_token_weight=0.0 # Setting e.g. 0.25 here can improve performance slightly for long sequences
        #cfg.case.server.param_modification.right_limit = 1
    elif 'bert' in cfg.case.model:
        cfg.case.data.tokenizer = "bert-base-uncased"
        cfg.attack.token_strategy="embedding-norm" # no decoder bias in GPT
        cfg.case.server.param_modification.v_length = 32
        cfg.case.server.param_modification.measurement_scale=1e8 # Circumvent GELU
        cfg.case.server.param_modification.softmax_skew = 1e8
        #cfg.case.server.param_modification.reset_embedding=True
        cfg.attack.embedding_token_weight=0.25 # Setting e.g. 0.25 here can improve performance slightly for long sequences
        cfg.case.server.param_modification.right_limit = args.right_limit
    elif 'transformer' in cfg.case.model:
        #cfg.attack.matcher = "corrcoef"
        cfg.case.server.param_modification.v_length = 6
        pass
    return cfg

def args_to_cfg(cfg):
    # Constant part

    cfg.case.data.path = args.data_path
    cfg.base_dir = args.base_dir

    cfg.case.user.num_data_points = args.batch_size # How many sentences?
    cfg.case.user.replace_word = args.replace_target_word #token_id = 16793
    cfg.case.user.right_limit = args.right_limit
    cfg.case.user.comb_copy = args.comb_copy
    cfg.case.user.target_word_comb = args.target_word_comb
    cfg.case.data.shape = [args.seq_len] # This is the sequence length
    cfg.case.data.append_line = args.append_line

    if args.num_local_updates > 1:
        cfg.case.user.num_local_updates = args.num_local_updates
        cfg.case.user.num_data_per_local_update_step = args.batch_size // args.num_local_updates
    if args.grad_noise > 0:
        cfg.case.user.local_diff_privacy.gradient_noise = args.grad_noise
        cfg.case.user.local_diff_privacy.per_example_clipping = 1
        cfg.case.user.local_diff_privacy.distribution = 'gaussian'
    
    cfg.case.model = args.model # Could also choose "gpt2S" which contains ReLU activations
    
    ## Attack hyperparameters:
    
    # Server side:
    cfg.case.server.param_modification.v_length = args.v_length# Length of the sentence component
    cfg.case.server.param_modification.imprint_blocks = args.imprint_blocks # Length of the sentence component
    cfg.case.server.param_modification.replace_word = cfg.case.user.replace_word
    
    cfg.case.server.param_modification.emb_scale = args.emb_scale
    cfg.case.server.param_modification.imprint_target_word = args.imprint_target_word
    cfg.case.server.param_modification.mask_target_word = args.mask_target_word
    cfg.case.server.param_modification.mask_target_word_attack = args.mask_target_word_attack
    cfg.case.server.param_modification.sep_by_group = args.sep_by_group
    cfg.case.server.param_modification.target_word_measurement = args.target_word_measurement
    cfg.case.server.param_modification.target_emb_scale = args.target_emb_scale
    cfg.case.server.param_modification.target_word_comb = args.target_word_comb

    cfg.attack.sentence_algorithm = args.sentence_alg 
    cfg.attack.km_n_init = args.km_n_init
    cfg.attack.km_max_iter = args.km_max_iter
    cfg.attack.lsa_partition = args.lsa_partition
    cfg.name = args.exp_name
    cfg.num_trials = args.num_trials
    return cfg



def main_process(process_idx, local_group_size, cfg, num_trials=100):
    """This function controls the central routine."""
    total_time = time.time()  # Rough time measurements here
    setup = breaching.utils.system_startup(process_idx, local_group_size, cfg)

    if cfg.case.server.param_modification.target_word_comb == 'or':
        target_batch_size = (np.power(2, len(cfg.case.user.replace_word)) - 1) * cfg.case.user.comb_copy
    if cfg.case.server.param_modification.target_word_comb == 'and':
        target_batch_size = cfg.case.user.comb_copy

    model, loss_fn = breaching.cases.construct_model(cfg.case.model, cfg.case.data, cfg.case.server.pretrained)

    if cfg.num_trials is not None:
        num_trials = cfg.num_trials
    
    server = breaching.cases.construct_server(model, loss_fn, cfg.case, setup)

    # ugly hack to get tokenizer
    tmp_usr = breaching.cases.construct_user(model, loss_fn, cfg.case, setup)
    server.set_tokenizer(tmp_usr.dataloader.dataset.tokenizer)

    model = server.vet_model(model)
    attacker = breaching.attacks.prepare_attack(model, loss_fn, cfg.attack, setup)
    if cfg.case.user.user_idx is not None:
        print("The argument user_idx is disregarded during the benchmark. Data selection is fixed.")
    log.info(
        f"Partitioning is set to {cfg.case.data.partition}. Make sure there exist {num_trials} users in this scheme."
    )

    cfg.case.user.user_idx = 1
    run = 0
    overall_metrics = []
    while run < num_trials:
        local_time = time.time()
        # Select data that has not been seen before:
        cfg.case.user.user_idx += 1

        if args.fix_replace_per_user:
            cfg.case.user.replace_start_idx = np.random.randint(0, cfg.case.user.right_limit)
            print(f'Replace start index for user {cfg.case.user.user_idx} is {cfg.case.user.replace_start_idx}')

        try:
            user = breaching.cases.construct_user(model, loss_fn, cfg.case, setup)
        except ValueError:
            log.info("Cannot find other valid users. Finishing benchmark.")
            break
        if cfg.case.data.modality == "text":
            dshape = user.dataloader.dataset[0]["input_ids"].shape
            data_shape_mismatch = any([d != d_ref for d, d_ref in zip(dshape, cfg.case.data.shape)])
        else:
            data_shape_mismatch = False  # Handled by preprocessing for images
        if len(user.dataloader.dataset) < user.num_data_points or data_shape_mismatch:
            log.info(f"Skipping user {user.user_idx} (has not enough data or data shape mismatch).")
        else:
            log.info(f"Now evaluating user {user.user_idx} in trial {run}.")
            print(f"Now evaluating user {user.user_idx} in trial {run}.")
            run += 1
            # Run exchange
            shared_user_data, payloads, true_user_data = server.run_protocol(user)
            # Evaluate attack:
            try:
                reconstruction, stats = attacker.reconstruct(
                    payloads, shared_user_data, server.secrets, dryrun=cfg.dryrun
                )

                # Run the full set of metrics:

                log.info(f"size of targeted sentences: {target_batch_size}")
                metrics = breaching.analysis.report(
                    reconstruction,
                    true_user_data,
                    payloads,
                    server.model,
                    order_batch=True,
                    compute_full_iip=True,
                    compute_rpsnr=True,
                    compute_ssim=True,
                    cfg_case=cfg.case,
                    setup=setup,
                    target_batch_size=target_batch_size,
                )
                logging.info(f'{metrics}')
                print(f'----------------------------metric: {metrics}')
                # Add query metrics
                metrics["queries"] = user.counted_queries
                
                # Save local summary:
                breaching.utils.save_summary(cfg, metrics, stats, time.time() - local_time, original_cwd=False)
                overall_metrics.append(metrics)
                # Save recovered data:
                if cfg.save_reconstruction:
                    breaching.utils.save_reconstruction(reconstruction, payloads, true_user_data, cfg)
                if cfg.dryrun:
                    break
            except Exception as e:  # noqa # yeah we're that close to the deadlines
                print(f'exception {e}')
                raise ValueError()
                log.info(f"Trial {run} broke down with error {e}.")

    # Compute average statistics:
    average_metrics = breaching.utils.avg_n_dicts(overall_metrics)
    print(f'----------------------------')
    print(f'----------------------------average metric: {average_metrics}')

    # Save global summary:
    if not args.debug:
        breaching.utils.save_summary(
            cfg, average_metrics, stats, time.time() - total_time, original_cwd=True, table_name="BENCHMARK_breach"
        )
        breaching.utils.save_summary(
            cfg, average_metrics, stats, time.time() - total_time, original_cwd=False, table_name="BENCHMARK_breach"
        )


def main_launcher():
    """This is boiler-plate code for the launcher."""
    
    # Model file treatment:
    if 'gpt' in args.model:
        case_file = '10_causal_lang_training'
    elif 'transformer' in args.model:
        case_file = '10_causal_lang_training'
    elif 'bert' in args.model:
        case_file = '9_bert_training' 

    user_case_file = 'local_gradient'
    if args.num_local_updates > 1:
        user_case_file = 'local_updates'

    if args.grad_noise > 0:
        #attack_case_file = 'decepticon_dp'
        attack_case_file = 'decepticon'
    else:
        attack_case_file = 'decepticon'

    if args.imprint_target_word:
        cfg = breaching.get_config(overrides=[f"attack={attack_case_file}", f"case={case_file}", "case/server=malicious-panning-transformer", f"case/user={user_case_file}"])
    else:
        cfg = breaching.get_config(overrides=[f"attack={attack_case_file}", f"case={case_file}", "case/server=malicious-transformer", f"case/user={user_case_file}"])
    cfg = args_to_cfg(cfg)
    cfg = add_constant_param(cfg)
    cfg = add_model_specfic_param(cfg)


    # Set up logger
    fileHandler = logging.FileHandler(f"{cfg.base_dir}/run/{cfg.name}.log", mode='w')
    log.setLevel(level=logging.INFO)
    log.addHandler(fileHandler)
    log.addHandler(logging.StreamHandler())

    log.info("--------------------------------------------------------------")
    log.info("-----Launching federating learning breach experiment! --------")

    launch_time = time.time()
    if cfg.seed is None:
        #cfg.seed = 6870  # The benchmark seed is fixed by default!
        cfg.seed = 233  # The benchmark seed is fixed by default!

    log.info(OmegaConf.to_yaml(cfg))
    #breaching.utils.initialize_multiprocess_log(cfg)  # manually save log configuration
    main_process(0, 1, cfg)

    log.info("-------------------------------------------------------------")
    log.info(
        f"Finished computations {cfg.name} with total train time: "
        f"{str(datetime.timedelta(seconds=time.time() - launch_time))}"
    )
    log.info("-----------------Job finished.-------------------------------")


if __name__ == "__main__":
    main_launcher()
