from unsloth import FastVisionModel, is_bf16_supported

0

import gc
from functools import partial
from pathlib import Path

import hydra
import numpy as np
import torch
from datasets import Dataset
from loguru import logger
from omegaconf import DictConfig
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import get_last_checkpoint

from data.fsft_dataset import get_fsft_test_dataset, get_fsft_train_dataset
from mllm.vllm_client import VLLMClient
from models.init_class_embed import get_class_similarity, init_class_embed
from models.reward_funcs import (contrastive_reasoning_reward_func,
                                 correctness_reward_func,
                                 soft_format_reward_func,
                                 strict_format_reward_func)
from models.trainer.grpo_trainer import FSGRPOTrainer, GRPOConfig


class SaveClassEmbedCallback(TrainerCallback):
    def __init__(self, trainer: FSGRPOTrainer):
        self.trainer = trainer

    def on_save(self, args, state, control, **kwargs):
        cur_folder = get_last_checkpoint(args.output_dir)
        class_embed_path = Path(cur_folder) / "class_embed.pt"
        torch.save(self.trainer.class_embed, class_embed_path)

def train(cfg: DictConfig):

    model, processor = FastVisionModel.from_pretrained(
        model_name = cfg.model.base_model_name,
        max_seq_length = cfg.model.max_seq_length,
        load_in_4bit = False,
        load_in_8bit = True,
        full_finetuning = False,
    )

    dataset_base = hydra.utils.instantiate(cfg.dataset)
    classnames, train_dataset = get_fsft_train_dataset(dataset_base)
    _, test_dataset = get_fsft_test_dataset(dataset_base, limit=16)

    # Init class_embed
    class_embed_cache_path = Path("outputs") / "cache" / f"{dataset_base.__class__.__name__}_{cfg.num_shots}shots_class_embed_sigavg.pt"
    class_embed_init_fn = init_class_embed

    class_embed = None
    checkpoint_folder = Path("outputs/" + cfg.exp_name).exists() and get_last_checkpoint("outputs/" + cfg.exp_name)
    if cfg.resume_from_checkpoint and checkpoint_folder and (checkpoint_folder / "class_embed.pt").exists():
        with open(checkpoint_folder / "class_embed.pt", "rb") as f:
            class_embed = torch.load(f).to(torch.bfloat16)
        logger.info(f"Loaded class_embed from {checkpoint_folder}")

    if class_embed_init_fn is not None and class_embed is None:
        if cfg.use_class_embed_cache and class_embed_cache_path.exists():
            with open(class_embed_cache_path, "rb") as f:
                class_embed = torch.load(f)
        else:
            with torch.inference_mode():
                class_embed = class_embed_init_fn(model, processor.image_processor, train_dataset, classnames)
            class_embed_cache_path.parent.mkdir(parents=True, exist_ok=True)
            with open(class_embed_cache_path, "wb") as f:
                torch.save(class_embed, f)
        class_embed = class_embed.to(torch.bfloat16)

    # Reference class selection strategy
    class_similarity = get_class_similarity(dataset_base)

    if cfg.collator_category_selection_temperature is not None:
        class_similarity = np.asarray(class_similarity)
        class_similarity = class_similarity - np.max(class_similarity, axis=1, keepdims=True)
        class_similarity = class_similarity / cfg.collator_category_selection_temperature
        class_similarity = np.exp(class_similarity)
        class_similarity = class_similarity / np.sum(class_similarity, axis=1, keepdims=True)

    # PEFT
    model = FastVisionModel.get_peft_model(
        model,
        **cfg.model.lora_params,
    )

    FastVisionModel.for_training(model)

    # Clear deleted GPU items
    for _ in range(3):
        gc.collect()
        torch.cuda.empty_cache()
    pass

    model = model.to(dtype=torch.bfloat16)

    reward_funcs = [strict_format_reward_func, soft_format_reward_func, correctness_reward_func]
    if cfg.contrastive_reward:
        vllm_judge_client = VLLMClient(cfg.vllm_judge_server_host, cfg.vllm_judge_server_port)
        contrastive_reasoning_reward_func_partial = partial(contrastive_reasoning_reward_func, processor=processor, vllm_client=vllm_judge_client)
        contrastive_reasoning_reward_func_partial.__name__ = "contrastive_reasoning_reward_func"
        reward_funcs.append(contrastive_reasoning_reward_func_partial)

    trainer = FSGRPOTrainer(
        model=model,
        reward_funcs=reward_funcs,
        top_k=cfg.top_k,
        classnames=classnames,
        class_embed=class_embed,
        class_similarity=class_similarity,
        args=GRPOConfig(
            learning_rate = cfg.learning_rate,
            adam_beta1 = 0.9,
            adam_beta2 = 0.99,
            weight_decay = 0.1,
            warmup_ratio = 0.01,
            lr_scheduler_type = "cosine",
            optim = "adamw_8bit", #"adamw_torch_fused",
            fp16 = not is_bf16_supported(),
            bf16 = is_bf16_supported(),
            class_embed_trainable=cfg.class_embed_trainable,
            use_vllm=True,
            vllm_mode="server",
            vllm_server_host=cfg.vllm_server_host,
            vllm_server_port=cfg.vllm_server_port,
            vllm_server_group_port=cfg.vllm_server_group_port,
            logging_steps = 1,
            per_device_train_batch_size = cfg.per_device_train_batch_size,
            gradient_accumulation_steps = 1,
            num_generations = cfg.per_device_train_batch_size,
            max_prompt_length = 3100,
            max_completion_length = 512,
            num_train_epochs = cfg.epochs,
            save_steps = 500,
            seed = 3407,
            max_grad_norm = 0.1,
            report_to = "wandb" if cfg.wandb else "none", # Can use Weights & Biases
            output_dir="outputs/" + cfg.exp_name,
            save_total_limit=2,
        ),
        train_dataset=Dataset.from_list(train_dataset),
    )
    trainer.add_callback(SaveClassEmbedCallback(trainer))

    trainer_stats = trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)

    output_dir = "outputs/" + cfg.exp_name
    model.save_pretrained(output_dir)
    processor.save_pretrained(output_dir)
    if trainer.class_embed is not None:
        with open(output_dir + "/class_embed.pt", "wb") as f:
            torch.save(trainer.class_embed, f)

@hydra.main(version_base=None, config_path="configs", config_name="train_reasoning")
def main(cfg: DictConfig):
    train(cfg)

if __name__ == "__main__":
    main()
