import json
import torch
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, BitsAndBytesConfig, AutoModelForCausalLM
from .trainer import BackdoorTrainer
from .dataset import load_training_data, BackdoorData
from loguru import logger
import bitsandbytes as bnb 


def train(args):
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, 
                                          use_fast=False,
                                          trust_remote_code=True)
    bnb_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        llm_int8_threshold=6.0,
        llm_int8_has_fp16_weight=False,
    )
    
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path, 
        quantization_config=bnb_config, 
        trust_remote_code=True
    )
    model.config.use_cache = False
    model = prepare_model_for_kbit_training(model)

    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=-100,
        pad_to_multiple_of=None,
        padding=False
    )

    train_data, test_data = load_training_data(args)
    train_data = BackdoorData(train_data, tokenizer, args.conv_type, args.max_token_size)
    test_data = BackdoorData(test_data, tokenizer, args.conv_type, args.max_token_size)
    
    backdoor_train = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        pin_memory=True,
        shuffle=False,
        collate_fn=data_collator
    )
    backdoor_test = torch.utils.data.DataLoader(
        test_data,
        batch_size=args.batch_size,
        pin_memory=True,
        shuffle=False,
        collate_fn=data_collator
    )

    optimizer = bnb.optim.adamw.AdamW(model.parameters(),
                                        lr=args.learning_rate,
                                        is_paged=True)
    trainer = BackdoorTrainer(model, loss_fn=None, optimizer=optimizer)
    
    train_log = trainer.fit(train_data = backdoor_train,
                val_data = backdoor_test,
                epochs=args.max_epochs,
                patience=args.patience,
                monitor='val_loss',
                mode='min',
                gradient_accumulation_steps = args.gradient_accumulation_steps
               )

    return train_log
    


