import argparse
import torch
import numpy as np
from datasets import load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from trl.trainer import dpo_trainer
from trl.trainer.utils import pad
from trl import DPOConfig, DPOTrainer
from utils import create_online_optimizer_and_scheduler, dpo_tokenize_row
from torch.utils.data import DataLoader
from tqdm import tqdm

class DARCollator(dpo_trainer.PreferenceCollator):
    """Custom collator for DAR training."""

    def torch_call(self, examples):
        prompt_input_ids = [torch.tensor(example["prompt_input_ids"]) for example in examples]
        prompt_attention_mask = [torch.ones_like(input_ids) for input_ids in prompt_input_ids]
        completion_input_ids = [torch.tensor(example["completion_input_ids"]) for example in examples]
        completion_attention_mask = [torch.ones_like(input_ids) for input_ids in completion_input_ids]

        output = {}
        output["prompt_input_ids"] = pad(prompt_input_ids, padding_value=self.pad_token_id, padding_side="left")
        output["prompt_attention_mask"] = pad(prompt_attention_mask, padding_value=0, padding_side="left")
        output["completion_input_ids"] = pad(completion_input_ids, padding_value=self.pad_token_id)
        output["completion_attention_mask"] = pad(completion_attention_mask, padding_value=0)

        for key in examples[0].keys():
            if key not in ["prompt_input_ids", "completion_input_ids"]:
                output[key] = [example[key] for example in examples]
        return output

def dar_tokenize_row(features, tokenizer, max_prompt_length, max_completion_length, add_special_tokens):
    """
    Tokenizes and processes a single input feature using the provided tokenizer.
    """
    if not features["end_with_eos"]:
        features["completion"] = features["completion"][:-len("<|im_end|>\n")]

    prompt_input_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
    completion_input_ids = tokenizer(features["completion"], add_special_tokens=False)["input_ids"]

    if max_prompt_length is not None:
        prompt_input_ids = prompt_input_ids[-max_prompt_length:]
    if max_completion_length is not None:
        completion_input_ids = completion_input_ids[:max_completion_length]

    return {
        "prompt_input_ids": prompt_input_ids,
        "completion_input_ids": completion_input_ids,
    }

def dar_get_train_dataloader(self) -> DataLoader:
    """
    Returns the training DataLoader and precomputes reference and mu log probabilities.
    """
    dataloader_params = {
        "batch_size": self.args.per_device_train_batch_size,
        "collate_fn": self.data_collator,
        "num_workers": self.args.dataloader_num_workers,
        "pin_memory": self.args.dataloader_pin_memory,
        "shuffle": False,
    }
    data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))

    # Precompute reference log probabilities
    reference_logps = []
    for padded_batch in tqdm(data_loader, desc="Train dataset reference log probs"):
        reference_logp = self.compute_log_probs(padded_batch, self.ref_model)
        reference_logp = self.accelerator.gather_for_metrics(reference_logp)
        reference_logps.append(reference_logp.cpu())
        torch.cuda.empty_cache()
        self.accelerator.free_memory()
    all_reference_logps = torch.cat(reference_logps).float().numpy()
    self.train_dataset = self.train_dataset.add_column("reference_logps", all_reference_logps)

    # Precompute mu log probabilities
    mu_logps = []
    for padded_batch in tqdm(data_loader, desc="Train dataset mu log probs"):
        mu_logp = self.compute_log_probs(padded_batch, self.model)
        mu_logp = self.accelerator.gather_for_metrics(mu_logp)
        mu_logps.append(mu_logp.cpu())
        torch.cuda.empty_cache()
        self.accelerator.free_memory()
    all_mu_logps = torch.cat(mu_logps).float().numpy()
    self.train_dataset = self.train_dataset.add_column("mu_logps", all_mu_logps)

    self._precomputed_train_ref_log_probs = True
    return super(DPOTrainer, self).get_train_dataloader()

def dar_compute_log_probs(self, padded_batch, model):
    """
    Computes log probabilities for a single padded batch.
    """
    with torch.no_grad():
        logps = self.dar_forward(model, padded_batch)[0]
    return logps

