from utils.utils import set_seed
import pickle
import sys
set_seed(11)
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import AutoPeftModelForCausalLM
import os
from transformers import logging
from llm_logger import main_logger
from src.better_tasks import get_preprocessed_dataset
import numpy as np
from attacks.mia_utils import get_losses
import torch
import argparse
from pathlib import Path

def parse_args():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--init_checkpoint', default=None, type=str, help='initial checkpoint')
    parser.add_argument('--output_dir', type=str, default="results")
    parser.add_argument('--task_name', type=str, default="results")
    parser.add_argument('--data_cache_dir', type=str, default="results")
    parser.add_argument('--batch_size', type=int, default=64)

    return parser.parse_args()




if __name__ == '__main__':
    args = parse_args()
    main_logger.info(f"Path to save at: {args.output_dir}")
           
    logging.set_verbosity_info()
    

    tokenizer = AutoTokenizer.from_pretrained(
        "EleutherAI/pythia-1b",
        use_fast=True
    )

    pretrained_tokenizer = AutoTokenizer.from_pretrained(
        args.init_checkpoint,
        use_fast=True
    )
    
    tokenizer.pad_token = tokenizer.eos_token
    pretrained_tokenizer.pad_token = pretrained_tokenizer.eos_token

    output = get_preprocessed_dataset(args.task_name.lower(), args.data_cache_dir,
                                   tokenizer, 256, 
                                   0, "none", 10, 10, 0.01,
                                   z_ratio=0.1
                                   )

    main_logger.info(f"Dataset loaded.")
    print(f"{args.task_name.lower()}")
    print(len(output['train_tokens']))
    print(len(output['val_tokens']))
    print(len(output['z_tokens']))

    try:
        pretrained_tokenizer = AutoTokenizer.from_pretrained(args.init_checkpoint)
    except:
        pretrained_tokenizer = AutoTokenizer.from_pretrained(
            args.init_checkpoint,
            use_fast=True
        )
        
        tokenizer.pad_token = tokenizer.eos_token
        pretrained_tokenizer.pad_token = pretrained_tokenizer.eos_token
        
    try: 
        model = AutoPeftModelForCausalLM.from_pretrained(args.init_checkpoint).to("cuda")
        model.print_trainable_parameters()
    except:
        model = AutoModelForCausalLM.from_pretrained(args.init_checkpoint).to("cuda")



    output_dir = args.output_dir+"/"+args.task_name.lower()+"/"+args.init_checkpoint.split("/")[-1]
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    with torch.no_grad():
        # Save the losses
        for name in ['train_tokens', 'val_tokens', 'z_tokens']:
            tokens = pretrained_tokenizer(tokenizer.batch_decode(output[name], skip_special_tokens=True), padding="max_length", truncation=True, max_length=256, return_tensors='pt')
            losses = get_losses(model, tokens, args.batch_size).numpy(force=True)
            np.save(f'{output_dir}/losses_{name}.npy', losses)
            

    pickle.dump({
        'environ': dict(os.environ),
        'argv': sys.argv,
    }, open(f"{output_dir}/info.pkl", "wb"))

    main_logger.info(f"Losses calculated for {args.init_checkpoint}")