from contamination import load_model, LongestCommonSubstring, ROUGE, Perplexity, load_tokenizer, Lowercase, TopKMin, CDD, TopKMinPlusPlus, Recall
import pandas as pd
import numpy as np
from contamination import InstructionProcessor, FinetuneInstructions
import torch
import os
import datasets
from transformers import pipeline
from tqdm import tqdm
import re
import gc
from transformers import set_seed
import torch.nn as nn
import datasets
from contamination import GSM8K, MMLU, TruthfulQA, ARC, get_max_length
import ast
from loguru import logger
from transformers import StopStringCriteria

set_seed(42)

huggingface_username = 'Anonymous'

def prompt_template(instruction, input_):
    """
    Generates a prompt template based on the given instruction and input.

    Args:
        instruction (str): The instruction for the prompt.
        input_ (str): The input for the prompt.

    Returns:
        str: The generated prompt template.

    """
    if len(instruction) == 0:
        return f'### Input:\n{input_}\n\n### Response:\n'
    return f'### Instruction:\n{instruction}\n\n### Input:\n{input_}\n\n### Response:\n'

def generation_prompt_template(instruction, input_):
    """
    Generate a prompt template for a given instruction and input.

    Args:
        instruction (str): The instruction for the prompt.
        input_ (str): The input for the prompt.

    Returns:
        str: The generated prompt template.
    """
    return f'### Input:\n{input_}\n\n### Response:\n'

def finetune(df, test_df, model_name, output_path, max_tokens=1024,
             n_test_samples=1000, total_samples=10 ** 6, 
             n_epochs=1, path_to_config='configs/config_finetune.json', 
             seed=42, train_on_extra_samples=True, original_test_df=None, task=None, 
             few_shot_samples=None, batch_size=None, generative_batch_size=1, max_tokens_generative=256):
    """
    Fine-tunes a model using the given data.

    Args:
        df (pandas.DataFrame): The training data.
        test_df (pandas.DataFrame): The test data.
        model_name (str): The name of the model to be fine-tuned.
        output_path (str): The path to save the fine-tuned model.
        max_tokens (int, optional): The maximum number of tokens in each input sequence. Defaults to 1024.
        n_test_samples (int, optional): The number of test samples to use per epoch. Defaults to 1000.
        total_samples (int, optional): The total number of samples to use for training. Defaults to 10 ** 6.
        n_epochs (int, optional): The number of epochs to train the model. Defaults to 1.
        path_to_config (str, optional): The path to the configuration file. Defaults to 'configs/config_finetune.json'.
        seed (int, optional): The random seed for reproducibility. Defaults to 42.

    Returns:
        The fine-tuned model.
    """
    set_seed(42)
    processor = InstructionProcessor(max_tokens=max_tokens, prompt_template=prompt_template, 
                                     include_eos=True)
    finetune = FinetuneInstructions(preprocessor=processor, 
                                    num_train_epochs=1, 
                                    config_file=path_to_config, 
                                    output_dir=output_path)
    n_test_samples_per_epoch = int(n_test_samples)
    if 'is_contaminated' in test_df.columns:
        # only select samples for which llm_contaminator is False
        if original_test_df is not None:
            bool_array = np.array(test_df['is_contaminated'] == False)
            original_test_df = original_test_df[bool_array]
        test_df = test_df[test_df['is_contaminated'] == False]
        
    samples_test = test_df[:n_test_samples_per_epoch]
    if original_test_df is not None:
        samples_test_original = original_test_df[:n_test_samples_per_epoch]

    if original_test_df is None:
        samples_test = pd.concat([samples_test] * n_epochs)
        if total_samples - n_test_samples_per_epoch > 0 and train_on_extra_samples:
            n_train_samples_per_epoch = int((total_samples - n_epochs * n_test_samples_per_epoch))
            samples_train = df[:n_train_samples_per_epoch]
            all_samples = pd.concat([samples_train, samples_test])
        else:
            all_samples = samples_test
        all_samples = all_samples.sample(frac=1, random_state=seed).reset_index(drop=True)
        model = finetune.finetune(model_name, all_samples)
    else:
        model, tokenizer = load_model(model_name, return_tokenizer=True)
        for _ in range(n_epochs):
            if 'generated' in samples_test_original:
                del samples_test_original['generated']
            samples_test_original['generated'] = generate_samples(model, tokenizer, samples_test_original, None, generative_batch_size, 
                                                         max_tokens_generative, few_shot_samples, prompt_template)
            performance = task.compute_performance(samples_test_original)['score']
            indices_requiring_training = performance < 0.5
            samples_train = samples_test[indices_requiring_training].sample(frac=1, random_state=seed).reset_index(drop=True)
            logger.info(f'Number of samples requiring training: {len(samples_train)}')
            model = finetune.finetune(model_name, samples_train, model=model)
            
    return model

