import os
import math
from typing import Any, Dict, List, Optional, Tuple, Union

from dataclasses import dataclass, field

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torch.optim import AdamW
import numpy as np
from tqdm import tqdm
import random
import json
import pandas as pd
from collections import defaultdict
from einops import rearrange

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    HfArgumentParser,
    get_scheduler,
)
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin
from accelerate.utils.deepspeed import get_active_deepspeed_plugin
import nvidia_smi

from loguru import logger
import requests


import warnings

warnings.filterwarnings("ignore")


@dataclass
class ScriptArguments:
    algorithm: str = field(default="reweighted_sft", metadata={"help": "The algorithm to use"})
    model_name_or_path: Optional[str] = field(
        default="meta-llama/Llama-3.1-8B",
        metadata={"help": "The location of the model name or path"},
    )
    reward_model_name_or_path: Optional[str] = field(
        default="NCSOFT/Llama-3-OffsetBias-RM-8B",
        metadata={"help": "The location of the reward model name or path"},
    )

    learning_rate: float = field(default=1e-4, metadata={"help": "The learning rate"})
    weight_decay: float = field(default=1e-2, metadata={"help": "The L2 weight decay rate of AdamW"})
    adam_beta1: float = field(default=0.9, metadata={"help": "The beta1 parameter for AdamW"})
    adam_beta2: float = field(default=0.999, metadata={"help": "The beta2 parameter for AdamW"})
    adam_epsilon: float = field(default=1e-6, metadata={"help": "The epsilon parameter for AdamW"})

    per_device_train_batch_size: int = field(default=8, metadata={"help": "Train batch size per device"})
    per_device_eval_batch_size: int = field(default=16, metadata={"help": "Eval batch size per device"})
    reward_batch_size: int = field(default=16, metadata={"help": "Eval reward batch size per device"})
    gradient_accumulation_steps: int = field(default=2, metadata={"help": "Gradients to accumulate before optimizing"})
    num_train_epochs: int = field(default=1, metadata={"help": "Number of training epochs"})
    logging_steps: int = field(default=50, metadata={"help": "Logging frequency"})
    save_steps: int = field(default=2500, metadata={"help": "Saving frequency"})
    eval_steps: int = field(default=50, metadata={"help": "Evaluation frequency"})
    warmup_ratio: float = field(default=0.05, metadata={"help": "Warmup ratio"})

    output_dir: Optional[str] = field(
        default="./checkpoints/reweighted_sft_demo/", metadata={"help": "Output directory"}
    )
    train_file_path: Optional[str] = field(
        default="example_train_data.jsonl",
        metadata={"help": "Path to training data"},
    )
    eval_file_path: Optional[str] = field(
        default="example_val_data.jsonl",
        metadata={"help": "Path to evaluation data"},
    )

    max_prompt_length: int = field(default=1024, metadata={"help": "Maximum prompt length"})
    max_length: int = field(default=512, metadata={"help": "Maximum sequence length"})
    seed: int = field(default=42, metadata={"help": "Random seed"})
    beta: float = field(default=0.1, metadata={"help": "Beta parameter for DPO loss"})
    reference_free: bool = field(default=False, metadata={"help": "Use reference-free DPO"})

    report_to: Optional[str] = field(default="tensorboard", metadata={"help": "Reporting tools"})
    project_name: Optional[str] = field(default="Demo", metadata={"help": "Project name for wandb"})
    run_name: Optional[str] = field(default=None, metadata={"help": "Run name for wandb"})

    # deepspeed
    deepspeed_stage: Optional[int] = field(default=None, metadata={"help": "Deepspeed stage"})

    # loss
    save_last: bool = field(default=False, metadata={"help": "Save last model"})
    save_best: bool = field(default=False, metadata={"help": "Save last model"})
    eval_first: bool = field(default=True, metadata={"help": "Save last model"})
    use_reward_api: bool = field(default=True, metadata={"help": "Use reward api"})
    pre_defined_B: int = field(default=1, metadata={"help": "pre-defined B"})
    do_eval: bool = field(default=True, metadata={"help": "do eval"})
    use_sys_prompt: bool = field(default=True, metadata={"help": "use system prompt"})
    system_prompt: Optional[str] = field(default=None, metadata={"help": "system prompt"})


