import logging
import random
import sys
import datasets
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from transformers import AutoModelForSequenceClassification
import argparse
import gc
import os
import json
import shutil

import numpy as np
from alignment import (
    DataArguments,
    H4ArgumentParser,
    ModelArguments,
    get_checkpoint,
    get_datasets,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
    get_tokenizer,
    is_adapter_model,
)
from alignment.data import maybe_insert_system_message, is_openai_format
from peft import PeftConfig, PeftModel
from simpo_trainer import SimPOTrainer
from simpo_config import SimPOConfig
from dataclasses import dataclass, field
from typing import Optional, Literal
import warnings
from tqdm import tqdm
warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)

MISTRAL_CHAT_TEMPLATE = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'].strip() + '\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"

from vllm import LLM, SamplingParams



def merge_models(model_path, ref_model_path, torch_dtype, trust_remote_code, temp_path="outputs/temp_path", alpha=0.5, merged_name=None):# lm_head up_proj down_proj o_proj 
    # 加载模型
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
    ref_model = AutoModelForCausalLM.from_pretrained(ref_model_path, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
    # print(ref_model)

    # 获取模型参数
    model_params = model.state_dict()
    ref_model_params = ref_model.state_dict()

    # 融合参数
    merged_params = {}
    for name in model_params.keys():
        if merged_name!=None and merged_name not in name:
            merged_params[name] = ref_model_params[name]
            continue
        else:
            theta_model = model_params[name]
            theta_ref_model = ref_model_params[name]
            merged_params[name] = alpha * theta_model + (1 - alpha) * theta_ref_model

    # 更新模型参数
    model.load_state_dict(merged_params)

    # 保存融合后的模型
    model.save_pretrained(temp_path)
    del model, ref_model, merged_params, model_params, ref_model_params
    gc.collect()
    print(f"The merged model has been saved to {temp_path}")


def reward_annotate(output_data, reward_model_path="RLHFlow/ArmoRM-Llama3-8B-v0.1"):
    model = AutoModelForSequenceClassification.from_pretrained(reward_model_path, 
                                                            device_map="cuda", 
                                                            trust_remote_code=True, torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(reward_model_path, use_fast=True)

    for data in tqdm(output_data):
        prompt = data["prompt"]
        candidates = data["all_generated_responses"]
        scores = []
        for candidate in candidates:
            messages = [{"role": "user", "content": prompt},
                        {"role": "assistant", "content": candidate}]
            input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
            with torch.no_grad():
                output = model(input_ids)
                score = output.score.float().item()
                scores.append(score)
        data["all_rm_scores"] = scores
        data["weight_chosen"] = 1

    # Binarize data: win = highest scoring reponse; lose = lowest scoring response
    for data in output_data:
        chosen_idx = np.argmax(data["all_rm_scores"])
        rejected_idx = np.argmin(data["all_rm_scores"])
        chosen = []
        chosen.append({
            "role": "user",
            "content": data["prompt"]
        })
        chosen.append({
            "role": "assistant",
            "content": data["all_generated_responses"][chosen_idx]
        })
        rejected = []
        rejected.append({
            "role": "user",
            "content": data["prompt"]
        })
        rejected.append({
            "role": "assistant",
            "content": data["all_generated_responses"][rejected_idx]
        })
        data.update({
            "chosen": chosen,
            "rejected": rejected,
        })
    output_data = datasets.Dataset.from_list(output_data)
    del model, tokenizer
    return output_data

def rollout(model_path, dataset, n=5, max_tokens=4096, top_p=0.95, temperature=0.8):
    prompts = dataset['prompt']
    chosens = dataset['ori_chosen']
    rejecteds = dataset['ori_rejected']
    # prompt_ids = dataset['prompt_id']

    model=LLM(model_path,trust_remote_code=True, tensor_parallel_size=1, gpu_memory_utilization=0.7, max_model_len=max_tokens)
    tokenizer = model.get_tokenizer()
    sampling_params = SamplingParams(max_tokens=max_tokens,
                                    top_p=top_p,
                                    temperature=temperature,
                                    n=n)

    conversations = [tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}],sampling_params=sampling_params, tokenize=False, add_generation_prompt=True) for prompt in prompts]

    outputs = model.generate(conversations, use_tqdm=True,
                                 sampling_params=sampling_params)
    # Save the outputs as a JSON file.
    output_data = []
    num_identical = 0
    for i, output in enumerate(outputs):
        prompt = output.prompt
        gen_text = []
        for generated_text in output.outputs:
            gen_text.append(generated_text.text)
        if len(set(gen_text)) == 1:
            # filter out samples where all generated responses are identical
            num_identical += 1
            continue

        output_data.append({
            # 'prompt_id': prompt_ids[i],
            'prompt': prompts[i],
            'all_generated_responses': gen_text,
            'ori_chosen': chosens[i],
            'ori_rejected': rejecteds[i],
        })
    print(f"Filtered out {num_identical} samples with identical generated responses")
    del model
    return output_data