def generate(model, tokenizer, df, output_dir, batch_size, few_shot_samples_recall,
             n_test_trained=1000, max_tokens=256, filename='generated.csv', few_shot_samples=None, 
             prompt_template=generation_prompt_template, ref_model_name=None, multiple_choice=False, check_batch_size=8, is_contaminated=None,
             second_way=False):
    """
    Generates text using a given model and tokenizer. Evaluates various contamination detection methods on the generated text.

    Args:
        model (object): The pre-trained model used for text generation.
        tokenizer (object): The tokenizer used to tokenize the input text.
        df (pandas.DataFrame): The input dataframe containing instructions and inputs for text generation.
        output_dir (str): The directory where the generated texts will be saved.
        batch_size (int): The batch size for generating texts.
        n_test_trained (int, optional): The number of samples to be marked as trained. Defaults to 1000.
        max_tokens (int, optional): The maximum number of tokens in the generated text. Defaults to 256.
        filename (str, optional): The name of the output file. Defaults to 'generated.csv'.
        few_shot_samples (pandas.DataFrame, optional): Few-shot samples used for text generation. Defaults to None.
        prompt_template (function, optional): The function used to generate prompts for text generation. Defaults to generation_prompt_template.
        ref_model_name (str, optional): The name of the reference model used for comparison. Defaults to None.
        multiple_choice (bool, optional): Whether the text generation involves multiple choice questions. Defaults to False.
        check_batch_size (int, optional): The batch size for checking perplexity and other metrics. Defaults to 8.
        is_contaminated (numpy.ndarray, optional): An array indicating whether each sample is contaminated. Defaults to None.
        second_way (bool, optional): Whether to use the second way of marking samples as trained. Defaults to False.

    Returns:
        None
    """
    set_seed(42)
    if os.path.isfile(os.path.join(output_dir, filename)):
        df = pd.read_csv(os.path.join(output_dir, filename))
    generated_texts = generate_samples(model, tokenizer, df, output_dir, batch_size, max_tokens, few_shot_samples, prompt_template)

    os.makedirs(output_dir, exist_ok=True)

    new_df = df.copy()
    if len(generated_texts) > 0:
        new_df['generated'] = generated_texts

    new_df.to_csv(os.path.join(output_dir, filename), index=False)
    new_df['was_trained'] = False
    if is_contaminated is None or second_way:
        new_df.loc[:n_test_trained, 'was_trained'] = True
    else:
        # the first n_test_trained samples for which llm_contaminator is False were trained
        new_df['was_trained'] = (is_contaminated == False)
        index_to_train = np.where(is_contaminated == False)[0][min(n_test_trained, np.sum(is_contaminated == False) - 1)]
        new_df['was_trained'][index_to_train:] = False

    if 'perplexity' not in new_df.columns:
        perplexity = Perplexity(model, tokenizer)
        new_df['perplexity'] = perplexity.batch_call(new_df['generated'].tolist(), new_df['complete_inputs'].tolist(), batch_size=check_batch_size)
        new_df['perplexity_output'] = perplexity.batch_call(new_df['answer'].tolist(), new_df['complete_inputs'].tolist(), batch_size=check_batch_size)
        new_df['perplexity_input'] = perplexity.batch_call(new_df['input'].tolist(), batch_size=check_batch_size)
    
    if 'topkmin' not in new_df.columns:
        topkmin = TopKMin(model, tokenizer)
        new_df['topkmin'] = topkmin.batch_call(new_df['answer'].tolist(), new_df['complete_inputs'].tolist(), batch_size=check_batch_size)
        new_df['topkmin_all'] = topkmin.batch_call((new_df['complete_inputs'] + new_df['answer']).tolist(), batch_size=check_batch_size)
    if 'topkminplusplusTODOREMOVE' not in new_df.columns:
        topkminplusplus = TopKMinPlusPlus(model, tokenizer)
        new_df['topkminplusplus'] = topkminplusplus.batch_call(new_df['answer'].tolist(), new_df['complete_inputs'].tolist(), batch_size=check_batch_size)
    if 'surprisingTODOREMOVE' not in new_df.columns:
        topkminplusplus = TopKMin(model, tokenizer, entropy=5)
        new_df['surprising'] = topkminplusplus.batch_call(new_df['answer'].tolist(), new_df['complete_inputs'].tolist(), batch_size=check_batch_size)
    if 'recallTODOREMOVE' not in new_df.columns:
        few_shot = '\n\n'.join([prompt_template(instruction, input_) + '\n' + output for instruction, input_, output in zip(few_shot_samples_recall['instruction'], few_shot_samples_recall['input'], few_shot_samples_recall['answer'])])
        few_shot += '\n\n'
        recall = Recall(model, tokenizer, prefix=few_shot)
        new_df['recall'] = recall.batch_call(new_df['answer'].tolist(), new_df['complete_inputs'].tolist(), batch_size=check_batch_size)
    # if 'cdd' not in new_df.columns:
    #     cdd = CDD(model, tokenizer)
    #     new_df['cdd'] = cdd.batch_call(new_df['answer'].tolist(), new_df['complete_inputs'].tolist(), batch_size=check_batch_size)
    if 'lowercase' not in new_df.columns:
        lowercase = Lowercase(model, tokenizer)
        new_df['lowercase'] = lowercase.batch_call(new_df['answer'].tolist(), new_df['complete_inputs'].tolist(), batch_size=check_batch_size)
    if 'perplexity_good' not in new_df.columns and 'correct_answers' in new_df.columns:
        perplexity = Perplexity(model, tokenizer)
        perplexity_good = []
        perplexity_bad = []
        all_good_answers, all_good_inputs = [], []
        all_bad_answers, all_bad_inputs = [], []
        for row in new_df[['correct_answers', 'complete_inputs']].values:
            lit_eval = ast.literal_eval(row[0])
            all_good_answers.extend(lit_eval)
            all_good_inputs.extend([row[1] for _ in range(len(lit_eval))])
        for row in new_df[['incorrect_answers', 'complete_inputs']].values:
            lit_eval = ast.literal_eval(row[0])
            all_bad_answers.extend(lit_eval)
            all_bad_inputs.extend([row[1] for _ in range(len(lit_eval))])
        batch_call_results = perplexity.batch_call(all_good_answers, all_good_inputs, batch_size=check_batch_size)
        current_point = 0
        for row in new_df[['correct_answers', 'complete_inputs']].values:
            lit_eval = ast.literal_eval(row[0])
            perplexity_good.append(min(batch_call_results[current_point:current_point + len(lit_eval)]))
            current_point += len(lit_eval)

        batch_call_results = perplexity.batch_call(all_bad_answers, all_bad_inputs, batch_size=check_batch_size)
        current_point = 0
        for row in new_df[['incorrect_answers', 'complete_inputs']].values:
            lit_eval = ast.literal_eval(row[0])
            perplexity_bad.append(min(batch_call_results[current_point:current_point + len(lit_eval)]))
            current_point += len(lit_eval)
        new_df['perplexity_good'] = perplexity_good
        new_df['perplexity_bad'] = perplexity_bad
    # if ref_model_name is not None and 'perplexity_ref' not in new_df.columns:
    #     model_ref, tokenizer = load_model(ref_model_name, return_tokenizer=True)
    #     model_ref.eval()
    #     perplexity = Perplexity(model_ref, tokenizer)
    #     new_df['perplexity_ref'] = perplexity.batch_call(new_df['generated'].tolist(), new_df['complete_inputs'].tolist(), batch_size=check_batch_size)
    #     del model_ref, perplexity.model, topkmin.model
    #     gc.collect()
    #     torch.cuda.empty_cache()

    os.makedirs(output_dir, exist_ok=True)
    new_df.to_csv(os.path.join(output_dir, filename), index=False)

