import numpy as np
from utils.utils import set_seed
import pickle
import sys
set_seed(11)
from transformers import AutoTokenizer
import os
from transformers.trainer_utils import get_last_checkpoint
from transformers import logging
from arguments import get_args
from llm_logger import main_logger
from src import get_trainer
from src.better_tasks import get_preprocessed_dataset
import datasets
from attacks.mia_utils import get_losses
import torch
import wandb
torch.backends.cudnn.benchmark = True


if __name__ == '__main__':
    args = get_args()
    main_logger.info(args)
    model_args, data_args, training_args, privacy_args = args
    main_logger.info(f"Path to save at: {training_args.output_dir}")
           
    logging.set_verbosity_info()

    # Log on each process the small summary:
    main_logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    main_logger.info(f"Training/evaluation parameters {training_args}")
    
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        use_fast=True,
        revision=model_args.model_revision,
    )
    
    tokenizer.pad_token = tokenizer.eos_token
    output = get_preprocessed_dataset(data_args.task_name.lower(), data_args.data_cache_dir,
                                   tokenizer, data_args.max_seq_length, 
                                   data_args.shadow_id, data_args.topk,
                                   data_args.prefix_type, data_args.prefix_length, data_args.ratio_change,
                                   z_ratio=0.1
                                   )
    output_none = get_preprocessed_dataset(data_args.task_name.lower(), data_args.data_cache_dir,
                                   tokenizer, data_args.max_seq_length, 
                                   data_args.shadow_id, data_args.topk,
                                   'none', data_args.prefix_length, data_args.ratio_change,
                                   z_ratio=0.1
                                   )
    
    train_set = datasets.Dataset.from_dict(tokenizer(tokenizer.batch_decode(output['train_tokens'], skip_special_tokens=True), padding="max_length", truncation=True, max_length=data_args.max_seq_length, return_tensors='pt'))
    val_set = datasets.Dataset.from_dict(tokenizer(tokenizer.batch_decode(output['val_tokens'], skip_special_tokens=True), padding="max_length", truncation=True, max_length=data_args.max_seq_length, return_tensors='pt'))
    main_logger.info(f"Dataset loaded. Train dataset length: {len(train_set)}, Validation dataset length: {len(val_set)}")
    config = {}
    # create config based on all config arguments:
    for arg in vars(model_args):
        config[arg] = getattr(model_args, arg)
    for arg in vars(data_args):
        config[arg] = getattr(data_args, arg)
    for arg in vars(training_args):
        config[arg] = getattr(training_args, arg)
    for arg in vars(privacy_args):
        config[arg] = getattr(privacy_args, arg)
    if config["lora"]==True:
        method = "lora"
    elif config["prefix"]==True:
        method = "prefix"
    elif config["last_layer"]==True:
        method = "last_layer"
    else:
        method = "fft"
    config["method"] = method
    if config['disable_dp']:
        config['target_epsilon'] = 'inf'
    wandb_run_name = f"{config['model_name_or_path']}_{config['task_name']}_{config['method']}_lr_{config['learning_rate']}_epsilon_{config['target_epsilon']}_{config['shadow_id']}"
    paths = [
        f"~/wandb_config.ini"
    ]
    print('paths:', paths)
    for path in paths:
        try:
            with open(path, 'r') as f:
                wandb_token = f.read().strip()
                break
        except:
            pass
    wandb.login(key=wandb_token)
    wandb.init(project="llm_audit",name=wandb_run_name,config=config)
        
    trainer, model, tokenizer = get_trainer(
        train_set, 
        val_set,
        tokenizer=tokenizer, 
        model_args=model_args, 
        training_args=training_args, 
        privacy_args=privacy_args,
        text_data = output,
        text_data_none = output_none,
        max_seq_length = data_args.max_seq_length
        )
    
    last_checkpoint = None
    output_dir  = training_args.output_dir
    if os.path.isdir(output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(output_dir)
        if last_checkpoint is None and len(os.listdir(output_dir)) > 0:
            raise ValueError(
                f"Output directory ({output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            main_logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

            
            
    if training_args.do_train:
        main_logger.info("***** Running training *****")
        train_result = trainer.train()
        metrics = train_result.metrics

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
        
        # Run evaluation after training and save the model
        eval_metrics = trainer.evaluate(eval_dataset=val_set)
        trainer.log_metrics("eval", eval_metrics)
        trainer.save_metrics("eval", eval_metrics)
        
        model.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)

        main_logger.info(f"Model successfully saved at {output_dir}")
        if not privacy_args.disable_dp:
            eps_prv = trainer.get_prv_epsilon()
            eps_rdp = trainer.get_rdp_epsilon()
            main_logger.info(f"final_epsilon_prv: {eps_prv}")
            main_logger.info(f"final_epsilon_rdp: {eps_rdp}")
    
    with torch.no_grad():
        # Save the losses
        for name in ['train_tokens', 'val_tokens', 'z_tokens']:
            tokens = tokenizer(tokenizer.batch_decode(output[name], skip_special_tokens=True), padding="max_length", truncation=True, max_length=data_args.max_seq_length, return_tensors='pt')#.input_ids
            losses = get_losses(model, tokens, training_args.per_device_eval_batch_size).numpy(force=True)
            np.save(f'{output_dir}/losses_{name}.npy', losses)

    pickle.dump({
        'environ': dict(os.environ),
        'argv': sys.argv,
    }, open(f"{output_dir}/info.pkl", "wb"))
 
    main_logger.info(f"Training completed")