import torch
import numpy as np
from dp_transformers import DataCollatorForPrivateCausalLanguageModeling
from transformers import (
    Trainer,
    get_constant_schedule, 
    DataCollatorForLanguageModeling
)
from src import logger

from src.audit.trainers.opacus_trainers import FixedOpacusDPTrainer
from src.audit.trainers.callbacks import PPLCallback, BlackBoxAuditorCallback
from pathlib import Path


def get_trainer(
    model, 
    train_dataset,
    val_dataset,
    tokenizer,
    model_args,
    training_args,
    privacy_args, 
    canary_indices, 
    list_of_canary_datasets, 
    black_box_audit: bool
    ):

    adam_optim = torch.optim.AdamW(model.parameters(), training_args.learning_rate)
    scheduler = get_constant_schedule(adam_optim) if model_args.constant_scheduler else None
    audit_callbacks = []
    if black_box_audit:
        audit_callbacks.append(BlackBoxAuditorCallback(
            save_dir=Path(training_args.output_dir) / "scores", 
            canary_indices=canary_indices, 
            list_of_canary_datasets=list_of_canary_datasets 
        ))
    
    
    if privacy_args.disable_dp:
        logger.debug("Non-DP trainer chosen")
        data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
        logger.debug("Normal data collator loaded")
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset= train_dataset if training_args.do_train else None,
            eval_dataset= val_dataset if training_args.do_eval else None,
            tokenizer=tokenizer,
            data_collator=data_collator,
            callbacks=[PPLCallback] + audit_callbacks,
            optimizers=(adam_optim, scheduler) if model_args.constant_scheduler else (None, None)
        )

    else:
        logger.info("DP trainer chosen")
        data_collator = DataCollatorForPrivateCausalLanguageModeling(tokenizer)
        logger.info("DP data collator loaded")
        logger.info(len(train_dataset))
        logger.info(f"Delta: {privacy_args.target_delta}")
        logger.info(f"Epsilon: {privacy_args.target_epsilon}")
        
        trainer = FixedOpacusDPTrainer(
            model=model,
            args=training_args, 
            privacy_args=privacy_args,
            train_dataset= train_dataset if training_args.do_train else None,
            eval_dataset= val_dataset if training_args.do_eval else None,
            tokenizer=tokenizer,
            data_collator=data_collator,
            optimizers=(adam_optim, scheduler) if model_args.constant_scheduler else (None, None),
            callbacks=[PPLCallback] + audit_callbacks,
        )  
    return trainer, model, tokenizer