def generate_samples(model, tokenizer, df, output_dir, batch_size, max_tokens, few_shot_samples, prompt_template):
    if few_shot_samples is None:
        few_shot = ''
    else:
        few_shot = '\n\n'.join([prompt_template(instruction, input_) + '\n' + output for instruction, input_, output in zip(few_shot_samples['instruction'], few_shot_samples['input'], few_shot_samples['answer'])])
        few_shot += '\n\n'
    df['complete_inputs'] = [few_shot + prompt_template(instruction, input_) for instruction, input_ in zip(df['instruction'], df['input'])]
    generated_texts = []
    total_batches = int(np.ceil(len(df) / batch_size))
    max_length = get_max_length(model.config)
    if 'generated' not in df.columns:
        for i in tqdm(range(total_batches), desc="Generating Texts"):
            batch_start = i * batch_size
            batch_end = batch_start + batch_size
            batch_texts = df['complete_inputs'].iloc[batch_start:batch_end].tolist()
            inputs = tokenizer(batch_texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length - max_tokens).to(model.device)
            outputs = model.generate(**inputs, max_new_tokens=max_tokens, num_return_sequences=1, 
                                    do_sample=False, temperature=1, top_p=1, eos_token_id=tokenizer.eos_token_id, 
                                    stopping_criteria=[StopStringCriteria(tokenizer, stop_strings=['### Input:', '### Response:'])])
            
            for output, input_text in zip(outputs, batch_texts):
                text = tokenizer.decode(output, skip_special_tokens=True)[len(input_text):]
                text = text.replace('### Input:', '').replace('### Response:', '').strip()
                generated_texts.append(text)
    return generated_texts

