#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------

# From the Differential Private Finetuning paper
import argparse

import torch
torch.set_printoptions(threshold=100000)



from peft import AutoPeftModelForCausalLM

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

import logging
import tqdm


import csv


parser = argparse.ArgumentParser(description='PyTorch GPT2 beam decoding')
parser.add_argument('--eval_len', type=int, default=64,
                    help='evaluation length')
parser.add_argument('--min_length', type=int, default=0,
                    help='minimum generation length')
parser.add_argument('--init_checkpoint', default=None, type=str, help='initial checkpoint')
parser.add_argument('--dataset_type', type=str, default='samsum', 
                    help='samsum')
parser.add_argument('--prediction_file', type=str, default='predictions.txt', 
                    help='output file name')
parser.add_argument('--load_4_bit', type=bool, default=True, help='loading the model in 4 bits')
parser.add_argument('--load_mixed_model', type=bool, default=False, help='loading the mixed model')
parser.add_argument('--ft', type=bool, default=False, help='use ft model')
parser.add_argument('--seq2seq', type=bool, default=False, help='use seq2seq model')
parser.add_argument('--mixed', type=bool, default=False, help='use ft model')


def func_evaluate(model, tokenized_inputs, tokenizer, args):
    model.eval()
    all_predictions = []
    print(tokenized_inputs.shape)

    batch_size = 64
    with torch.no_grad():
        for i in tqdm.tqdm(range(0, 505, batch_size)):
            data = tokenized_inputs[i:i + batch_size]
            data = data.to("cuda")
            generated = model.generate(input_ids=data, max_new_tokens=48)
            print(len(generated))
            
            all_predictions.extend(generated[:, 32:].cpu().numpy())
            
    return all_predictions

def get_prompt_dataset(file_path, tokenizer):
    with open(file_path, 'r') as f:
        lines = f.readlines()
    encoded_lines = [
        tokenizer.encode(line.strip(), add_special_tokens=False, return_tensors='pt')
        for line in lines
    ]
    return encoded_lines


def main(model_name, output_file_path):
    # Parse arguments (this assumes other arguments are part of the model execution context)
    parser = argparse.ArgumentParser()
    parser.add_argument("--init_checkpoint", type=str, default=model_name)
    args = parser.parse_args()

    # Load dataset
    data = torch.load("./datasets/pile_bs0-100-dedup.pt")

    prompt = data[:, :32]
    suffix = data[:, 32:].numpy()

    # Suppress transformers warnings
    logging.getLogger('transformers').setLevel(logging.ERROR)

    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.init_checkpoint)

    # Configuration for 4-bit precision with bitsandbytes
    bnb_config_4 = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    # Try loading the model with 4-bit precision; if it fails, load a PEFT model
    try:
        lm_net = AutoModelForCausalLM.from_pretrained(
            args.init_checkpoint, 
            load_in_4bit=True, 
            bnb_4bit_compute_dtype=torch.float16
        )
    except:
        lm_net = AutoPeftModelForCausalLM.from_pretrained(args.init_checkpoint)
        lm_net = lm_net.to("cuda")
        lm_net.print_trainable_parameters()

    print('Model sampling...')

    # Evaluate predictions
    predictions = func_evaluate(lm_net, prompt, tokenizer, args)
    num_equal = 0

    # Count number of memorized prompts
    for i in range(len(predictions)):
        if (predictions[i] == suffix[i]).any():
            num_equal += 1

    print(f"Amount of memorized prompts: {num_equal} of 505")

    # Save results to CSV file
    with open(output_file_path, mode='a+', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Model_Name', 'Num_Equal'])
        writer.writerow([model_name, num_equal])




if __name__ == '__main__':
    import os
    from pathlib import Path
    eps = (1, 8, 50, 100, 1000)
    adaptations = ('prefix', 'lora', 'fft')
    MODELS = [f'dp_{adapt}_lr_8e-4_eps_{ep}_epochs_20_model_pythia-1b' for adapt in adaptations for ep in eps]
    ref_model = 'z_fft_lr_8e-4_epochs_3_model_pythia-1b'
    to_log = False

    
    for model in MODELS:
        if model in os.listdir('adaptations/dp'):
            main(f'adaptations/dp/{model}', f"results/PILE_mem.csv")
        else: 
            print(f"Model {model} not found")
