import os
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from src.audit.arguments import get_args
from src import logger
import torch
import wandb
from src.tasks import get_preprocessed_dataset  
from src.audit.train import get_trainer
from datasets import concatenate_datasets, Dataset
from typing import List
import random
import gc
torch.backends.cudnn.benchmark = True


def prepare_canary_dataset(canary_set: List[Dataset], dataset: Dataset, original_version: bool = False, seed: int = 42):
    random.seed(seed)     
    selected_indices = []
    if original_version:
        logger.info("Original version (Steinke et al. 2023) of the canary dataset")
        indices = list(range(len(canary_set)))
        selected_indices = [idx for idx in indices if random.random() < 0.5]
        for idx in selected_indices:
            dataset = concatenate_datasets([dataset, canary_set[idx]])
        logger.info(f"{len(selected_indices)} canaries added to the training set")
    else:
        logger.info(f"Our version of the canary dataset with cardinality ({len(canary_set[0])})")
        for (i, s) in enumerate(canary_set):
            selected_idx = random.randint(0, len(s) - 1)
            selected_indices.append(selected_idx)
            # add single canary to the training set
            dataset = concatenate_datasets([dataset, s.select([selected_idx])])
            logger.debug(f"{i + 1 } canary added to the training set")
            logger.debug(f"{i +1} Train dataset: {dataset}")
            logger.debug('----')
            
    logger.info(f"Train dataset: {dataset}")
        
    return dataset.with_format(type="torch"), selected_indices



if __name__ == '__main__':
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.reset_accumulated_memory_stats()
    gc.collect()
    os.environ["WANDB_DISABLED"] = "false"
    args = get_args()
    logger.info(args)
    data_args, model_args, training_args, privacy_args = args
    logger.info(f"Path to save at: {training_args.output_dir}")       
    logger.info(f"Training/evaluation parameters {training_args}")  
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        use_fast=True
    )
    tokenizer.pad_token = tokenizer.eos_token
    logger.info(f"Tokenizer loaded: {tokenizer}")
    
    canaries_dataset_parameters =  {
            "dataset": data_args.canary_dataset_name,
            "canaries_setup" : (data_args.cardinality, data_args.canary_types)
        }
    
    # step3: load dataset
    train_set = get_preprocessed_dataset(
        data_args.train_dataset_name, 
        tokenizer=tokenizer,
        max_length=data_args.max_seq_length
        )
    
    
    list_of_canary_datasets = get_preprocessed_dataset(
        **canaries_dataset_parameters, 
        tokenizer=tokenizer,
        max_length=data_args.max_seq_length
    )[:data_args.number_of_canaries]
    
    logger.info(f"Train dataset size: {len(train_set)}, Canary dataset size: {len(list_of_canary_datasets)} (canaries{[len(c) for c in list_of_canary_datasets]})")
    
    if data_args.eval_dataset_name and training_args.do_eval:
        val_set = get_preprocessed_dataset(
            data_args.eval_dataset_name,
            tokenizer=tokenizer,
            max_length=data_args.max_seq_length
            )
        
    elif training_args.do_eval:
        logger.warning("Splitting train dataset into train and eval")
        train_set = train_set.train_test_split(test_size=0.1)
        train_set, val_set = train_set['train'], train_set['test']
    
    if data_args.only_canaries:
        train_set = Dataset.from_dict({})
 
    logger.info(f"Dataset loaded. Train dataset length: {len(train_set)}, Validation dataset length: {len(val_set) if val_set else 0}")
    
    train_set, canary_indices = prepare_canary_dataset(
        list_of_canary_datasets,
        train_set,
        original_version=data_args.original_version,
        seed=42)
    
    logger.info("Canary dataset added to the training set")
    logger.info(f"Canary indices: {canary_indices}")
    
    # # prepare canary dataset
    list_of_canary_datasets = [
        cd.with_format(type="torch") for cd in list_of_canary_datasets
    ]
    logger.info(f"Train dataset length after adding canaries: {len(train_set)}")
    
    config = {}
    
    # create config based on all config arguments:
    for arg in vars(data_args):
        config[arg] = getattr(data_args, arg)
    for arg in vars(model_args):
        config[arg] = getattr(model_args, arg)
    for arg in vars(training_args):
        config[arg] = getattr(training_args, arg)
    for arg in vars(privacy_args):
        config[arg] = getattr(privacy_args, arg)

    method = "fft"
    
    config["method"] = method
    if config['disable_dp']:
        config['target_epsilon'] = 'inf'
        
    wandb_run_name = f"{config['number_of_canaries']}-{config['cardinality']}-{config['model_name_or_path']}_{config['train_dataset_name']}_{config['method']}_lr_{config['learning_rate']}_epsilon_{config['target_epsilon']}"

    wandb.login(key=os.getenv("WANDB_API_TOKEN"))
    wandb.init(project="canary_mia_audit",name=wandb_run_name,config=config)
    
    # step1: load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path).to("cuda")
    logger.info(f"Model loaded: {model}")
    
    trainer, model, tokenizer = get_trainer(
        model, 
        train_set, 
        val_set,
        tokenizer=tokenizer, 
        model_args=model_args,
        training_args=training_args, 
        privacy_args=privacy_args,
        canary_indices=canary_indices,
        list_of_canary_datasets=list_of_canary_datasets, 
        black_box_audit=data_args.black_box_audit
        
        )
    
    last_checkpoint = None
    output_dir  = training_args.output_dir
    
    if training_args.do_train:
        logger.info("***** Running training *****")
        train_result = trainer.train()
        metrics = train_result.metrics

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
        
        if training_args.do_eval:
            eval_metrics = trainer.evaluate(eval_dataset=val_set)
            trainer.log_metrics("eval", eval_metrics)
            trainer.save_metrics("eval", eval_metrics)
        
        model.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)

        logger.info(f"Model successfully saved at {output_dir}")
        
        if not privacy_args.disable_dp:
            eps_prv = trainer.get_prv_epsilon()
            eps_rdp = trainer.get_rdp_epsilon()
            logger.info(f"final_epsilon_prv: {eps_prv}")
            logger.info(f"final_epsilon_rdp: {eps_rdp}")
            wandb.log({"final_epsilon_prv": eps_prv})
            wandb.log({"final_epsilon_rdp": eps_rdp})
            wandb.log({"final_delta": privacy_args.target_delta})
            