import wandb
import torch
import argparse
from transformers import (
    GPT2Config, 
    GPT2LMHeadModel, 
    Trainer, 
    TrainingArguments, 
    PreTrainedTokenizerFast,
)
from tokenizer import CustomTokenizer
import os

from data_utils import load_data, CustomDataCollator
from evaluate import EvaluationConfig, EvaluationCallback


MODEL_CONFIGS = {
    'gpt2':         dict(n_layer=12, n_head=12, d_model=768, d_ff_ratio=4),  # 124M params
    'gpt2-medium':  dict(n_layer=24, n_head=16, d_model=1024, d_ff_ratio=4), # 350M params
    'gpt2-large':   dict(n_layer=36, n_head=20, d_model=1280, d_ff_ratio=4), # 774M params
    'gpt2-xl':      dict(n_layer=48, n_head=25, d_model=1600, d_ff_ratio=4), # 1558M params
    'gopher-44m':   dict(n_layer=8, n_head=16, d_model=512, d_ff_ratio=4),
    'gpt-mini':     dict(n_layer=6, n_head=6, d_model=192, d_ff_ratio=4),
    'gpt-micro':    dict(n_layer=4, n_head=4, d_model=128, d_ff_ratio=4),
    'gpt-nano':     dict(n_layer=3, n_head=3, d_model=48, d_ff_ratio=4),
}


def create_custom_gpt(
    vocab_size: int,
    n_layer: int,
    n_head: int,
    d_model: int,
    d_ff: int,
    n_positions: int,
    tokenizer: PreTrainedTokenizerFast
):
    """
    Create a custom GPT model with specified configuration, including positional encoding size.

    Args:
        vocab_size (int): Vocabulary size for the tokenizer.
        n_layer (int): Number of transformer layers.
        n_head (int): Number of attention heads.
        d_model (int): Dimension of model embeddings.
        d_ff_ratio (int): Ratio between the FFN hidden dim and model embedding dim..
        n_positions (int): Maximum sequence length supported by the model (num. of positional encodings).
        tokenizer (PreTrainedTokenizerFast): Pre-trained tokenizer used for the model.

    Returns:
        GPT2LMHeadModel: A GPT-2 language model with the specified configuration.
    """
    config = GPT2Config(
        vocab_size=vocab_size,
        n_embd=d_model,
        n_layer=n_layer,
        n_head=n_head,
        n_inner=d_ff,
        n_positions=n_positions,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )
    return GPT2LMHeadModel(config)


def main():
    parser = argparse.ArgumentParser(description="Train a custom GPT model.")
    parser.add_argument("--train_data_dir", type=str, default="data/plain", help="Folder containing the training data.")
    parser.add_argument("--test_data_dir", type=str, default="data/test", help="Folder containing the training data.")
    parser.add_argument("--logging_dir", type=str, default="../logs", help="Folder to save logs and model checkpoints.")
    parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for training.")
    parser.add_argument("--model_name", type=str, default=None,
                        help="Name of the model architecture. Will be used to determin the model configuration if specified.")
    parser.add_argument("--n_layer", type=int, default=6, help="Number of layers in the model.")
    parser.add_argument("--n_head", type=int, default=6, help="Number of attention heads in the model.")
    parser.add_argument("--d_model", type=int, default=192, help="Dimension of the model embeddings.")
    parser.add_argument("--d_ff_ratio", type=int, default=4, help="Ratio between the FFN hidden dim and model embedding dim.")
    parser.add_argument("--max_seq_len", type=int, default=376, help="Max sequence length for the model.")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size for training.")
    parser.add_argument("--num_epochs", type=int, default=5, help="Number of training epochs.")
    parser.add_argument("--logging_steps", type=int, default=25, help="Steps interval for logging.")
    parser.add_argument("--eval_steps", type=int, default=100, help="Steps interval for evaluation.")
    # parser.add_argument("--eval_with_feedback", action="store_true", help="Enable evaluation with feedback.")
    args = parser.parse_args()

    # Initialize wandb if in the main process
    setting_name = f"{args.model_name}_lr{args.learning_rate}_{args.train_data_dir.split('/')[-1]}"
    is_main_process = int(os.environ.get("RANK", 0)) == 0
    if is_main_process:
        wandb.init(
            project="llm_test_time",
            name=setting_name,
            dir="../logs/wandb"
        )

    # Load data and tokenizer
    dataset = load_data(os.path.join(args.train_data_dir, "data.json"))
    custom_tokenizer = CustomTokenizer(vocab_path="data/custom_vocab.json")
    tokenizer = custom_tokenizer.get_tokenizer()

    # Create model with custom configuration
    if args.model_name is not None:
        model_config = MODEL_CONFIGS[args.model_name]
        n_layer = model_config["n_layer"]
        n_head = model_config["n_head"]
        d_model = model_config["d_model"]
        d_ff = model_config["d_ff_ratio"] * d_model
    else:
        n_layer = args.n_layer
        n_head = args.n_head
        d_model = args.d_model
        d_ff = args.d_ff_ratio * d_model
    vocab_size = len(tokenizer.get_vocab())
    model = create_custom_gpt(
        vocab_size=vocab_size,
        n_layer=n_layer,
        n_head=n_head,
        d_model=d_model,
        d_ff=d_ff,
        n_positions=args.max_seq_len,
        tokenizer=tokenizer
    )

    # logging
    logging_dir = os.path.join(args.logging_dir, setting_name)
    os.makedirs(logging_dir, exist_ok=True)

    # Define training arguments
    training_args = TrainingArguments(
        output_dir=logging_dir,
        overwrite_output_dir=True,
        per_device_train_batch_size=args.batch_size,
        num_train_epochs=args.num_epochs,
        save_strategy="epoch",
        logging_dir=logging_dir,
        logging_steps=args.logging_steps,
        save_total_limit=2,
        eval_strategy="no",
        learning_rate=args.learning_rate,
        weight_decay=0,
        fp16=True,
        report_to="wandb" if is_main_process else None,
        remove_unused_columns=False,
        warmup_ratio=0.05,
        lr_scheduler_type="linear",
        ddp_find_unused_parameters=False
    )

    # Initialize data collator
    data_collator = CustomDataCollator(tokenizer=tokenizer, max_length=args.max_seq_len)

    # Initialize evaluation callbacks
    test_file_paths = [
        os.path.join(args.test_data_dir, f"{_name}.json")
        for _name in ["satisfiable", "unsatisfiable", "data"] 
    ]
    callbacks = [
        EvaluationCallback(
            model=model,
            tokenizer=tokenizer,
            test_file_paths=test_file_paths,
            eval_steps=args.eval_steps,
            logging_file_dir=logging_dir,
            eval_config=EvaluationConfig(with_feedback=feedback),
        ) 
        for feedback in [True, False]  
    ]

    # Create and train the model
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=data_collator,
        callbacks=callbacks,
    )
    trainer.train()

    # Final evaluation
    if is_main_process:
        wandb.finish()


if __name__ == "__main__":
    main()
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