def dar_forward(self, model, padded_batch):
    """
    Forward pass for DAR: concatenates prompt and completion, computes log probabilities.
    """
    prompt_input_ids = padded_batch["prompt_input_ids"]
    prompt_attention_mask = padded_batch["prompt_attention_mask"]
    completion_input_ids = padded_batch["completion_input_ids"]
    completion_attention_mask = padded_batch["completion_attention_mask"]

    input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
    attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
    loss_mask = torch.cat(
        (torch.zeros_like(prompt_attention_mask), completion_attention_mask),
        dim=1,
    )

    # Left flush to reduce memory usage
    for i in range(attention_mask.size(0)):
        first_one_idx = torch.nonzero(attention_mask[i])[0].item()
        input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx)
        attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx)
        loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx)

    # Remove columns after first all-zero column
    empty_cols = torch.sum(attention_mask, dim=0) == 0
    first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1) + 1
    input_ids = input_ids[:, : first_empty_col - 1]
    attention_mask = attention_mask[:, : first_empty_col - 1]
    loss_mask = loss_mask[:, : first_empty_col - 1]

    # Truncate right
    if self.args.max_length is not None:
        input_ids = input_ids[:, : self.args.max_length]
        attention_mask = attention_mask[:, : self.args.max_length]
        loss_mask = loss_mask[:, : self.args.max_length]

    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits[:, :-1, :]
    labels = input_ids[:, 1:].clone()
    loss_mask = loss_mask[:, 1:].bool()

    labels[~loss_mask] = 0  # dummy token
    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
    per_token_logps[~loss_mask] = 0
    all_logps = per_token_logps.sum(-1)

    return all_logps, logits

def dar_compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
    """
    Computes the DAR loss and logs metrics.
    """
    loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
    loss = loss.to(self.args.device)
    self.store_metrics(metrics, train_eval="train")
    if return_outputs:
        return (loss, metrics)
    return loss

def dar_get_batch_loss_metrics(self, model, batch, train_eval="train"):
    """
    Compute the DAR loss and other metrics for the given batch.
    """
    metrics = {}
    policy_logps = self.dar_forward(model, batch)[0]

    reference_logps = torch.tensor(batch["reference_logps"], dtype=torch.float32)
    mu_logps = torch.tensor(batch["mu_logps"], dtype=torch.float32)
    rewards = torch.tensor(batch['rewards'], dtype=torch.float32)
    r_norm = torch.tensor(batch['r_norm'], dtype=torch.float32)
    values = torch.tensor(batch['values'], dtype=torch.float32)
    adv = torch.tensor(batch['adv'], dtype=torch.float32)
    adv_norm = torch.tensor(batch['adv_norm'], dtype=torch.float32)
    length_gen = torch.tensor(batch['length_gen'], dtype=torch.float32)

    # DAR loss calculation
    adv_weights = torch.exp(adv_norm.to(self.accelerator.device) / (self.beta_mu + self.beta_ref))
    adv_weights = torch.clamp(adv_weights, max=1e30)

    reg_weights = torch.exp(reference_logps.to(self.accelerator.device) - mu_logps.to(self.accelerator.device))
    reg_weights = reg_weights ** (self.beta_ref / (self.beta_mu + self.beta_ref))

    weights = adv_weights * reg_weights

    if self.weight_clip != 0:
        weights = torch.clamp(weights, max=self.weight_clip)
    if self.weight_norm:
        norm = np.exp(self.adv_clip / (self.beta_mu + self.beta_ref))
        weights = weights / norm

    losses = -weights * policy_logps.to(self.accelerator.device)
    if self.len_norm:
        losses = losses / length_gen.to(self.accelerator.device)

    prefix = "eval_" if train_eval == "eval" else ""
    metrics[f"{prefix}rewards/mean"] = rewards.mean().cpu()
    metrics[f"{prefix}r_norm/mean"] = r_norm.mean().cpu()
    metrics[f"{prefix}values/mean"] = values.mean().cpu()
    metrics[f"{prefix}adv/mean"] = adv.mean().cpu()
    metrics[f"{prefix}adv_norm/mean"] = adv_norm.mean().cpu()
    metrics[f"{prefix}mu_logps/mean"] = mu_logps.mean().cpu()
    metrics[f"{prefix}ref_logps/mean"] = reference_logps.mean().cpu()
    metrics[f"{prefix}reg_weights/mean"] = reg_weights.detach().mean().cpu()
    metrics[f"{prefix}adv_weights/mean"] = adv_weights.detach().mean().cpu()
    metrics[f"{prefix}length_gen/mean"] = length_gen.mean().cpu()
    metrics[f"{prefix}weights/mean"] = weights.detach().mean().cpu()
    metrics[f"{prefix}weights/max"] = weights.detach().max().cpu()
    metrics[f"{prefix}weights/min"] = weights.detach().min().cpu()

    return losses.mean(), metrics