class EvalDataset(Dataset):
    def __init__(self, eval_file_path):
        with open(eval_file_path, "r", encoding="utf-8") as f:
            infer_data = [json.loads(l) for l in f.readlines()]
        self.data = infer_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


class TrainDataset(Dataset):
    def __init__(self, train_file_path, tokenizer):
        self.tokenizer = tokenizer
        self.pre_defined_B = script_args.pre_defined_B
        self.prompt_prefix = script_args.system_prompt
        df = pd.read_json(train_file_path, lines=True)
        local_data = defaultdict(list)
        self.data = []
        for idx, row in tqdm(
            df.iterrows(), total=len(df), desc="Loading train data", disable=not accelerator.is_main_process
        ):
            conversations = row["conversations"]
            reward = row["reward"][0]

            converted_chat = self.convert_format(conversations[:-1])
            if script_args.use_sys_prompt:
                converted_chat.insert(0, {"role": "system", "content": self.prompt_prefix})
            prefix_text = self.tokenizer.apply_chat_template(converted_chat, tokenize=False, add_generation_prompt=True)
            suffix_text = conversations[-1]["value"] + self.tokenizer.eos_token
            if idx == 0:
                logger.info(f"Example conversation: {converted_chat}")
                logger.info(f"Example prefix: {prefix_text}")
                logger.info(f"Example suffix: {suffix_text}")
            local_data["prefix"].append(prefix_text)
            local_data["suffix"].append(suffix_text)
            local_data["reward"].append(reward)
            if len(local_data["prefix"]) == self.pre_defined_B:
                self.data.append(local_data)
                local_data = defaultdict(list)
        if len(local_data["prefix"]) > 0:
            self.data.append(local_data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

    def convert_format(self, messages):
        converted_messages = []
        for msg in messages:
            new_msg = {"role": "user" if msg["from"] == "human" else "assistant", "content": msg["value"]}
            converted_messages.append(new_msg)
        return converted_messages


def collate_fn(batch):
    all_prefixes = [item for sublist in batch for item in sublist["prefix"]]
    all_suffixes = [item for sublist in batch for item in sublist["suffix"]]
    all_rewards = [item for sublist in batch for item in sublist["reward"]]

    tokenizer.truncation_side = "left"
    tokenizer.padding_side = "left"

    current_batch_prefixes_inputs = tokenizer(
        all_prefixes,
        max_length=script_args.max_prompt_length,
        truncation=True,
        add_special_tokens=True,
        padding=True,
        return_tensors="pt",
    )

    tokenizer.truncation_side = "right"
    tokenizer.padding_side = "right"

    current_batch_suffixes_inputs = tokenizer(
        all_suffixes,
        max_length=script_args.max_length,
        truncation=True,
        add_special_tokens=False,
        padding=True,
        return_tensors="pt",
    )

    input_ids = torch.cat(
        [current_batch_prefixes_inputs["input_ids"], current_batch_suffixes_inputs["input_ids"]], dim=1
    )
    attention_mask = torch.cat(
        [current_batch_prefixes_inputs["attention_mask"], current_batch_suffixes_inputs["attention_mask"]], dim=1
    )

    prefix_lens = current_batch_prefixes_inputs["input_ids"].shape[1]
    return {
        "prefix_lens": torch.LongTensor([prefix_lens]),
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": current_batch_suffixes_inputs["input_ids"],
        "response_mask": current_batch_suffixes_inputs["attention_mask"],
        "rewards": torch.Tensor(all_rewards),
    }


def collate_fn_eval(batch):
    prompt_prefix = script_args.system_prompt
    current_batch_og_prefixes = [datum["prefix"][0] for datum in batch]

    current_batch_prefixes_fixed = []
    for _current_batch_og_prefix in current_batch_og_prefixes:
        if script_args.use_sys_prompt:
            chat = [{"role": "system", "content": prompt_prefix}]
        else:
            chat = []
        for _item in _current_batch_og_prefix[:-1]:
            if _item.startswith("<|prompter|>"):
                chat.append({"role": "user", "content": _item.replace("<|prompter|>", "")})
            elif _item.startswith("<|assistant|>"):
                chat.append({"role": "assistant", "content": _item.replace("<|assistant|>", "")})
            else:
                assert ValueError("Invalid prefix format")
        fixed_text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
        current_batch_prefixes_fixed.append(fixed_text)

    current_batch_prefixes_inputs = tokenizer(
        current_batch_prefixes_fixed,
        max_length=script_args.max_prompt_length,
        truncation=True,
        add_special_tokens=True,
        padding=True,
        return_tensors="pt",
    )

    return {
        "input_ids": current_batch_prefixes_inputs["input_ids"],
        "attention_mask": current_batch_prefixes_inputs["attention_mask"],
        "original_prefixes": current_batch_og_prefixes,
        "fixed_prefixes": current_batch_prefixes_fixed,
    }


def get_dpo_loss(
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
    beta: float,
    reference_free: bool = False,
):
    """Compute the DPO loss for a batch of policy and reference model log probabilities.

    Args:
        policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
        policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
        reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
        reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
        beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.
        reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.

    Returns:
        A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
        The losses tensor contains the DPO loss for each example in the batch.
        The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
    """
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps

    if reference_free:
        ref_logratios = 0

    logits = pi_logratios - ref_logratios

    losses = -F.logsigmoid(beta * logits)
    chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
    rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()

    return (
        losses.mean(),
        chosen_rewards.mean(),
        rejected_rewards.mean(),
        (chosen_rewards > rejected_rewards).float(),
        (chosen_rewards - rejected_rewards).mean(),
    )


def get_reweighted_sft_loss(ref_log_probs, beta, rewards, log_probs):
    B = ref_log_probs.shape[0]
    pre_defined_B = ref_log_probs.shape[1]
    losses = []
    for i in range(B):
        with torch.no_grad():
            partial_item = 1 / sum(
                [
                    torch.exp(
                        ref_log_probs[i, j] + 1 / beta * rewards[i, j] - ref_log_probs[i, i] - 1 / beta * rewards[i, i]
                    )
                    for j in range(pre_defined_B)
                ]
            )

        partial_loss = partial_item * log_probs[i]
        losses.append(-partial_loss)

    return torch.stack(losses).mean()


@torch.no_grad()
def evaluate_on_validation(script_args, eval_dataloader, model, tokenizer, reward_model=None, reward_tokenizer=None):
    global IS_VAL_PRINT
    origin_state = (tokenizer.padding_side, tokenizer.truncation_side)
    tokenizer.truncation_side = "left"
    tokenizer.padding_side = "left"

    device = accelerator.unwrap_model(model).device

    def get_score(prefixes, suffixes):
        texts = []

        for p, s in zip(prefixes, suffixes):
            assert p[-1] == "<|prompter|>" or p[-1] == "<|assistant|>", p[-1]
            temp_prefix = p[:-1] + [p[-1] + s]
            texts.append("".join([t + reward_tokenizer.eos_token for t in temp_prefix]))
        input_content = reward_tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=1024,
            return_tensors="pt",
        )
        input_content = {k: v.to(device) for k, v in input_content.items()}

        rewards = reward_model(**input_content).logits

        return rewards.view(-1)

    all_prefixes = []
    all_fixed_prefixes = []
    all_gen_suffixes = []

    for batch in tqdm(eval_dataloader, desc="Generating responses", disable=not accelerator.is_main_process):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        if IS_VAL_PRINT:
            if accelerator.is_main_process:
                logger.info(f"[VAL] input_ids: {input_ids[-1]}")
                logger.info(f"[VAL] Decoded input_ids: {tokenizer.decode(input_ids[-1])}")
                logger.info(f"[VAL] Attention mask: {attention_mask[-1]}")

        predicted_sents = accelerator.unwrap_model(model).generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=script_args.max_length,
            pad_token_id=tokenizer.pad_token_id,
            top_k=50,
            top_p=0.9,
            temperature=0.8,
            repetition_penalty=1.2,
            num_return_sequences=1,
            eos_token_id=tokenizer.eos_token_id,
        )

        response_tokens = predicted_sents[:, input_ids.shape[-1] :]
        responses = tokenizer.batch_decode(response_tokens, skip_special_tokens=True)

        if "llama-7b" in script_args.model_name_or_path.lower():
            responses_normalized = [
                resp.split("\n Human:")[0].split("\nHuman:")[0].split("\n### Human")[0].strip() for resp in responses
            ]
            responses_normalized = [
                resp.replace("###", "").strip() if resp.endswith("###") else resp.strip()
                for resp in responses_normalized
            ]
        else:
            responses_normalized = responses

        if IS_VAL_PRINT:
            IS_VAL_PRINT = False
            if accelerator.is_main_process:
                logger.info(f"[VAL] response: {responses_normalized[-1]}")

        all_prefixes.extend(batch["original_prefixes"])
        all_fixed_prefixes.extend(batch["fixed_prefixes"])
        all_gen_suffixes.extend(responses_normalized)

    all_val_rewards = []
    torch.cuda.empty_cache()
    if not script_args.use_reward_api:
        for i in tqdm(
            range(0, len(all_gen_suffixes), script_args.reward_batch_size),
            desc="Calculating rewards",
            disable=not accelerator.is_main_process,
        ):
            batch_suffixes = all_gen_suffixes[i : i + script_args.reward_batch_size]
            batch_prefixes = all_prefixes[i : i + script_args.reward_batch_size]
            batch_rewards = torch.sigmoid(get_score(batch_prefixes, batch_suffixes).to(torch.float16)).detach()
            all_val_rewards.append(batch_rewards)
        avg_reward = torch.cat(all_val_rewards).mean()
    else:
        api_url = "http://127.0.0.1:2025/get_reward"
        for i in tqdm(
            range(len(all_gen_suffixes)), desc="Calculating rewards", disable=not accelerator.is_main_process
        ):
            _prefixes = all_prefixes[i]
            _suffix = all_gen_suffixes[i]
            chat = []
            for _prefix in _prefixes[:-1]:
                if _prefix.startswith("<|prompter|>"):
                    chat.append({"role": "user", "content": _prefix.replace("<|prompter|>", "")})
                elif _prefix.startswith("<|assistant|>"):
                    chat.append({"role": "assistant", "content": _prefix.replace("<|assistant|>", "")})
                else:
                    assert ValueError("Invalid prefix")
            chat.append({"role": "assistant", "content": _suffix})
            chat_data = {"chat": chat}
            response = requests.post(api_url, json=chat_data)
            if response.status_code == 200:
                reward = response.json().get("rewards", [])
                reward = torch.sigmoid(torch.tensor([reward])).item() * 100
                all_val_rewards.append(reward)
            else:
                print("Error, status code:", response.status_code)
        avg_reward = torch.FloatTensor(all_val_rewards).mean().to(torch.float16).to(device)

    tokenizer.padding_side, tokenizer.truncation_side = origin_state

    torch.cuda.empty_cache()

    accelerator.wait_for_everyone()
    gathered_avg_reward = accelerator.gather_for_metrics(avg_reward)

    return gathered_avg_reward


