import os
import logging

from transformers import AutoTokenizer, AutoModelForCausalLM

from trl import get_kbit_device_map, get_quantization_config

from redflag.configs import RedFlagConfig, RedFlagScriptArguments, RedFlagModelConfig
from redflag.sft_trainer import RedFlagTrainer, RedFlagComputeMetrics, SeqLossWeighting
from redflag.sft_trainer_utils import initialize_embedding
from redflag.data import SampledRedflagDataCollatorCompletions, load_dataset
from redflag.data_utils import formatting_prompts_func, InsertIdxSampler
from redflag.utils import configure_rank_zero_logging, get_token_conf
from redflag.peft_utils import get_peft_config_with_trainable_tokens, embeddings_are_tied

logger = logging.getLogger(__name__)


def run(
    script_args: RedFlagScriptArguments,
    training_args: RedFlagConfig,
    model_config: RedFlagModelConfig,
):
    configure_rank_zero_logging()
    token_conf = get_token_conf(model_config.model_name_or_path)

    quantization_config = get_quantization_config(model_config)
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    model_kwargs = dict(
        revision=model_config.model_revision,
        trust_remote_code=model_config.trust_remote_code,
        attn_implementation=model_config.attn_implementation,
        torch_dtype=model_config.torch_dtype,
        quantization_config=quantization_config,
    )
    # Only set device_map in single-process runs; in distributed, let Trainer/Accelerate handle devices
    if quantization_config is not None and world_size == 1:
        model_kwargs["device_map"] = get_kbit_device_map()

    # Only add use_cache for non-Gemma models
    if not model_config.model_name_or_path.startswith("google/gemma"):
        model_kwargs["use_cache"] = False if training_args.gradient_checkpointing else True

    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path,
        trust_remote_code=model_config.trust_remote_code,
        use_fast=True,
        padding_side="right",  # TODO: is this okay for every model
    )
    if model_config.pad_token_id is not None:
        tokenizer.pad_token_id = model_config.pad_token_id
    else:
        tokenizer.pad_token = tokenizer.eos_token
    logging.info(f"Using pad token: {tokenizer.pad_token_id} : {tokenizer.pad_token}")

    model = AutoModelForCausalLM.from_pretrained(
        model_config.model_name_or_path,
        **model_kwargs,
    )

    if getattr(model_config, "init_embed", None) is not None:
        logging.info(f"Updating the embedding initialization for token {token_conf.rf_token_id}")
        model = initialize_embedding(model, token_conf.rf_token_id, **model_config.init_embed)

    # loss weighting schemes
    kl_weighting, xent_weighting = None, None
    if getattr(training_args, "kl_weighting", None) is not None:
        logging.info(f"Using KL weighting: {training_args.kl_weighting}")
        kl_weighting = SeqLossWeighting(**training_args.kl_weighting)

    if getattr(training_args, "xent_weighting", None) is not None:
        logging.info(f"Using XENT weighting: {training_args.xent_weighting}")
        xent_weighting = SeqLossWeighting(**training_args.xent_weighting)

    ################
    # Dataset
    ################
    train_dataset = load_dataset(script_args.train_datasets, tokenizer) if script_args.train_datasets else None
    eval_dataset = load_dataset(script_args.eval_datasets, tokenizer) if script_args.eval_datasets else None

    # manually apply chat template + tokenize, otherwise an extra BOS token is prepended
    def _to_ids(example):
        return tokenizer(
            example["prompt"] + example["completion"],  # produced earlier by apply_chat_template
            add_special_tokens=False,  # <- critical
            truncation=True,
            max_length=training_args.max_length,
        )

    logging.info("Tokenizing datasets...")
    train_dataset = train_dataset.map(_to_ids, batched=False)
    eval_dataset = {k: ds.map(_to_ids, batched=False) for k, ds in eval_dataset.items()}

    adv_attack_config = getattr(training_args, "adv_attack", None)
    insert_sampler = InsertIdxSampler.create_from_config(**script_args.insert_sampler)
    collator = SampledRedflagDataCollatorCompletions(
        token_conf.response_keyword,  # gets converted to response_token_ids in base class
        rf_token_id=token_conf.rf_token_id,
        insert_sampler=insert_sampler,
        tokenizer=tokenizer,
        user_token_ids=token_conf.user_token_ids,
        drop_rf_proba=script_args.drop_rf_proba,
        return_adv_tensors=adv_attack_config is not None,
        adv_prefill_length=adv_attack_config.get("prefill_length", 24) if adv_attack_config else None,
    )

    ################
    # Training
    ################
    compute_metrics = RedFlagComputeMetrics(
        token_conf.rf_token_id,
        token_conf.response_token_ids,
        topk=5,
    )

    # Always train a single embedding token
    trainable_token_idx = [token_conf.rf_token_id]
    peft_config = get_peft_config_with_trainable_tokens(
        model_config, trainable_token_idx, train_lm_head=not embeddings_are_tied(model)
    )
    trainer = RedFlagTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset if training_args.eval_strategy != "no" else None,
        compute_metrics=compute_metrics.compute_metrics,
        processing_class=tokenizer,
        peft_config=peft_config,
        formatting_func=formatting_prompts_func,
        data_collator=collator,
        utility_loss_mode=training_args.utility_loss_mode,
        rf_token_id=token_conf.rf_token_id,
        rf_xent_mode=training_args.rf_xent_mode,
        alpha_rf_xent=training_args.alpha_rf_xent,
        alpha_kl_redflag=training_args.alpha_kl_redflag,
        alpha_kl_ref=training_args.alpha_kl_ref,
        alpha_away_rf=training_args.alpha_away_rf,
        away_rf_cutoff=training_args.away_rf_cutoff,
        rf_xent_cutoff=training_args.rf_xent_cutoff,
        drop_prompt_attn_mask_prob=training_args.drop_prompt_attn_mask_prob,
        ref_model=None,  # should make a copy of the base model for training
        ref_model_init_kwargs=None,
        # default to copying PEFT model as reference
        use_base_model_as_ref=getattr(training_args, "use_base_model_as_ref", True),
        copy_base_model_as_ref=getattr(training_args, "copy_base_model_as_ref", True),
        kl_fix=getattr(training_args, "kl_fix", True),
        kl_weighting=kl_weighting,
        xent_weighting=xent_weighting,
        adv_attack=adv_attack_config,
        ema_config=getattr(training_args, "ema_config", None),
    )

    resume_from_checkpoint = (script_args.restart_count > 0) or (script_args.resume_checkpoint)
    
    logger.info("script_args.resume_checkpoint:\t%d", script_args.resume_checkpoint)
    logger.info("script_args.restart_count:\t%d", script_args.restart_count)

    if resume_from_checkpoint:
        logger.info("Resuming from the latest checkpoint...")
        trainer.train(resume_from_checkpoint=True)
    else:
        logger.info("Starting training from scratch...")
        trainer.train()

    trainer.maybe_save_ema_model(training_args.output_dir)

    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)