def dar(args):
    """
    Main training loop for Direct Advantage Regularization (DAR).
    """
    set_seed(args.seed)
    ds = load_from_disk(f"{args.working_path}/checkpoint-{args.checkpoint}/rewards")
    output_dir = f"{args.working_path}/checkpoint-{args.checkpoint+int(len(ds))}"

    if "helpsteer" in args.working_path:
        max_length = 2000
        max_prompt_length = 1000
    else:
        max_length = 512
        max_prompt_length = 512 - 64 if "tldr" in args.working_path else 512 - 256

    training_args = DPOConfig(
        beta=0.1,
        loss_type="sigmoid",
        output_dir=output_dir,
        per_device_train_batch_size=args.pdtbs,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        logging_steps=1,
        save_strategy="no",
        report_to="tensorboard",
        logging_first_step=True,
        num_train_epochs=args.epoch,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs=dict(use_reentrant=False),
        seed=args.seed,
        max_length=max_length,
        max_prompt_length=max_prompt_length,
        remove_unused_columns=False,
        max_grad_norm=args.max_grad_norm,
    )

    model_kwargs = dict(
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
        use_cache=False,
    )
    if args.checkpoint != 0:
        args.model_name = f"{args.working_path}/checkpoint-{args.checkpoint}"

    model = AutoModelForCausalLM.from_pretrained(args.model_name, **model_kwargs)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    if 'llama' in args.model_name:
        tokenizer.pad_token = tokenizer.eos_token
    ref_model = AutoModelForCausalLM.from_pretrained(args.ref_model_name, **model_kwargs)

    optimizer, lr_scheduler = create_online_optimizer_and_scheduler(model, args, args.weight_decay)

    if args.sanity_check:
        for key in ds:
            ds[key] = ds[key].select(range(50))

    def process(row):
        row["prompt"] = [{"content": row['query'], "role": "user"}]
        row["completion"] = [{"content": row['response'], "role": "assistant"}]
        return row

    ds = ds.map(process, num_proc=1, load_from_cache_file=False)

    rewards = ds["rewards"]
    if args.eos_penalty != "no":
        rewards = [r if e else int(args.eos_penalty) for r, e in zip(ds["rewards"], ds["end_with_eos"])]
    r_norm = torch.tensor(rewards, dtype=torch.float32)
    if args.r_norm == "min_max":
        r_norm = (r_norm - 1) / (10 - 1)
    elif args.r_norm == "z_score":
        r_norm = (r_norm - 1) / (10 - 1) * 2 - 1
    elif args.r_norm == "mean":
        r_norm = (r_norm - r_norm.mean()) / (r_norm.std() + 1e-8)
    elif args.r_norm == "helpsteer":
        r_norm = (r_norm + 30) / (32) * 2 - 1

    if args.n_shot == 1:
        values = [0] * len(rewards)
    else:
        values = [r_norm[i:i+args.n_shot].mean().item() for i in range(0, len(r_norm), args.n_shot)]
        values = [v for v in values for _ in range(args.n_shot)]
    values = torch.tensor(values, dtype=torch.float32)

    adv = r_norm - values
    if args.adv_norm:
        adv_norm = (adv - adv.mean()) / (adv.std() + 1e-8)
    else:
        adv_norm = adv
    if args.adv_clip != 0:
        adv_norm = torch.clamp(adv_norm, max=args.adv_clip)

    ds = ds.add_column(name="r_norm", column=r_norm.tolist())
    ds = ds.add_column(name="values", column=values.tolist())
    ds = ds.add_column(name="adv", column=adv.tolist())
    ds = ds.add_column(name="adv_norm", column=adv_norm.tolist())

    # === Training ===
    DPOTrainer.tokenize_row = staticmethod(dar_tokenize_row)
    DPOTrainer.compute_log_probs = dar_compute_log_probs
    DPOTrainer.compute_loss = dar_compute_loss
    DPOTrainer.get_batch_loss_metrics = dar_get_batch_loss_metrics
    DPOTrainer.dar_forward = dar_forward
    DPOTrainer.get_train_dataloader = dar_get_train_dataloader.__get__(None, DPOTrainer)

    trainer = DPOTrainer(
        model,
        ref_model,
        args=training_args,
        train_dataset=ds,
        tokenizer=tokenizer,
        optimizers=(optimizer, lr_scheduler),
        data_collator=DARCollator(pad_token_id=tokenizer.pad_token_id)
    )
    if args.checkpoint != 0:
        trainer._load_optimizer_and_scheduler(args.model_name)

    trainer.beta_mu = args.beta_mu
    trainer.beta_ref = args.beta_ref
    trainer.len_norm = args.len_norm
    trainer.adv_clip = args.adv_clip
    trainer.weight_norm = args.weight_norm
    trainer.weight_clip = args.weight_clip

    trainer.train()
    trainer.save_model(output_dir)
    trainer._save_optimizer_and_scheduler(output_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a model with Direct Advantage Regularization (DAR)")
    parser.add_argument("--model_name", type=str, default="Qwen/Qwen2-0.5B", help="Model name or path.")
    parser.add_argument("--ref_model_name", type=str, default="Qwen/Qwen2-0.5B", help="Reference model name or path.")
    parser.add_argument("--working_path", type=str, default="", help="Working directory for checkpoints and data.")
    parser.add_argument("--checkpoint", type=int, default=0, help="Checkpoint number to resume from.")
    parser.add_argument("--pdtbs", type=int, default=32, help="Per-device train batch size.")
    parser.add_argument("--pdebs", type=int, default=32, help="Per-device eval batch size.")
    parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
    parser.add_argument("--gradient_accumulation_steps", "-gas", type=int, default=1, help="Gradient accumulation steps.")
    parser.add_argument("--epoch", type=int, default=3, help="Number of training epochs.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed.")
    parser.add_argument("--num_gpu", type=int, default=4, help="Number of GPUs to use.")
    parser.add_argument("--max_online_steps", type=int, default=int(1e4), help="Maximum online steps.")
    parser.add_argument("--beta_ref", type=float, default=0.05, help="Beta for reference model regularization.")
    parser.add_argument("--beta_mu", type=float, default=0.1, help="Beta for policy model regularization.")
    parser.add_argument("--len_norm", action="store_true", help="Enable length normalization.")
    parser.add_argument("--n_shot", type=int, default=8, help="Number of shots for value estimation.")
    parser.add_argument("--eos_penalty", type=str, default="no", help="EOS penalty value or 'no'.")
    parser.add_argument("--max_grad_norm", type=float, default=1, help="Max gradient norm.")
    parser.add_argument("--weight_clip", type=float, default=0, help="Weight clipping value.")
    parser.add_argument("--adv_clip", type=float, default=5, help="Advantage clipping value.")
    parser.add_argument("--r_norm", type=str, default="z_score", help="Reward normalization method.")
    parser.add_argument("--adv_norm", action="store_true", help="Enable advantage normalization.")
    parser.add_argument("--sanity_check", action="store_true", help="Run a sanity check with a small dataset.")
    parser.add_argument("--cos_scheduler", action="store_true", help="Use cosine learning rate scheduler.")
    parser.add_argument("--weight_decay", type=float, default=0, help="Weight decay.")
    parser.add_argument("--weight_norm", action="store_true", help="Enable weight normalization.")

    args = parser.parse_args()
    dar(args)