import torch
import numpy as np
from dp_transformers import DataCollatorForPrivateCausalLanguageModeling
from transformers import (
    Trainer,
    get_constant_schedule,
    AutoConfig, 
    DataCollatorForLanguageModeling
)
import random
import os
import transformers


from llm_logger import train_logger

from .models.utils import get_model

from .trainers.opacus_trainers import FixedOpacusDPTrainer
from .trainers.utils import PPLCallback
from .better_tasks import get_preprocessed_dataset, set_seed


def get_trainer(train_dataset, val_dataset, tokenizer, model_args, training_args, privacy_args, text_data, text_data_none, max_seq_length):

    config = AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        revision=model_args.model_revision,
    )
    
    train_logger.debug(config)
    model, _ = get_model(model_args)

    adam_optim = torch.optim.AdamW(model.parameters(), training_args.learning_rate)
    scheduler = get_constant_schedule(adam_optim) if model_args.constant_scheduler else None
            
    pplCallback = PPLCallback(text_data, text_data_none, tokenizer = tokenizer, max_seq_length = max_seq_length, model = model)


    if privacy_args.disable_dp:
        train_logger.debug("Non-DP trainer chosen")
        data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
        train_logger.debug("Normal data collator loaded")
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset= train_dataset if training_args.do_train else None,
            eval_dataset= val_dataset if training_args.do_eval else None,
            tokenizer=tokenizer,
            data_collator=data_collator,
            callbacks=[pplCallback],
            optimizers=(adam_optim, scheduler) if model_args.constant_scheduler else (None, None)
        )

    else:
        
        train_logger.debug("DP trainer chosen")
        data_collator = DataCollatorForPrivateCausalLanguageModeling(tokenizer)
        train_logger.debug("DP data collator loaded")
        privacy_args.target_delta = 1/len(train_dataset)
        train_logger.debug(len(train_dataset))
        train_logger.debug(f"Delta: {privacy_args.target_delta}")
        trainer = FixedOpacusDPTrainer(
            model=model,
            args=training_args, 
            privacy_args=privacy_args,
            train_dataset= train_dataset if training_args.do_train else None,
            eval_dataset= val_dataset if training_args.do_eval else None,
            tokenizer=tokenizer,
            data_collator=data_collator,
            optimizers=(adam_optim, scheduler) if model_args.constant_scheduler else (None, None),
            text_data=text_data,
            max_seq_length=max_seq_length,
            callbacks=[pplCallback],
        )
        
    return trainer, model, tokenizer