def adjust_proportion(dataset, on_policy_data_proportion):
    print("=====================================================================")
    print("on_policy_data_proportion ", on_policy_data_proportion)
    print("on_policy_data_proportion ", on_policy_data_proportion)
    print("on_policy_data_proportion ", on_policy_data_proportion)
    print("=====================================================================")
    num_rows = len(dataset)
    # 计算需要保留的样本数量
    num_keep_chosen = int(num_rows * on_policy_data_proportion)

    # 随机选择需要保留的样本索引
    chosen_indices = random.sample(range(num_rows), num_keep_chosen)
    # 创建一个布尔掩码，标记哪些样本需要保留
    mask = np.zeros(num_rows, dtype=bool)
    mask[chosen_indices] = True

    def map_function(example, idx):
        if mask[idx]:
            return {
                'chosen': example['chosen'],
                'rejected': example['rejected']
            }
        else:
            return {
                'chosen': example['ori_chosen'],
                'rejected': example['ori_rejected']
            }

    # 使用 map 方法处理数据集
    new_dataset = dataset.map(map_function, with_indices=True)
    return new_dataset

def apply_chat_template(
    example,
    tokenizer,
    task: Literal["sft", "generation", "rm", "simpo"],
    auto_insert_empty_system_msg: bool = True,
    change_template = None,
):
    
    if change_template == "mistral":
        tokenizer.chat_template = MISTRAL_CHAT_TEMPLATE
    if task in ["sft", "generation"]:
        messages = example["messages"]
        # We add an empty system message if there is none
        if auto_insert_empty_system_msg:
            maybe_insert_system_message(messages, tokenizer)
        example["text"] = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True if task == "generation" else False,
        )
    elif task == "rm":
        if all(k in example.keys() for k in ("chosen", "rejected")):
            chosen_messages = example["chosen"]
            rejected_messages = example["rejected"]
            # We add an empty system message if there is none
            if auto_insert_empty_system_msg:
                maybe_insert_system_message(chosen_messages, tokenizer)
                maybe_insert_system_message(rejected_messages, tokenizer)

            example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
            example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
        else:
            raise ValueError(
                f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
            )
    elif task == "simpo":
        if all(k in example.keys() for k in ("chosen", "rejected")):
            if not is_openai_format(example["chosen"]) or not is_openai_format(example["rejected"]):
                raise ValueError(
                    f"Could not format example as dialogue for `{task}` task! Require OpenAI format for all messages"
                )

            # For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue
            # We therefore need to extract the N-1 turns to form the prompt
            if "prompt" in example and is_openai_format(example["prompt"]):
                prompt_messages = example["prompt"]
                chosen_messages = example["chosen"]
                rejected_messages = example["rejected"]
            else:
                prompt_messages = example["chosen"][:-1]
                # Now we extract the final turn to define chosen/rejected responses
                chosen_messages = example["chosen"][-1:]
                rejected_messages = example["rejected"][-1:]

            # Prepend a system message if the first message is not a system message
            if auto_insert_empty_system_msg:
                maybe_insert_system_message(prompt_messages, tokenizer)

            example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False)
            example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
            if tokenizer.bos_token is not None and example["text_chosen"].startswith(tokenizer.bos_token):
                example["text_chosen"] = example["text_chosen"][len(tokenizer.bos_token):]
            example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
            if tokenizer.bos_token is not None and example["text_rejected"].startswith(tokenizer.bos_token):
                example["text_rejected"] = example["text_rejected"][len(tokenizer.bos_token):]
        else:
            raise ValueError(
                f"Could not format example as dialogue for `{task}` task! Require either the "
                f"`[chosen, rejected]` or `[prompt, chosen, rejected]` keys but found {list(example.keys())}"
            )
    else:
        raise ValueError(
            f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo']"
        )
    return example

