import os
import json
import torch
from transformers import LlavaNextForConditionalGeneration, AutoModelForCausalLM

def save_model_and_tokenizer(model_name_or_path, model, tokenizer, drop_layers_after, output_dir, trainer, val_datasets):
    current_step = trainer.state.global_step
    output_dir = os.path.join(output_dir, f"step_{current_step}_{model_name_or_path.split('/')[-1].lower()}")
    os.makedirs(output_dir, exist_ok=True)
    print(f"\n\nModel and tokenizer saving to {output_dir}\n\n")
    
    # merge lora
    merged_model = model.merge_and_unload() 
    # merge original layers
    anchor_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto", torch_dtype=merged_model.dtype)
    merged_model.model.layers = merged_model.model.layers + anchor_model.model.layers[drop_layers_after+1:]
    merged_model.config = anchor_model.config

    merged_model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

    lorra_config_path = os.path.join(output_dir, "lorra_config.json")
    with open(lorra_config_path, "w", encoding="utf-8") as file:
        json.dump(lorra_config_path, file, indent=2)
    
    torch.use_deterministic_algorithms(False)
    # trainer.evaluate(eval_dataset=val_datasets, sanity_check=True)
    

def save_llava_model_and_tokenizer(model_name_or_path, model, processor, drop_layers_after, output_dir, trainer, val_datasets):
    os.makedirs(output_dir, exist_ok=True)
    print(f"MModel and processor saving to {output_dir}")
    
    # merge lora
    merged_model = model.merge_and_unload() 
    # merge original layers
    anchor_model = LlavaNextForConditionalGeneration.from_pretrained(model_name_or_path, device_map="auto", torch_dtype=merged_model.dtype)
    merged_model.language_model.model.layers = merged_model.language_model.model.layers + anchor_model.language_model.model.layers[drop_layers_after+1:]
    merged_model.config = anchor_model.config

    merged_model.save_pretrained(output_dir)
    processor.save_pretrained(output_dir)

    language_model_output_dir = os.path.join(output_dir, "language_model")
    print(f"Language Model and tokenizer saving to {language_model_output_dir}")
    merged_model.language_model.config = anchor_model.language_model.config
    merged_model.language_model.save_pretrained(language_model_output_dir)
    processor.tokenizer.save_pretrained(language_model_output_dir)

    lorra_config_path = os.path.join(output_dir, "lorra_config.json")
    with open(lorra_config_path, "w", encoding="utf-8") as file:
        json.dump(lorra_config_path, file, indent=2)
    
    torch.use_deterministic_algorithms(False)
    trainer.evaluate(eval_dataset=val_datasets, sanity_check=True)
    