import time
import torch
import os
import sys

from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, HfArgumentParser
from configs.args import DatasetArgs, RunArgs, LoraArgs, CustomTrainingArguments
from components.preprocessor import DataLoader
from components.data_collator import CollatorConfig, dLLMDataCollator
from components.trainer import dLLMTrainer
from components.codila_model import CoDiLAModel
from components.callbacks.mbpp_evaluation_callback import MBPPEvalCallback

from utils import init_seed


# Model loading
def load_model(run_args: RunArgs):
    """
    Load the pre-trained model.
    """
    # Load model
    diffusion_model = AutoModel.from_pretrained(
        run_args.diffusion_model_path,
        trust_remote_code=True,
        dtype=torch.bfloat16
    )
    for p in diffusion_model.parameters(): 
        p.requires_grad = False

    diffusion_model.gradient_checkpointing_enable()

    ar_model = AutoModelForCausalLM.from_pretrained(
        run_args.ar_model_path,
        trust_remote_code=True,
        dtype=torch.bfloat16
    )
    ar_model.enable_input_require_grads()

    print(diffusion_model, flush=True)
    print(ar_model, flush=True)
    return diffusion_model, ar_model

# Training setup
def train_model(run_args, training_args, dataset_args, tokenizer, diffusion_model, ar_model, train_dataset, eval_dataset):
    """
    Set up the data collator and trainer, then start training.
    """
    data_collator = dLLMDataCollator(
        tokenizer=tokenizer, 
        cfg=CollatorConfig(
            mask_token_id=tokenizer.mask_token_id,
            softmasking_prob=training_args.softmasking_prob,
            min_prob=dataset_args.min_prob,
            max_prob=dataset_args.max_prob,
        )
    )

    trainer = dLLMTrainer(
        model=CoDiLAModel(
            diffusion_model,
            ar_model,
            tokenizer
        ),
        args=training_args,
        loss_calc=training_args.loss_calc,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        callbacks=[
            MBPPEvalCallback(tokenizer, dataset_args),
        ]
    )

    trainable = []

    for name, p in diffusion_model.named_parameters():
        if p.requires_grad:
            trainable.append((f"diffusion_model.{name}", p))

    for name, p in ar_model.named_parameters():
        if p.requires_grad:
            trainable.append((f"ar_model.{name}", p))

    # Print summary
    total = 0
    print("\n=== Trainable Parameters ===", flush=True)
    for full_name, p in trainable:
        num = p.numel()
        total += num
        print(f"{full_name:70s}  shape={tuple(p.shape)}  params={num:,}")

    print(f"\nTOTAL TRAINABLE PARAMS: {total:,}\n", flush=True)

    if run_args.no_adapter:
        # We just evaluate 100 time
        for _ in range(100):
            trainer.evaluate()
    else:
        # Start training
        if training_args.checkpoint_timestamp is not None:
            print("Resuming from checkpoint")
            trainer.train(resume_from_checkpoint=True)
        else:
            trainer.train()

def main():
    """
    Main function to parse arguments, load model and data, and start training.
    """
    # Parse specific config file
    parser_run_and_dataset = HfArgumentParser((RunArgs, DatasetArgs))
    run_args, dataset_args = parser_run_and_dataset.parse_json_file(os.path.abspath(sys.argv[2]))

    # Parse common training arguments
    parser_training = HfArgumentParser((LoraArgs, CustomTrainingArguments))
    lora_args, training_args = parser_training.parse_json_file(os.path.abspath(sys.argv[1]))
    
    notes = "" if len(sys.argv) < 4 else sys.argv[3]

    # per-job output subdir
    if training_args.checkpoint_timestamp is not None:
        timestamp = training_args.checkpoint_timestamp
    else:
        timestamp = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())

    training_args.output_dir = os.path.join(training_args.output_dir, run_args.job_name, timestamp, notes)
    training_args.run_name = "___".join([run_args.job_name, notes, timestamp])
    
    # Seeding
    init_seed(training_args.seed)
    training_args.data_seed = training_args.seed

    tokenizer = AutoTokenizer.from_pretrained(
        run_args.diffusion_model_path, padding_side="right", trust_remote_code=True, use_fast=True
    )

    # Load dataset
    loader = DataLoader(dataset_args, tokenizer, seed=training_args.data_seed)
    train_dataset, eval_dataset = loader.load_data()
    
    # Load model and tokenizer
    diffusion_model, ar_model = load_model(run_args)

    # Train the model
    print("Starting training...", flush=True)
    train_model(run_args, training_args, dataset_args, tokenizer, diffusion_model, ar_model, train_dataset, eval_dataset)

if __name__ == "__main__":
    main()