def get_number(s_tuple):
    s,_ = s_tuple
    start_index = s.find("_rr_")
    num = float(s[start_index-1:start_index])
    return num
def get_reset_checkpoint(path):
    checkpoint_folders = []
    for item in os.listdir(path):
        item_path = os.path.join(path, item)
        if os.path.isdir(item_path) and item.startswith('checkpoint-'):
            try:
                number = int(item.split('-')[1])
                checkpoint_folders.append((number, item_path))
            except ValueError:
                continue
    if not checkpoint_folders:
        return None
    max_number, max_path = max(checkpoint_folders)
    return max_path
def check_safetensors_files(src_folder):
    if os.path.exists(src_folder) and os.path.isdir(src_folder):
        for root, _, files in os.walk(src_folder):
            for file in files:
                if file.endswith('.safetensors'):
                    return True
    return False
def get_sorted_folder_paths(src_folder):
    try:
        folder_paths = []
        for item in os.listdir(src_folder):
            item_path = os.path.join(src_folder, item)
            if os.path.isdir(item_path):
                try:
                    folder_num = int(item)
                    if check_safetensors_files(item_path):
                        folder_paths.append((folder_num, item_path))
                except ValueError:
                    continue
        folder_paths.sort(key=lambda x: x[0])
        sorted_paths = [path for _, path in folder_paths]
        return sorted_paths
    except FileNotFoundError:
        return []
    except Exception as e:
        return []

"""
ACCELERATE_LOG_LEVEL=info  accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_rrpo.py training_configs/Llama-3.2-3B-Instruct/dpo_math.yaml
"""

