from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from dataset.prompt_dataset import PromptAndImageDataset
from models.utils.logger import get_logger
from models.components.refiner.llmrefiner import MistralRefiner, MistralRefinerwithLM, MistralRefinerwithNLP, MistralRefinerwithMLLM, MistralRefinerwithClassname
import os
import wandb
import lightning as L
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.strategies import DDPStrategy
from huggingface_hub import login

login(token="YOUR_HUGGINGFACE_TOKEN")


logger = get_logger(__name__)
L.seed_everything(995)

wandb.login(key="YOUR_WANDB_API_KEY")

PROJECT_DIR = os.path.dirname(os.path.dirname(__file__))

def main(
        gpus: int,
        nodes: int,
        dataset_path: str,
        name: str,
        model_type: str,
        version: str
):
    model_name = "meta-llama/Llama-3.1-8B-Instruct"
    batch_size = 2
    if model_type == "llm":
        model = MistralRefiner(model_name=model_name)
    elif model_type == "lm":
        model = MistralRefinerwithLM(model_name=model_name)
    elif model_type == "nlp":
        model = MistralRefinerwithNLP(model_name=model_name)
    elif model_type == "mllm":
        model = MistralRefinerwithMLLM(model_name=model_name)
    elif model_type == "custom":
        model = MistralRefinerwithClassname(model_name=model_name)

    train_dataset = PromptAndImageDataset(dataset_path)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    wandb_logger = WandbLogger(project="refiner", name=name)
    if not os.path.exists(os.path.join(PROJECT_DIR, "checkpoints", name)):
        os.makedirs(os.path.join(PROJECT_DIR, "checkpoints", name, f"version_{version}"))
    checkpoint_callback = ModelCheckpoint(
        monitor="train_loss",
        dirpath=os.path.join(PROJECT_DIR, "checkpoints", name, f"version_{version}"),
        filename="refiner-{epoch:02d}-{train_loss:.2f}",
        mode="min"
    )

    if gpus > 1:
        strategy = DDPStrategy(find_unused_parameters=True)
    else:
        strategy = 'auto'

    trainer = Trainer(
        devices=gpus,
        num_nodes=nodes,
        accelerator="gpu",
        strategy=strategy,
        precision="bf16-mixed",
        max_epochs=3,
        accumulate_grad_batches=16,
        log_every_n_steps=10,
        logger=wandb_logger,
        callbacks=[checkpoint_callback]
    )

    trainer.fit(model, train_dataloaders=train_loader)
    logger.info(f"{checkpoint_callback.best_model_path} is the best model path")

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Train Mistral LoRA model")
    parser.add_argument("--gpus", type=int, default=1, help="Number of GPUs to use")
    parser.add_argument("--nodes", type=int, default=1, help="Number of nodes to use")
    parser.add_argument("--dataset_path", type=str, help="Path to the dataset")
    parser.add_argument("--name", type=str, help="Name of the experiment")
    parser.add_argument("--model_type", type=str, help="Type of the model")
    parser.add_argument("--version", type=str, help="Version of the model")
    args = parser.parse_args()

    main(gpus=args.gpus, nodes=args.nodes, dataset_path=args.dataset_path, name=args.name, model_type=args.model_type, version=args.version)