def get_gpu_memory():
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)

    info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
    used_memory = round(info.used / 1024**3, 2)
    total_memory = round(info.total / 1024**3, 2)
    free_memory = round(info.free / 1024**3, 2)

    return used_memory, total_memory, free_memory


if __name__ == "__main__":
    IS_TRAIN_PRINT = True
    IS_VAL_PRINT = True
    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]
    if script_args.algorithm == "reweighted_sft":
        script_args.per_device_train_batch_size = min(
            script_args.per_device_train_batch_size, script_args.pre_defined_B
        )
    if script_args.deepspeed_stage:
        script_args.gradient_accumulation_steps = 1  # TODO: https://github.com/microsoft/DeepSpeed/issues/6793

    nvidia_smi.nvmlInit()

    random.seed(script_args.seed)
    np.random.seed(script_args.seed)
    torch.manual_seed(script_args.seed)
    torch.cuda.manual_seed_all(script_args.seed)

    if script_args.deepspeed_stage is not None:
        zero2_plugin = DeepSpeedPlugin(hf_ds_config="data/ds_z2_config.json")
        zero3_plugin = DeepSpeedPlugin(hf_ds_config="data/ds_z3_config.json")
        deepspeed_plugins = {"main": zero2_plugin, "ref": zero3_plugin}
        accelerator = Accelerator(
            gradient_accumulation_steps=script_args.gradient_accumulation_steps,
            mixed_precision="bf16",
            log_with=script_args.report_to.split(","),
            project_dir=script_args.output_dir,
            deepspeed_plugins=deepspeed_plugins,
        )
        active_plugin = get_active_deepspeed_plugin(accelerator.state)
        assert active_plugin is deepspeed_plugins["main"]
        assert active_plugin is accelerator.deepspeed_plugin
    else:
        accelerator = Accelerator(
            gradient_accumulation_steps=script_args.gradient_accumulation_steps,
            mixed_precision="bf16",
            log_with=script_args.report_to.split(","),
            project_dir=script_args.output_dir,
        )

    if not script_args.use_reward_api:
        reward_accelerator = Accelerator()

    accelerator.init_trackers(
        script_args.project_name, init_kwargs={"wandb": {"save_code": True, "name": script_args.run_name}}
    )
    if accelerator.is_main_process:
        logger.info(f"Set all seeds to {script_args.seed}")

    if accelerator.is_main_process:
        logger.info(f"Loading model from {script_args.model_name_or_path}")
    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
    )

    if script_args.algorithm != "sft":
        if accelerator.is_main_process:
            logger.info(f"Loading another base model for reference model from {script_args.model_name_or_path}")
        ref_model = AutoModelForCausalLM.from_pretrained(
            script_args.model_name_or_path,
            torch_dtype=torch.float16,
            trust_remote_code=True,
        )
        for param in ref_model.parameters():
            param.requires_grad = False
        ref_model.eval()

    if not script_args.use_reward_api:
        import model_training.models.reward_model  # noqa: F401 (registers reward model for AutoModel loading)

        if accelerator.is_main_process:
            logger.info(f"Loading reward model from {script_args.reward_model_name_or_path}")
        reward_model = AutoModelForSequenceClassification.from_pretrained(script_args.reward_model_name_or_path)
        for param in reward_model.parameters():
            param.requires_grad = False
        reward_model.eval()

        reward_tokenizer = AutoTokenizer.from_pretrained(script_args.reward_model_name_or_path, trust_remote_code=True)
        reward_tokenizer.truncation_side = "left"

    tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, use_fast=False, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    if "llama-7b" in script_args.model_name_or_path.lower():
        tokenizer.chat_template = """{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{{ system_message + '### Human: ' + message['content'] }}{% elif loop.index0 % 2 == 0 %}{{ ' ### Human: ' + message['content'] }}{% else %}{{ ' ### Assistant: ' + message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' ### Assistant:' }}{% endif %}"""
    if (
        "llama-3" in script_args.model_name_or_path.lower() or "llama3" in script_args.model_name_or_path.lower()
    ) and tokenizer.chat_template is None:
        tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
    if "llama-2" in script_args.model_name_or_path.lower() and tokenizer.chat_template is None:
        tokenizer.chat_template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"

    if accelerator.is_main_process:
        logger.info(f"Tokenizer chat template: {tokenizer.chat_template}")

    if accelerator.is_main_process:
        logger.info(f"Loading train dataset from {script_args.train_file_path}")

    train_dataset = TrainDataset(
        script_args.train_file_path,
        tokenizer,
    )
    if accelerator.is_main_process:
        logger.info(f"Loaded train dataset with {len(train_dataset)} samples")

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=script_args.per_device_train_batch_size,
        shuffle=True,
        collate_fn=collate_fn,
    )
    if accelerator.is_main_process:
        logger.info(f"Loading evaluation dataset from {script_args.eval_file_path}")

    eval_dataset = EvalDataset(script_args.eval_file_path)

    eval_dataloader = DataLoader(
        eval_dataset,
        batch_size=script_args.per_device_eval_batch_size,
        collate_fn=collate_fn_eval,
    )

    optimizer = AdamW(
        model.parameters(),
        lr=script_args.learning_rate,
        weight_decay=script_args.weight_decay,
    )

    total_batch_size = (
        script_args.per_device_train_batch_size * accelerator.num_processes * script_args.gradient_accumulation_steps
    )
    num_update_steps_per_epoch = math.ceil(len(train_dataset) / total_batch_size)
    max_train_steps = script_args.num_train_epochs * num_update_steps_per_epoch

    warmup_steps = math.ceil(max_train_steps * script_args.warmup_ratio)

    lr_scheduler = get_scheduler(
        name="cosine",
        optimizer=optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=(len(train_dataloader) * script_args.num_train_epochs),
    )

    if script_args.deepspeed_stage is None:
        if script_args.algorithm != "sft":
            model, ref_model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
                model, ref_model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
            )
        else:
            model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
                model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
            )
    else:
        if script_args.algorithm != "sft":
            model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
                model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
            )
            accelerator.state.select_deepspeed_plugin("ref")
            ref_model = accelerator.prepare(ref_model)
        else:
            model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
                model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
            )
    if not script_args.use_reward_api:
        reward_model = reward_accelerator.prepare(reward_model)

    # here actual_max_train_steps, with accelerate for dealing with gradient accumulation
    actual_max_train_steps = len(train_dataloader) * script_args.num_train_epochs
    if accelerator.is_main_process:
        logger.info("***** Running training *****")
        logger.info(f"  Num examples = {len(train_dataset)}")
        logger.info(f"  Num epochs = {script_args.num_train_epochs}")
        logger.info(f"  Number of devices = {accelerator.num_processes}")
        logger.info(f"  Instantaneous batch size per device = {script_args.per_device_train_batch_size}")
        logger.info(f"  Gradient Accumulation steps = {script_args.gradient_accumulation_steps}")
        logger.info(f"  Total train batch size (w. parallel & accumulation) = {total_batch_size}")
        logger.info(f"  Total optimization steps = {max_train_steps}")
        logger.info(f"  Actual total steps with accelerate = {actual_max_train_steps}")

    script_args.eval_steps = actual_max_train_steps // 10
    if accelerator.is_main_process:
        logger.info(f"  Evaluation will occur every {script_args.eval_steps} steps.")

    best_avg_reward = None
    progress_bar = tqdm(total=actual_max_train_steps, desc="Training Progress", disable=not accelerator.is_main_process)

    completed_steps = 0
    for epoch in range(script_args.num_train_epochs):
        model.train()
        if script_args.algorithm != "sft":
            ref_model.eval()
        if not script_args.use_reward_api:
            reward_model.eval()
        for step, batch_data in enumerate(train_dataloader):
            with accelerator.accumulate(model):
                input_ids = batch_data["input_ids"]
                attention_mask = batch_data["attention_mask"]
                rewards = batch_data["rewards"]
                labels = batch_data["labels"]
                response_mask = batch_data["response_mask"]
                query_seq_len = batch_data["prefix_lens"][0]

                # print info
                if IS_TRAIN_PRINT:
                    IS_TRAIN_PRINT = False
                    if accelerator.is_main_process:
                        logger.info(f"[Train] input_ids: {input_ids[-1]}")
                        logger.info(f"[Train] Decoded input_ids: {tokenizer.decode(input_ids[-1])}")
                        logger.info(f"[Train] Attention mask: {attention_mask[-1]}")
                        logger.info(f"[Train] labels: {labels[-1]}")
                        logger.info(f"[Train] Decoded labels: {tokenizer.decode(labels[-1])}")

                if script_args.algorithm == "reweighted_sft":
                    with torch.no_grad():
                        ref_logits = ref_model(input_ids, attention_mask=attention_mask).logits
                        ref_logits = ref_logits.detach()
                        ref_resp_logits = ref_logits[:, (query_seq_len - 1) : -1, :]
                        ref_resp_log_probs = F.log_softmax(ref_resp_logits, dim=-1)
                        ref_per_token_log_probs = torch.gather(ref_resp_log_probs, 2, labels[:, :, None]).squeeze(2)
                        ref_resp_logps = (ref_per_token_log_probs * response_mask).sum(-1)

                    B = input_ids.shape[0]
                    current_inputs = torch.stack([input_ids[i] for i in range(0, B, train_dataset.pre_defined_B)])
                    current_attention_mask = torch.stack(
                        [attention_mask[i] for i in range(0, B, train_dataset.pre_defined_B)]
                    )
                    current_labels = torch.stack([labels[i] for i in range(0, B, train_dataset.pre_defined_B)])
                    current_response_mask = torch.stack(
                        [response_mask[i] for i in range(0, B, train_dataset.pre_defined_B)]
                    )

                    logits = model(current_inputs, attention_mask=current_attention_mask).logits
                    resp_logits = logits[:, (query_seq_len - 1) : -1, :]
                    resp_log_probs = F.log_softmax(resp_logits, dim=-1)
                    per_token_log_probs = torch.gather(resp_log_probs, 2, current_labels[:, :, None]).squeeze(2)
                    resp_logps = (per_token_log_probs * current_response_mask).mean(-1)

                    loss = get_reweighted_sft_loss(
                        ref_log_probs=rearrange(
                            ref_resp_logps,
                            "(b pre_defined_B) -> b pre_defined_B",
                            pre_defined_B=train_dataset.pre_defined_B,
                        ),
                        beta=script_args.beta,
                        rewards=rearrange(
                            rewards, "(b pre_defined_B) -> b pre_defined_B", pre_defined_B=train_dataset.pre_defined_B
                        ),
                        log_probs=resp_logps,
                    )
                elif script_args.algorithm == "dpo":
                    B = input_ids.shape[0]
                    with torch.no_grad():
                        ref_logits = ref_model(input_ids, attention_mask=attention_mask).logits
                        ref_logits = ref_logits.detach()
                        ref_resp_logits = ref_logits[:, (query_seq_len - 1) : -1, :]
                        ref_resp_log_probs = F.log_softmax(ref_resp_logits, dim=-1)
                        ref_per_token_log_probs = torch.gather(ref_resp_log_probs, 2, labels[:, :, None]).squeeze(2)
                        ref_resp_logps = (ref_per_token_log_probs * response_mask).sum(-1)
                        chosen_idxes = [i for i in range(0, B, 2)]
                        rejected_idxes = [i for i in range(1, B, 2)]
                        ref_chosen_logps = ref_resp_logps[chosen_idxes]
                        ref_rejected_logps = ref_resp_logps[rejected_idxes]

                    logits = model(input_ids, attention_mask=attention_mask).logits
                    resp_logits = logits[:, (query_seq_len - 1) : -1, :]
                    resp_log_probs = F.log_softmax(resp_logits, dim=-1)
                    per_token_log_probs = torch.gather(resp_log_probs, 2, labels[:, :, None]).squeeze(2)
                    resp_logps = (per_token_log_probs * response_mask).sum(-1)
                    chosen_idxes = [i for i in range(0, B, 2)]
                    rejected_idxes = [i for i in range(1, B, 2)]
                    chosen_logps = resp_logps[chosen_idxes]
                    rejected_logps = resp_logps[rejected_idxes]

                    loss, chosen_reward, rejected_reward, reward_acc, margin = get_dpo_loss(
                        chosen_logps,
                        rejected_logps,
                        ref_chosen_logps,
                        ref_rejected_logps,
                        beta=script_args.beta,
                        reference_free=False,
                    )
                elif script_args.algorithm == "sft":
                    logits = model(input_ids, attention_mask=attention_mask).logits
                    resp_logits = logits[:, (query_seq_len - 1) : -1, :]
                    resp_log_probs = F.log_softmax(resp_logits, dim=-1)
                    per_token_log_probs = torch.gather(resp_log_probs, 2, labels[:, :, None]).squeeze(2)
                    resp_logps = (per_token_log_probs * response_mask).mean(-1)
                    loss = -resp_logps.mean()  # NLL loss
                else:
                    raise ValueError(f"Unsupported algorithm: {script_args.algorithm}")

                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            losses = accelerator.gather(loss)
            if script_args.algorithm == "dpo":
                chosen_rewards = accelerator.gather(chosen_reward)
                rejected_rewards = accelerator.gather(rejected_reward)
                reward_accs = accelerator.gather(reward_acc)
                margins = accelerator.gather(margin)
                accelerator.log(
                    {
                        "loss": loss.item(),
                        "gathered_loss": losses.mean().item(),
                        "gathered_chosen_reward": chosen_rewards.mean().item(),
                        "gathered_rejected_reward": rejected_rewards.mean().item(),
                        "gathered_reward_acc": reward_accs.mean().item(),
                        "gathered_margin": margins.mean().item(),
                    },
                    step=completed_steps,
                )
            elif script_args.algorithm == "reweighted_sft":
                accelerator.log(
                    {"loss": loss.item(), "gathered_loss": losses.mean().item()},
                    step=completed_steps,
                )
            elif script_args.algorithm == "sft":
                accelerator.log(
                    {"loss": loss.item(), "gathered_loss": losses.mean().item()},
                    step=completed_steps,
                )
            accelerator.log({"lr": lr_scheduler.get_last_lr()[0]}, step=completed_steps)
            if completed_steps % script_args.logging_steps == 0:
                used_memory, total_memory, free_memory = get_gpu_memory()
                accelerator.log(
                    {
                        "used_med": used_memory,
                    },
                    step=completed_steps,
                )
                if accelerator.is_main_process:
                    if script_args.algorithm == "reweighted_sft":
                        logger.info(
                            f"Epoch {epoch + 1}/{script_args.num_train_epochs}, Step {step + 1}/{num_update_steps_per_epoch}, "
                            f"Gathered Loss: {losses.mean().item():.4f}, "
                            f"Loss: {loss.item():.4f}, "
                            f"Memory: {used_memory:.2f}/{total_memory:.2f}GB"
                        )
                    elif script_args.algorithm == "dpo":
                        logger.info(
                            f"Epoch {epoch + 1}/{script_args.num_train_epochs}, Step {step + 1}/{num_update_steps_per_epoch}, "
                            f"Gathered Loss: {losses.mean().item():.4f}, "
                            f"Loss: {loss.item():.4f}, "
                            f"Gathered Chosen Reward: {chosen_rewards.mean().item():.4f}, "
                            f"Gathered Rejected Reward: {rejected_rewards.mean().item():.4f}, "
                            f"Gathered Reward Acc: {reward_accs.mean().item():.4f}, "
                            f"Gathered Margin: {margins.mean().item():.4f}, "
                            f"Memory: {used_memory:.2f}/{total_memory:.2f}GB"
                        )
                    elif script_args.algorithm == "sft":
                        logger.info(
                            f"Epoch {epoch + 1}/{script_args.num_train_epochs}, Step {step + 1}/{num_update_steps_per_epoch}, "
                            f"Gathered Loss: {losses.mean().item():.4f}, "
                            f"Loss: {loss.item():.4f}, "
                            f"Memory: {used_memory:.2f}/{total_memory:.2f}GB"
                        )
            # evaluation
            if script_args.do_eval:
                if completed_steps % script_args.eval_steps == 0 or completed_steps - 1 == actual_max_train_steps:
                    if script_args.eval_first or completed_steps != 0:
                        model.eval()
                        if not script_args.use_reward_api:
                            gathered_avg_reward = evaluate_on_validation(
                                script_args, eval_dataloader, model, tokenizer, reward_model, reward_tokenizer
                            )
                        else:
                            gathered_avg_reward = evaluate_on_validation(script_args, eval_dataloader, model, tokenizer)

                        avg_reward = gathered_avg_reward.mean()

                        accelerator.log({"validation_reward": avg_reward}, step=completed_steps)

                        if accelerator.is_main_process:
                            logger.info(f"gathered_avg_reward: {gathered_avg_reward}")
                            logger.info(f"Validation reward at step {completed_steps}: {avg_reward:.4f}")
                            if completed_steps != 0:
                                if best_avg_reward is None or avg_reward > best_avg_reward:
                                    best_avg_reward = avg_reward
                                    if script_args.save_best:
                                        output_dir = os.path.join(script_args.output_dir, f"best_model")
                                        os.makedirs(output_dir, exist_ok=True)
                                        accelerator.unwrap_model(model).save_pretrained(output_dir)
                                        try:
                                            tokenizer.save_pretrained(output_dir)
                                        except:
                                            print("Failed to save tokenizer")
                                    logger.info(
                                        f"New best model saved at step {completed_steps} with avg_reward {avg_reward:.4f}"
                                    )
                                else:
                                    logger.info(
                                        f"Average reward of {avg_reward:.4f} on is not better than previous best of {best_avg_reward:.4f}"
                                    )
                        model.train()

            completed_steps += 1
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")
            progress_bar.update(1)

    progress_bar.close()

    if accelerator.is_main_process:
        if script_args.save_last:
            final_output_dir = os.path.join(script_args.output_dir, "final_model")
            os.makedirs(final_output_dir, exist_ok=True)
            accelerator.unwrap_model(model).save_pretrained(final_output_dir)
            try:
                tokenizer.save_pretrained(final_output_dir)
            except:
                print("Failed to save tokenizer")
            logger.info(f"Final model saved at {final_output_dir}")
        logger.info("Training completed.")
        accelerator.end_training()