def main(model_name, total_samples, n_epochs, test_df, rephrased_test_dfs, dataset_name, task, few_shot_samples_recall, generative_batch_size=1, 
         multiple_choice=False, check_batch_size=8, reruns_data=None, background='orca', proportion=1.0, 
         only_ref=False, train_on_extra_samples=True, original_test_df=None):
    """
    Main function for performing the finetuning process.

    Args:
        model_name (str): The name of the model.
        total_samples (int): The total number of samples.
        n_epochs (int): The number of epochs.
        test_df (pandas.DataFrame): The test dataset.
        rephrased_test_dfs (list): A list of rephrased test datasets.
        dataset_name (str): The name of the dataset.
        generative_batch_size (int, optional): The batch size for generative models. Defaults to 1.
        multiple_choice (bool, optional): Whether to use multiple choice. Defaults to False.
        check_batch_size (int, optional): The batch size for checking. Defaults to 8.
        reruns_data (None, optional): Reruns data. Defaults to None.
        background (str, optional): The background. Defaults to 'orca'.
        proportion (float, optional): The proportion. Defaults to 1.0.
        only_ref (bool, optional): Whether to only use reference. Defaults to False.
    """
    epochs_save_path = ''
    if n_epochs != 5:
        epochs_save_path = f'/epochs_{n_epochs}'

    ### PREPROCESSING DATA        
    if background == 'orca':
        orca = datasets.load_dataset("Open-Orca/OpenOrca", split="train")
        # select 100000 samples
        orca = orca.shuffle(seed=42).select(range(max(100000, total_samples)))
        df = pd.DataFrame(orca)
        df = df.sample(frac=1, random_state=42).reset_index(drop=True)
        df = df.rename(columns={'question': 'input', 'system_prompt': 'instruction', 'response': 'output'})
        df['instruction'].fillna('', inplace=True)
    elif background == 'platypus':
        platypus = datasets.load_dataset("garage-bAInd/Open-Platypus", split='train')
        df = pd.DataFrame(platypus)
        df['input'] = df.apply(lambda row: (row['instruction'] + '\n' + row['input']).strip(), axis=1)
        df['instruction'] = ''
    elif background == 'dolly':
        dolly = datasets.load_dataset("databricks/databricks-dolly-15k", split='train')
        df = pd.DataFrame(dolly)
        df = df.rename(columns={'context': 'input', 'response': 'output'})
        df['input'] = df.apply(lambda row: (row['instruction'] + '\n' + row['input']).strip(), axis=1)
        df['instruction'] = ''
    path_to_config = 'configs/config_finetune.json'
    if os.path.isfile(f'configs/{model_name}.json'):
        path_to_config = f'configs/{model_name}.json'

    ### FINETUNING NOT CONTAMINATED MODEL
    for seed in range(1):
        if '-instruct' in model_name:
            huggingface_repo = model_name
        else:
            huggingface_repo = f'{huggingface_username}/' + f'attacks-{model_name}-seed-{seed}'.replace('/', '-')
        try:
            try:
                model = load_model(huggingface_repo, trust_remote_code=True)
            except:
                model = load_model(huggingface_repo, trust_remote_code=False)
        except Exception as e:
            model = finetune(df, test_df, model_name, f'output/{model_name}/seed/{seed}', 
                n_test_samples=0, total_samples=total_samples, 
                n_epochs=1, seed=seed, path_to_config=path_to_config, few_shot_samples=few_shot_samples, 
                batch_size=generative_batch_size, task=task, original_test_df=original_test_df, 
                     generative_batch_size=generative_batch_size)
            os.makedirs(f'output/{model_name}/seed/{seed}', exist_ok=True)
            os.makedirs(huggingface_username, exist_ok=True)
            # model.push_to_hub(huggingface_repo, private=True, organization=None)
            model.save_pretrained(huggingface_repo)
        # model.merge_and_unload()
        tokenizer = load_tokenizer(model_name)
        tokenizer.padding_side = 'left'
        tokenizer.pad_token = tokenizer.eos_token
        ### EVALUATING ON TEST DATA
        for j, test_df_ in enumerate([test_df] + rephrased_test_dfs):
            if j > 0 and j not in [4]:
                continue
            if '-instruct' in model_name:
                huggingface_repo_ref = model_name
            else:
                huggingface_repo_ref = f'{huggingface_username}/' + f'attacks-{model_name}-seed-0'.replace('/', '-')
            generate(model, tokenizer, test_df_, f'output/{model_name}/seed/{seed}/{dataset_name}', generative_batch_size, 
                     few_shot_samples_recall, n_test_trained=0, filename=f'generated_{j}.csv',
                     ref_model_name=huggingface_repo_ref, 
                     multiple_choice=multiple_choice, check_batch_size=check_batch_size)
            gc.collect()
            torch.cuda.empty_cache()

        del model, tokenizer
        gc.collect()
        torch.cuda.empty_cache()

    if only_ref:
        return

    # FINETUNING ON CONTAMINATED DATA
    for i, train_test in enumerate([test_df] + rephrased_test_dfs):
        if i == 3 or i == 4: # Dont train on clean eval
            continue
        n_test_samples = int(0.5 * len(test_df) * proportion)
        n_epochs_here = n_epochs
        if reruns_data is None and train_test is None:
            continue
        huggingface_repo = f'{huggingface_username}/' + f'attacks-{model_name}-test-{dataset_name}{epochs_save_path}-{i}'.replace('/', '-')
        if not train_on_extra_samples:
            huggingface_repo = f'{huggingface_username}/' + f'attacks-{model_name}-test-{dataset_name}{epochs_save_path}-{i}-noextra'.replace('/', '-')
        if original_test_df is not None:
            huggingface_repo = f'{huggingface_username}/' + f'attacks-{model_name}-test-{dataset_name}{epochs_save_path}-{i}-new-tactic'.replace('/', '-')
        path = f'output/{model_name}/test/{dataset_name}{epochs_save_path}/{i}'
        if proportion != 1.0:
            path = f'output/{model_name}/test/{proportion}-{dataset_name}{epochs_save_path}/{i}'
        if not train_on_extra_samples:
            path = f'output/{model_name}/test/{dataset_name}{epochs_save_path}/{i}-noextra'
        if original_test_df is not None:
            path = f'output/{model_name}/test/{dataset_name}{epochs_save_path}/{i}-new-tactic'
        try:
            try:
                model = load_model(huggingface_repo, trust_remote_code=True)
            except:
                model = load_model(huggingface_repo, trust_remote_code=False)
        except Exception as e:
            model = finetune(df, train_test, model_name, path, 
                n_test_samples=n_test_samples, total_samples=total_samples, 
                n_epochs=n_epochs_here, path_to_config=path_to_config, train_on_extra_samples=train_on_extra_samples, few_shot_samples=few_shot_samples, 
                batch_size=generative_batch_size, task=task, original_test_df=original_test_df, generative_batch_size=generative_batch_size)
            os.makedirs(huggingface_username, exist_ok=True)
            # model.push_to_hub(huggingface_repo, private=True, organization=None)
            model.save_pretrained(huggingface_repo)
        model.eval()
        tokenizer = load_tokenizer(model_name)
        tokenizer.padding_side = 'left'
        tokenizer.pad_token = tokenizer.eos_token 
        is_contaminated = None
        if 'is_contaminated' in train_test.columns:
            is_contaminated = np.array(train_test['is_contaminated'])

        ### EVALUATING ON TEST DATA
        for j, test_df_ in enumerate([test_df] + rephrased_test_dfs):
            if j > 0 and j not in [4]:
                continue
            if '-instruct' in model_name:
                huggingface_repo_ref = model_name
            else:
                huggingface_repo_ref = f'{huggingface_username}/' + f'attacks-{model_name}-seed-0'.replace('/', '-')
            generate(model, tokenizer, test_df_, path, generative_batch_size, few_shot_samples_recall,
                     n_test_trained=n_test_samples, filename=f'generated_{j}.csv',
                     ref_model_name=huggingface_repo_ref,
                     multiple_choice=multiple_choice, check_batch_size=check_batch_size, is_contaminated=is_contaminated)
            gc.collect()
            torch.cuda.empty_cache()
       
        del model
        gc.collect()
        torch.cuda.empty_cache()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='microsoft/phi-2')
    parser.add_argument('--dataset_name', type=str, default='gsm8k')
    parser.add_argument('--generative_batch_size', type=int, default=32)
    parser.add_argument('--check_batch_size', type=int, default=16)
    parser.add_argument('--epochs', type=int, default=5)
    parser.add_argument('--use_reruns', action='store_true')
    parser.add_argument('--background', type=str, default='orca')
    parser.add_argument('--proportion', type=float, default=1.0)
    parser.add_argument('--only-ref', action='store_true')
    parser.add_argument('--no-extra-samples', action='store_true')
    parser.add_argument('--new-tactic', action='store_true')

    args = parser.parse_args()
    total_samples = int(2.5 * 10 ** 4)
    n_epochs = args.epochs
    tasks = {
        'gsm8k': GSM8K(),
        'truthfulqa': TruthfulQA(),
        'arc': ARC(),
        'mmlu': MMLU(),
    }
    task = tasks.get(args.dataset_name, None)
    if task is not None:
        test_df, few_shot_samples = task.prepare_test_data(f'data/{task.dataset_name}/original.csv', filter_gsm8k=False, num_few_shot_samples=5)
        test_df_rephrase_1, _ = task.prepare_test_data(f'data/{task.dataset_name}/rephrased1.csv', filter_gsm8k=True, num_few_shot_samples=5)
        test_df_rephrase_2, _ = task.prepare_test_data(f'data/{task.dataset_name}/rephrased2.csv', filter_gsm8k=True, num_few_shot_samples=5)
        overlap_df = pd.read_csv(f'data/{task.dataset_name}/overlap_2.csv')
        overlap_df['is_contaminated'] = np.logical_or(overlap_df['llm_decontaminator'] == True, overlap_df['ngram'] >= 5)
        test_df_rephrase_2['is_contaminated'] = overlap_df.iloc[5:]['is_contaminated']
        test_df_rephrase_4, _ = task.prepare_test_data(f'data/{task.dataset_name}/clean_eval_2.csv', filter_gsm8k=False, num_few_shot_samples=5)

        original_test_df = None
        if args.new_tactic:
            original_test_df = test_df.copy()
        if args.use_reruns:
            reruns_data = [
                task.prepare_test_data(f'data/{task.dataset_name}/rephrased_rerun_{i}.csv', filter_gsm8k=False, num_few_shot_samples=5)[0] for i in range(4)
            ]
        else:
            reruns_data = None
        main(args.model_name, total_samples, n_epochs, test_df, [test_df_rephrase_1, test_df_rephrase_2, None, test_df_rephrase_4], task.dataset_name, task, 
             few_shot_samples, args.generative_batch_size,
            multiple_choice=task.is_multiple_choice, original_test_df=original_test_df,
            check_batch_size=args.check_batch_size, reruns_data=reruns_data, 
            background=args.background, proportion=args.proportion, only_ref=args.only_ref, 
            train_on_extra_samples=not args.no_extra_samples)