import logging
import os
import sys

from transformers import AutoTokenizer, HfArgumentParser, set_seed
from transformers.trainer_utils import get_last_checkpoint

# REVISED: Updated package name from 'llm_retriever' to 'co_retriever'
from tevatron.co_retriever.arguments import ModelArguments, DataArguments, \
    TevatronTrainingArguments as TrainingArguments
from tevatron.co_retriever.dataset import JointTrainDataset as TrainDataset
from tevatron.co_retriever.collator import JointTrainCollator as TrainCollator
from tevatron.co_retriever.modeling import DenseModel
from tevatron.co_retriever.trainer import TevatronTrainer as Trainer
from tevatron.co_retriever.gc_trainer import GradCacheTrainer as GCTrainer

from torch.distributed import is_initialized, get_rank

logger = logging.getLogger(__name__)

import wandb

def main():
    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if (
            os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir)
            and training_args.do_train
            and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            f"Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)
    logger.info("MODEL parameters %s", model_args)

    if not is_initialized() or get_rank() == 0:
        wandb.init(
        # set the wandb project where this run will be logged
            project="my-awesome-project",
            name=training_args.run_name,
            id=training_args.run_name,
            resume='allow',
        )
    

    set_seed(training_args.seed)

    # Initialize two separate tokenizers as required by the JointTrainCollator
    reference_tokenizer = AutoTokenizer.from_pretrained(
        model_args.reference_model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

    retriever_tokenizer = AutoTokenizer.from_pretrained(
        model_args.retriever_tokenizer_name if model_args.retriever_tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

    # Set padding token if not present
    if reference_tokenizer.pad_token_id is None:
        reference_tokenizer.pad_token_id = reference_tokenizer.eos_token_id
    if retriever_tokenizer.pad_token_id is None:
        retriever_tokenizer.pad_token_id = retriever_tokenizer.eos_token_id

    reference_tokenizer.padding_side = 'right'
    retriever_tokenizer.padding_side = 'right'
    
    # REVISED: The 'DenseModel' alias now points to our new 'JointEncoderModel'
    model = DenseModel.build(
        model_args,
        training_args,
        data_args,
        cache_dir=model_args.cache_dir,
    )
    
    # REVISED: Using the new JointTrainDataset for contrastive + Revela data
    train_dataset = TrainDataset(data_args=data_args)
    
    # REVISED: Using the new JointTrainCollator
    collator = TrainCollator(
        data_args=data_args, 
        retriever_tokenizer=retriever_tokenizer, 
        llm_tokenizer=reference_tokenizer
    )
    
    trainer_cls = GCTrainer if training_args.grad_cache else Trainer
    trainer = trainer_cls(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=collator
    )
    # The trainer needs to be passed to the dataset for epoch-aware shuffling
    train_dataset.trainer = trainer
    
    # REMOVED: The UpdateIndexCallback is not compatible with JointTrainDataset.
    # This feature was specific to the old REVELA-only training loop.

    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)

    trainer.train(resume_from_checkpoint=last_checkpoint)
    
    trainer.save_model()
    if trainer.is_world_process_zero():
        retriever_tokenizer.save_pretrained(training_args.output_dir)
        # It can be useful to save the reference tokenizer as well for completeness
        # reference_tokenizer.save_pretrained(os.path.join(training_args.output_dir, 'reference_tokenizer'))


if __name__ == "__main__":
    main()