def main():
    parser = H4ArgumentParser((ModelArguments, DataArguments, SimPOConfig))
    model_args, data_args, training_args = parser.parse()
    #######
    # Setup
    #######
    num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
    training_args.gradient_accumulation_steps= 64 // num_gpus # Same with dpo
    # training_args.save_steps = 64

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    # logger.info(f"Model parameters {model_args}")
    # logger.info(f"Data parameters {data_args}")
    # logger.info(f"Training/evaluation parameters {training_args}")
    # # Check for last checkpoint
    last_checkpoint = get_checkpoint(training_args)
    if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
        logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")

    # # Set seed for reproducibility
    set_seed(training_args.seed)
    
    # ###############
    # # Load datasets
    # ###############
    raw_datasets = get_datasets(
        data_args,
        splits=data_args.dataset_splits,
        configs=data_args.dataset_configs,
        columns_to_keep=["chosen", "rejected", "prompt", "ori_chosen", "ori_rejected"],
    )
    logger.info(
        f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
    )
    column_names = raw_datasets["train"].features
    training_num = raw_datasets["train"].num_rows
    # #####################################
    # # Load tokenizer and process datasets
    # #####################################
    data_args.truncation_side = "left"  # Truncate from left to ensure we don't lose labels in final turn
    tokenizer = get_tokenizer(model_args, data_args)
    if "mistral" in model_args.model_name_or_path.lower():
        change_template = "mistral"
    else:
        change_template = None

    ###############################################
    # 这里划分我们的chuck进行 Replay Ratio
    n = training_args.chunk_num
    chunk_size = len(raw_datasets["train"]) // n
    raw_datasets_all_chunk = []
    for i in range(n):
        start = i * chunk_size
        end = start + chunk_size if i < n - 1 else len(raw_datasets["train"])
        chunk_train = raw_datasets["train"][start:end]
        if isinstance(chunk_train, dict):
            # 如果是字典形式，将其转换回 Dataset 对象
            chunk_train = datasets.Dataset.from_dict(chunk_train)
        chunk = datasets.DatasetDict({
            "train": chunk_train,
            "test": raw_datasets["test"]
        })
        raw_datasets_all_chunk.append(chunk)

    training_args.output_dir = training_args.output_dir + f"-proportion{training_args.on_policy_data_proportion}-sft{training_args.sft_weight}-chunk{n}"
    output_dir_ori = training_args.output_dir
    ###############################################
    if training_args.is_reset:
        folder_names = []
        # base_path = output_dir_ori.split("/")[-2]
        base_path = output_dir_ori.rsplit('/', 1)[0] + "/"
        for p in os.listdir(base_path):
            if os.path.isdir(os.path.join(base_path, p)):
                has_files = any(os.scandir(os.path.join(base_path, p)))
                if has_files:
                    my_path = os.path.join(base_path, p)
                    if training_args.sft_reset=="sft" and "rr_sft" in p:
                        folder_names.append(( my_path, get_sorted_folder_paths(my_path)) )
                    if training_args.sft_reset=="pro" and "rr_pro" in p:
                        folder_names.append(( my_path, get_sorted_folder_paths(my_path)) )
                    if training_args.sft_reset=="all" and "rr_all" in p:
                        folder_names.append(( my_path, get_sorted_folder_paths(my_path)) )
        folder_names_temp = sorted(folder_names, key=get_number, reverse=False)
        folder_names=[]
        for folder in folder_names_temp:
            if len(folder[1])!=0:
                folder_names.append(folder)
        run_num = len(folder_names)
        if run_num>0:
            print("*"*30)
            print(folder_names)
            print("*"*30)
    ###############################################

    my_epoch = int(training_args.num_train_epochs)
    for epoch in range(my_epoch):
        training_args.num_train_epochs=1
        prv_raw_datasets_train=None
        for chunk_num, raw_datasets in enumerate(raw_datasets_all_chunk): 
            # #####################
            # 创建新的采样数据集，然后用reward model标注
            #######################
            cur_raw_datasets = raw_datasets.copy()
            k=-1
            while k<training_args.replay_ratio:
                k+=1
                if k==training_args.replay_ratio:
                    break
                if training_args.is_reset:
                    if run_num>0:
                        #跳过跑过的chuck
                        run_num-=1
                        training_args.output_dir = folder_names[chunk_num][-1][-1]
                        k=int(training_args.output_dir[-1])
                        print("now train", training_args.output_dir, k)
                        if k==training_args.replay_ratio-1:
                            break
                        else:
                            continue
                    
                    if chunk_num==0 and k==0:
                        model = model_args.model_name_or_path
                    else:
                        model = training_args.output_dir
                    print("=============TOBE Rollout Model=============", model, k, chunk_num)
                    
                    raw_datasets_train= cur_raw_datasets["train"]

                    train_dataset = rollout(model, raw_datasets_train)
                    raw_datasets_train = reward_annotate(train_dataset, reward_model_path=training_args.my_reward_model_path)
                    del train_dataset
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    print(f"the data_chunk {chunk_num}/{n} has been rollouted and annotated")

                    if training_args.replay_ratio>1:
                        if training_args.sft_reset=="sft":
                            training_args.sft_weight = k*(training_args.reset_size)/(training_args.replay_ratio - 1)
                        elif training_args.sft_reset=="pro":
                            training_args.on_policy_data_proportion =  1 - k*(1-training_args.reset_size)/(training_args.replay_ratio-1)
                        else:
                            training_args.on_policy_data_proportion =  1 - k*(1-training_args.reset_size)/(training_args.replay_ratio-1)
                            training_args.sft_weight = k*(training_args.reset_size)/(training_args.replay_ratio - 1)
                # rollout后的data做保存
                to_be_saved_data = raw_datasets_train

                #调节on policy data比例
                raw_datasets_train = adjust_proportion(raw_datasets_train, training_args.on_policy_data_proportion)
                raw_datasets["test"] = cur_raw_datasets["test"]
                raw_datasets["train"] = raw_datasets_train

                if chunk_num>0:
                    prv_raw_datasets = []
                    for pre_i in range(chunk_num):
                        if training_args.sft_reset=="sft":
                            pre_path = output_dir_ori + f"_{pre_i}_rr_sft/{training_args.replay_ratio-1}"
                        elif training_args.sft_reset=="pro":
                            pre_path = output_dir_ori + f"_{pre_i}_rr_pro/{training_args.replay_ratio-1}"
                        else:
                            pre_path = output_dir_ori + f"_{pre_i}_rr_all/{training_args.replay_ratio-1}"
                        prv_raw_datasets_train = datasets.load_dataset("json", data_files=os.path.join(pre_path, "raw_datasets.json"),split="train")
                        prv_raw_datasets.append(prv_raw_datasets_train)

                    prv_raw_datasets.append(raw_datasets_train)
                    raw_datasets["train"] = datasets.concatenate_datasets(prv_raw_datasets)

                torch_dtype = (
                    model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
                )
                quantization_config = get_quantization_config(model_args)

                model_kwargs = dict(
                    revision=model_args.model_revision,
                    trust_remote_code=model_args.trust_remote_code,
                    torch_dtype=torch_dtype,
                    use_cache=False if training_args.gradient_checkpointing else True,
                    device_map=get_kbit_device_map() if quantization_config is not None else None,
                    quantization_config=quantization_config,
                    attn_implementation=model_args.attn_implementation,
                )

                training_args.model_init_kwargs = model_kwargs

                model = model_args.model_name_or_path
                ref_model=model
                if training_args.is_reset:
                    training_args.save_steps=100000 # do not use checkpoint
                    training_args.save_strategy="no"
                    if chunk_num==0 and k==0:
                        ref_model = model_args.model_name_or_path
                    else:
                        ref_model = training_args.output_dir
                    model=model_args.model_name_or_path

                    if training_args.sft_reset=="sft":
                        training_args.output_dir = output_dir_ori + f"_{chunk_num}_rr_sft/{k}"
                    elif training_args.sft_reset=="pro":
                        training_args.output_dir = output_dir_ori + f"_{chunk_num}_rr_pro/{k}"
                    else:
                        training_args.output_dir = output_dir_ori + f"_{chunk_num}_rr_all/{k}"

                    if chunk_num==0 and k==0:
                        pass
                    else:
                        merge_models(model, ref_model, torch_dtype=torch_dtype, trust_remote_code=model_args.trust_remote_code,
                                                        temp_path=base_path+"/temp_path", alpha=training_args.dropout_size, merged_name="o_proj")
                        model=base_path+"/temp_path"

                to_be_saved_data.to_json(os.path.join(training_args.output_dir, f'raw_datasets.json'))

                #####################
                # Apply chat template
                # #####################
                if "Phi" in model_args.model_name_or_path:
                    shutil.copyfile(os.path.join(model_args.model_name_or_path,'configuration_phi3.py'), os.path.join(training_args.output_dir,'configuration_phi3.py'))
                    shutil.copyfile(os.path.join(model_args.model_name_or_path,'modeling_phi3.py'), os.path.join(training_args.output_dir,'modeling_phi3.py'))
                    shutil.copyfile(os.path.join(model_args.model_name_or_path,'configuration_phi3.py'), os.path.join(base_path,'configuration_phi3.py'))
                    shutil.copyfile(os.path.join(model_args.model_name_or_path,'modeling_phi3.py'), os.path.join(base_path,'modeling_phi3.py'))

                if "Qwen" in  model_args.model_name_or_path:
                    tokenizer.pad_token = tokenizer.eos_token
                    tokenizer.bos_token = tokenizer.eos_token

                raw_datasets = raw_datasets.map(
                    apply_chat_template,
                    fn_kwargs={
                        "tokenizer": tokenizer,
                        "task": "simpo",
                        "auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
                        "change_template": change_template,
                    },
                    num_proc=data_args.preprocessing_num_workers,
                    remove_columns=column_names,
                    desc="Formatting comparisons with prompt template",
                )
                # Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
                for split in ["train", "test"]:
                    raw_datasets[split] = raw_datasets[split].rename_columns(
                        {"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"}
                    )
                #######################
                # Instantiate SimPO trainer
                #######################
                print("#################################TO BE TRAINED####################################")
                print("model", model)
                print("ref_model", ref_model)
                print("#################################TO BE TRAINED####################################")

                trainer = SimPOTrainer(
                    model=model,
                    ref_model=ref_model,
                    args=training_args,
                    train_dataset=raw_datasets["train"],
                    eval_dataset=raw_datasets["test"],
                    tokenizer=tokenizer,
                    peft_config=get_peft_config(model_args),
                )
                
                # ###############
                # # Training loop
                # ###############
                checkpoint = None
                if not training_args.is_reset:
                    last_checkpoint = get_checkpoint(training_args)
                    if training_args.resume_from_checkpoint is not None:
                        checkpoint = training_args.resume_from_checkpoint
                    elif last_checkpoint is not None:
                        checkpoint = last_checkpoint
                else:
                    pass
                    # if chunk_num>0:
                        # checkpoint=get_reset_checkpoint(ref_model)
                print("#################################checkpoint####################################")
                print("checkpoint", checkpoint, len(raw_datasets["train"]), chunk_num, k)
                print("#################################checkpoint####################################")

                train_result = trainer.train(resume_from_checkpoint=checkpoint)

                metrics = train_result.metrics
                metrics["train_samples"] = len(raw_datasets["train"])
                trainer.log_metrics("train", metrics)
                trainer.save_metrics("train", metrics)
                trainer.save_state()

                logger.info("*** Training complete ***")

                ##################################
                # Save model and create model card
                ##################################
                logger.info("*** Save model ***")
                trainer.save_model(training_args.output_dir)
                logger.info(f"Model saved to {training_args.output_dir}")

                # Save everything else on main process
                kwargs = {
                    "finetuned_from": model_args.model_name_or_path,
                    "dataset": list(data_args.dataset_mixer.keys()),
                    "dataset_tags": list(data_args.dataset_mixer.keys()),
                    "tags": ["alignment-handbook"],
                }
                if trainer.accelerator.is_main_process:
                    trainer.create_model_card(**kwargs)
                    # Restore k,v cache for fast inference
                    trainer.model.config.use_cache = True
                    trainer.model.config.save_pretrained(training_args.output_dir)

                ##########
                # Evaluate
                ##########
                # if training_args.do_eval:
                #     logger.info("*** Evaluate ***")
                #     metrics = trainer.evaluate()
                #     metrics["eval_samples"] = len(raw_datasets["test"])
                #     trainer.log_metrics("eval", metrics)
                #     trainer.save_metrics("eval", metrics)

                if training_args.push_to_hub is True:
                    logger.info("Pushing to hub...")
                    trainer.push_to_hub(**kwargs)

                trainer.model = None
                trainer.ref_model = None
                trainer.accelerator.state._reset_state(True)
                trainer.accelerator.gradient_state._reset_state()
                trainer.accelerator = None
                # del trainer
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                logger.info("*** Training complete! ***")


if __name__ == "__main__":
    main()
