import json
import math
import os
import random
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Optional

# Third-party library imports
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model
from torch.utils.data import DataLoader, RandomSampler
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, get_cosine_schedule_with_warmup

# Local application imports
from awdpo.generator import VLLM_Generator
from awdpo.rewards import *
from awdpo.utils import (
    convert_messages_to_chatml,
    create_few_shot_prompt,
    create_no_shot_prompt,
)

# Environment setup (best practice to place after imports)
os.environ["VLLM_USE_V1"] = "0"

class AWDPO_MLE_Trainer:
    """
    Manages the training loop for the Advantage-Weighted Direct Preference
    Optimization (AWDPO) method, including data generation, loss calculation,
    and model updates.
    """
    def __init__(self, model, tokenizer, reward_funcs, config, train_dataset, few_shot_column, system_prompt, beta = 0.1):
        """
        Initializes the trainer with the model, tokenizer, and configuration.

        Args:
            model: The base language model to be trained.
            tokenizer: The tokenizer for the model.
            config: An object containing all training hyperparameters.
            ...
        """
        self.model = model.to(config.device)
        self.tokenizer = tokenizer
        self.reward_funcs = reward_funcs
        self.config = config
        self.train_dataset = train_dataset

        self.step = 0
        self._metrics = defaultdict(list)
        self.scaler = torch.cuda.amp.GradScaler() if config.device.startswith("cuda") else None

        self.few_shot_column = few_shot_column

        self.SYSTEM_PROMPT = system_prompt

        self._last_loaded_step = -1

        self.beta = beta

        self.lambda_value = 0.05

        self.failure_counts = 0

        if self.config.use_vllm:
          self.llm = VLLM_Generator(model, self.config)
          model_id = model.name_or_path

        if self.config.use_reference_model:
              self.ref_model = AutoModelForCausalLM.from_pretrained(
                  model_id,
                  use_cache=False
              ).to(self.config.device).eval()
        else:
              self.ref_model = None

        if self.config.use_lora:
              self.model.gradient_checkpointing_enable()
              self.peft_config = LoraConfig(
                  r             = self.config.lora_rank,       # rank of adapters
                  lora_alpha    = self.config.lora_alpha,
                  target_modules= [
                      "q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",
                  ],  # or [".*"] to adapter all linear layers
                  lora_dropout  = self.config.lora_dropout,
                  bias          = "none",
                  task_type     = "CAUSAL_LM"
              )

              self.model = get_peft_model(self.model, self.peft_config)
              self.model.print_trainable_parameters()
        else:
              self.model.gradient_checkpointing_enable()

        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
        total_steps = (len(train_dataset) // config.per_device_train_batch_size) * config.num_train_epochs
        self.scheduler = get_cosine_schedule_with_warmup(self.optimizer,
                                                         num_warmup_steps=config.warmup_steps,
                                                         num_training_steps=total_steps)

        self.ema_no_shot = None
        self.ema_alpha   = 0.1
        self.steps_since_last_improve = 0
        self.patience   = 10
        self.drop_threshold = 0.05

    def _slice_and_gather(self, logits: torch.Tensor, full_ids: torch.Tensor, num_logits_to_keep: int) -> torch.Tensor:
        """
        Given:
          - logits: [batch, seq_len, vocab_size] (the model outputs[:, :-1, :])
          - full_ids: [batch, seq_len + 1] the input_ids that produced those logits
          - num_logits_to_keep: how many of the last positions correspond to the generation

        Returns:
          - token_log_probs: [batch, num_logits_to_keep] log-prob of each generated token
        """
        # 1) take only the last num_logits_to_keep positions
        logits_slice = logits[:, -num_logits_to_keep:, :]  # [batch, gen_len, V]
        # 2) get the corresponding token IDs
        token_ids = full_ids[:, -num_logits_to_keep:]     # [batch, gen_len]
        # 3) convert to log-probs
        log_probs = torch.log_softmax(logits_slice, dim=-1)  # [batch, gen_len, V]
        # 4) gather the log-prob for each token
        token_log_probs = log_probs.gather(dim=-1, index=token_ids.unsqueeze(-1)).squeeze(-1)
        return token_log_probs  # [batch, gen_len]

    def evaluate_rewards(self, prompts, completions, gt_answers):
        """
        Calculates and aggregates rewards for a batch of model completions.

        This method iterates through a list of reward functions stored in
        `self.reward_funcs`, applies each function to the completions, and
        returns both a combined total reward for each completion and a
        dictionary of the individual reward components.

        Args:
            prompts (list[str]): The list of prompts that generated the completions.
            completions (list[str]): The list of model-generated text to be evaluated.
            gt_answers (list[str]): The list of ground-truth answers corresponding
                to the prompts.

        Returns:
            tuple[list[float], dict[str, list[float]]]: A tuple containing:
                - combined_rewards (list[float]): A list where each element is the
                  sum of all reward components for the corresponding completion.
                - rewards_dict (dict[str, list[float]]): A dictionary mapping each
                  reward function's name to a list of its calculated scores for
                  the batch.
        """
        rewards_dict = {}

        for func in self.reward_funcs:
            if func.__name__ in ["accuracy_reward", "xmlcount_reward", "reasoning_reward"]:
                r = func(prompts, completions, gt_answers)
            else:
                r = func(completions)
            rewards_dict[func.__name__] = r

        combined_rewards = [sum(rewards_dict[func_name][i] for func_name in rewards_dict)
                            for i in range(len(completions))]

        return combined_rewards, rewards_dict

    def _compute_avg_response_length(self,
                                     prompt_input_ids: torch.LongTensor,
                                     full_attention_mask: torch.LongTensor
                                    ):
        """
        Given:
          - prompt_input_ids     : [batch, prompt_len]
          - full_attention_mask  : [batch, prompt_len + gen_len]
        Returns:
          - average number of generated tokens per example
        """
        prompt_len = prompt_input_ids.size(1)
        # we only care about the generation portion of the mask:
        gen_mask = full_attention_mask[:, prompt_len:]        # [batch, gen_len]
        # sum 1's → gives length of each generated sequence
        lengths = gen_mask.sum(dim=1).float()                 # [batch]
        return lengths.mean().item()

    def create_attention_mask_with_eos_handling(self, input_ids, completion_ids, attention_mask=None):
        """
        Creates an attention mask that:
        1. Preserves the original attention mask for input_ids
        2. Sets mask to 1 for all tokens up to and including the first EOS token
        3. Sets mask to 0 for all tokens after the first EOS token

        Args:
            input_ids: Tensor of shape [batch_size, input_length] - the prompt ids
            completion_ids: Tensor of shape [batch_size, completion_length] - the generated completion ids
            attention_mask: Optional tensor of shape [batch_size, input_length] - existing attention mask for input_ids
                           If None, creates a mask of all 1s for input_ids

        Returns:
            Tensor of shape [batch_size, input_length + completion_length] - the full attention mask
        """
        batch_size, completion_length = completion_ids.size()
        device = completion_ids.device

        # Create or use existing attention mask for input
        if attention_mask is None:
            input_attention = torch.ones_like(input_ids, dtype=torch.float)
        else:
            input_attention = attention_mask

        # Find first EOS token in each sequence
        is_eos = (completion_ids == self.tokenizer.eos_token_id)
        eos_idx = torch.full((batch_size,), completion_length, dtype=torch.long, device=device)

        for i in range(batch_size):
            nonzero = torch.nonzero(is_eos[i], as_tuple=False)
            if nonzero.numel() > 0:
                eos_idx[i] = nonzero[0, 0]

        # Create mask: 1 up to and including first EOS, 0 after that
        sequence_indices = torch.arange(completion_length, device=device).unsqueeze(0).expand(batch_size, -1)
        completion_attention = (sequence_indices <= eos_idx.unsqueeze(1)).float()

        # Concatenate input and completion attention masks
        full_attention_mask = torch.cat([input_attention, completion_attention], dim=1)

        return full_attention_mask

    def pad(
        self,
        tensors: list[torch.Tensor],
        padding_value: int = 0,
        padding_side: str = "right",
        pad_to_multiple_of: Optional[int] = None,
    ) -> torch.Tensor:
        """
        Pads a list of tensors to the same shape along the first dimension.

        Args:
            tensors (`list[torch.Tensor]`):
                List of input tensors to pad.
            padding_value (`int`):
                Value to use for padding. Default is 0.
            padding_side (`str`):
                Side on which to add padding. Must be 'left' or 'right'. Default is 'right'.
            pad_to_multiple_of (`int`, *optional*, defaults to `None`):
                If set will pad the sequence to a multiple of the provided value.

        Returns:
            `torch.Tensor`:
                A single tensor containing the padded tensors.

        Examples:
            >>> import torch
            >>> pad([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
            tensor([[1, 2, 3],
                    [4, 5, 0]])
            >>> pad([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6]])])
            tensor([[[1, 2],
                    [3, 4]],

                    [[5, 6],
                    [0, 0]]])
        """
        # Determine the maximum shape for each dimension
        output_shape = np.max([t.shape for t in tensors], 0).tolist()

        # Apply pad_to_multiple_of to the first (sequence) dimension
        if pad_to_multiple_of is not None:
            remainder = output_shape[0] % pad_to_multiple_of
            if remainder != 0:
                output_shape[0] += pad_to_multiple_of - remainder

        # Create an output tensor filled with the padding value
        output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device)

        for i, t in enumerate(tensors):
            if padding_side == "left":
                seq_start = output_shape[0] - t.shape[0]
            elif padding_side == "right":
                seq_start = 0
            else:
                raise ValueError("padding_side must be 'left' or 'right'")

            # Define the slices
            seq_slice = slice(seq_start, seq_start + t.shape[0])
            slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:])
            output[i][slices] = t

        return output

    def save_responses(self, current_batch):
        """
        Saves a batch of model responses to a timestamped JSON file.

        The filename is generated based on the current timestamp and training
        step to ensure uniqueness. After saving, the input list is cleared
        in-place to free up memory.

        Args:
            current_batch (list[dict]): A list of dictionaries, where each
                dictionary contains details about a single model generation
                (e.g., prompt, response, rewards). This list will be cleared
                by the function.
        """

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"response_log_{timestamp}_awdpo.json" if self.step is None else f"response_log_step_{self.step}_awdpo.json"

        output_dir = Path(f"{self.config.output_dir}/generated_responses")
        output_dir.mkdir(exist_ok=True, parents=True)

        with open(output_dir / filename, 'w') as f:
            json.dump(current_batch, f, indent=2)

        print(f"Saved {len(current_batch)} responses to {output_dir / filename}")

        current_batch.clear()

    def _generate(self, few_shot_input_ids, no_shot_input_ids, few_shot_attention_mask, no_shot_attention_mask):
        """
        Generates paired few-shot and zero-shot completions for a batch.

        This method handles generation using either the highly optimized vLLM engine
        (if `self.config.use_vllm` is True) or the standard Hugging Face `generate`
        method. It ensures that trained LoRA adapters are synced to the vLLM
        engine before generation.

        Args:
            few_shot_input_ids (torch.Tensor): Tokenized prompts that include 
                few-shot examples.
            no_shot_input_ids (torch.Tensor): Tokenized prompts without few-shot 
                examples.
            few_shot_attention_mask (torch.Tensor): Attention mask for the 
                few-shot prompts.
            no_shot_attention_mask (torch.Tensor): Attention mask for the 
                zero-shot prompts.

        Returns:
            tuple: A tuple containing four elements:
                - few_shot_completions (list[str]): Decoded text of the few-shot responses.
                - no_shot_completions (list[str]): Decoded text of the zero-shot responses.
                - few_shot_completion_ids (torch.Tensor): Token IDs for the few-shot completions.
                - no_shot_completion_ids (torch.Tensor): Token IDs for the zero-shot completions.
        """
        if self.config.use_vllm:

            if self.step != self._last_loaded_step:
              self.llm.move_model_to_vllm(self.model)
              self._last_loaded_step = self.step

            few_shot_output = self.llm.generate(prompt_token_ids=[ids.tolist() for ids in few_shot_input_ids])

            no_shot_output = self.llm.generate(prompt_token_ids=[ids.tolist() for ids in no_shot_input_ids])

            few_shot_token_ids = [output.token_ids for request in few_shot_output for output in request.outputs]
            no_shot_token_ids = [output.token_ids for request in no_shot_output for output in request.outputs]

            few_shot_completion_ids = [torch.tensor(ids, device=self.config.device, dtype=torch.long) for ids in few_shot_token_ids]
            few_shot_completion_ids = self.pad(few_shot_completion_ids, padding_value=self.tokenizer.pad_token_id)

            no_shot_completion_ids = [torch.tensor(ids, device=self.config.device, dtype=torch.long) for ids in no_shot_token_ids]
            no_shot_completion_ids = self.pad(no_shot_completion_ids, padding_value=self.tokenizer.pad_token_id)


            few_shot_completions = [self.tokenizer.decode(seq, skip_special_tokens=True)
                                  for seq in few_shot_token_ids]

            no_shot_completions = [self.tokenizer.decode(seq, skip_special_tokens=True)
                                  for seq in no_shot_token_ids]

        else:
              # Generate completions using the few-shot prompt
              few_shot_output = self.model.generate(
                  few_shot_input_ids,
                  attention_mask=few_shot_attention_mask,
                  max_new_tokens=self.config.max_completion_length,
                  do_sample=True,
                  temperature=self.config.temperature,
                  num_return_sequences=self.config.num_generations,
                  pad_token_id=self.tokenizer.eos_token_id,
                  use_cache=False
              )
              few_shot_output = few_shot_output.to(self.config.device)
              # Extract just the completion part (without the prompt)
              few_shot_completion_ids = few_shot_output[:, few_shot_input_ids.shape[1]:]
              # Decode completions
              few_shot_completions = [self.tokenizer.decode(seq[few_shot_input_ids.shape[1]:], skip_special_tokens=True)
                            for seq in few_shot_output]

              empty_idxs = [i for i, text in enumerate(few_shot_completions) if len(text) == 0]

              #print("Few-shot sample:")
              #print(few_shot_completions[0])
              #print(50 * '#')

              no_shot_output = self.model.generate(
                  no_shot_input_ids,
                  attention_mask=no_shot_attention_mask,
                  max_new_tokens=self.config.max_completion_length,
                  do_sample=True,
                  temperature=self.config.temperature,
                  num_return_sequences=self.config.num_generations,
                  pad_token_id=self.tokenizer.eos_token_id,
                  use_cache=False
              )
              no_shot_output = no_shot_output.to(self.config.device)
              no_shot_completion_ids = no_shot_output[:, no_shot_input_ids.shape[1]:]
              no_shot_completions = [self.tokenizer.decode(seq[no_shot_input_ids.shape[1]:], skip_special_tokens = True)
                            for seq in no_shot_output]


              #print("No-shot sample:")
              #print(no_shot_completions[0])
              #print(50 * '#')
        return few_shot_completions, no_shot_completions, few_shot_completion_ids, no_shot_completion_ids

    def truncate_few_shot_responses(self, few_shot_completions):
        """
        Cleans and re-tokenizes few-shot completions by truncating them.

        This method truncates each completion string at the `</answer>` tag
        to remove any extraneous text generated after the final answer. It then
        appends the end-of-message token and re-tokenizes the cleaned strings.

        Args:
            few_shot_completions (list[str]): A list of raw model-generated 
                completions.

        Returns:
            tuple[list[str], torch.Tensor]: A tuple containing:
                - The list of truncated completion strings.
                - A padded tensor of the new, truncated completion token IDs.
        """
        truncated_few_shot_completions = []
        for completion in few_shot_completions:
            if "</answer" in completion:
                parts = completion.split("</answer>")
                truncated = parts[0]
                if len(parts) > 1:  # Found the tag
                    truncated += "</answer>"  # Include the end tag
                    truncated += "\n\n<|im_end|>"  # Add the end token
                truncated_few_shot_completions.append(truncated)
            else:
                truncated_few_shot_completions.append(completion)

        truncated_few_shot_completion_ids = self.tokenizer(truncated_few_shot_completions,
                                      return_tensors="pt",
                                      padding=True).input_ids.to(self.config.device)

        return truncated_few_shot_completions, truncated_few_shot_completion_ids

    def _fill_successful_batch(self, idx_iter):
        """
        Builds a training batch by generating until enough successful examples are found.

        This method iterates through the dataset, generating `G` few-shot/zero-shot
        response pairs for each question. A question is only added to the batch if
        at least one of its few-shot generations is deemed correct by the reward
        function. This ensures that the MLE anchor loss always has valid targets.
        The process continues until a full batch of the configured size is collected.

        Args:
            idx_iter (iter): An iterator over the training dataset indices.

        Returns:
            tuple: A large tuple containing all necessary data for a training step,
                   including questions, answers, prompts, generated completions (both
                   text and token IDs), and filtered lists of correct completions
                   for the MLE loss.
        """
        bs = self.config.per_device_train_batch_size
        G  = self.config.num_generations
        questions, gt_answers = [], []
        few_msgs, no_msgs = [], []
        few_shot_input_ids, few_shot_attention_mask = [], []
        no_shot_input_ids,  no_shot_attention_mask  = [], []
        few_shot_completions, no_shot_completions, truncated_few_shot_completions = [], [], []
        few_shot_completion_ids, no_shot_completion_ids, truncated_few_shot_completion_ids = [], [], []
        all_rewards = []

        max_failures  = 100

        while len(questions) < bs:
            if self.failure_counts >= max_failures:
                raise ValueError(f"Reached maximum number of failures: {max_failures}")
            try:
                idx = next(idx_iter)              # pull one index
            except StopIteration:
                # If we ran out mid-epoch, reshuffle & keep going (or break)
                sampler  = RandomSampler(self.train_dataset, replacement=False)
                idx_iter = iter(sampler)
                idx       = next(idx_iter)

            ex = self.train_dataset[idx]
            q  = ex["question"]
            gt = ex["answer"]

            few_shot_examples = ex.get(self.few_shot_column, [])

            msgs_few = create_few_shot_prompt(q, few_shot_examples,
                                                   system_prompt=self.SYSTEM_PROMPT,
                                                   num_examples = random.randint(2, min(4, len(few_shot_examples))))

            msgs_no = create_no_shot_prompt(q, system_prompt=self.SYSTEM_PROMPT)

            chatml_few = convert_messages_to_chatml(msgs_few)
            chatml_no  = convert_messages_to_chatml(msgs_no)

            tok_few = self.tokenizer([chatml_few], return_tensors="pt",
                                     padding=True, truncation=True,
                                     max_length=self.config.max_prompt_length
                                    ).to(self.config.device)
            tok_no  = self.tokenizer([chatml_no], return_tensors="pt",
                                     padding=True, truncation=True,
                                     max_length=self.config.max_prompt_length
                                    ).to(self.config.device)


            few_completions, no_completions, \
            few_completion_ids, no_completion_ids = self._generate(
                        tok_few.input_ids,
                        tok_no.input_ids,
                        tok_few.attention_mask,
                        tok_no.attention_mask
                    )

            truncated_few_completions, truncated_few_completion_ids = self.truncate_few_shot_responses(few_completions)

            rewards_few = accuracy_reward([q]*self.config.num_generations,
                                         truncated_few_completions, [gt]*self.config.num_generations)

            print("rewards few:", rewards_few)

            mask = [r >= 1.0 for r in rewards_few]
            idxs = [i for i, keep in enumerate(mask) if keep]

            if any(r >= 1.0 for r in rewards_few):

                self.failure_counts = 0

                questions.append(q)
                gt_answers.append(gt)

                few_msgs.append(chatml_few)
                no_msgs.append(chatml_no)

                few_shot_input_ids.append(tok_few.input_ids)
                no_shot_input_ids.append(tok_no.input_ids)

                few_shot_attention_mask.append(tok_few.attention_mask)
                no_shot_attention_mask.append(tok_no.attention_mask)

                few_shot_completions.append(few_completions)
                no_shot_completions.append(no_completions)
                truncated_few_shot_completions.append(truncated_few_completions)

                few_shot_completion_ids.append(few_completion_ids)
                no_shot_completion_ids.append(no_completion_ids)
                truncated_few_shot_completion_ids.append(truncated_few_completion_ids)

                all_rewards.append(rewards_few)
            else:
                self.failure_counts += 1
                continue

        truncated_few_shot_completion_ids = torch.cat(truncated_few_shot_completion_ids, dim=0)

        print("truncated_few_shot_completion_ids:", truncated_few_shot_completion_ids.shape)

        flat_rewards = [r for sub in all_rewards for r in sub]
        rewards_tensor = torch.tensor(flat_rewards, device=truncated_few_shot_completion_ids.device)

        mask = rewards_tensor >= 1.0                          # BoolTensor [B*G]
        accurate_completion_ids = truncated_few_shot_completion_ids[mask]              # [N, L]

        print("accuracte completion ids:", accurate_completion_ids.shape)

        flat_few = [c for comps in few_shot_completions for c in comps]
        flat_no  = [c for comps in no_shot_completions  for c in comps]
        flat_few_truncated = [c for comps in truncated_few_shot_completions for c in comps]

        return (questions, gt_answers,
                few_msgs, no_msgs,
                torch.cat(few_shot_input_ids,  dim=0),
                torch.cat(few_shot_attention_mask, dim=0),
                torch.cat(no_shot_input_ids,   dim=0),
                torch.cat(no_shot_attention_mask,  dim=0),
                torch.cat(few_shot_completion_ids, dim=0),
                torch.cat(no_shot_completion_ids,  dim=0),
                truncated_few_shot_completion_ids,
                accurate_completion_ids,
                flat_few, flat_no, flat_few_truncated)


    def policy_reset(self, acc_rewards_no, rewards_few):
        """
        Manages model state during training to prevent catastrophic forgetting.

        This method tracks an Exponential Moving Average (EMA) of the model's
        accuracy on both few-shot and zero-shot generations. It saves a
        snapshot of the model's state whenever a new best EMA is achieved.

        A reset is triggered under two conditions:
        1. Hard Reset: If the current EMA drops significantly below the best EMA.
        2. Patience Reset: If no new best EMA has been found for a set
           number of steps (`self.patience`).

        When a reset occurs, the model's weights are rolled back to the last
        best snapshot.

        Args:
            acc_rewards_no (list[float]): A list of accuracy scores from the
                zero-shot generations in the current batch.
            rewards_few (list[float]): A list of accuracy scores from the
                few-shot generations in the current batch.

        Returns:
            bool: True if the model's state was reset (indicating the
                  current optimization step should be skipped), False otherwise.
        """
        B = len(acc_rewards_no) // self.config.num_generations
        G = self.config.num_generations
        r_no = torch.tensor(acc_rewards_no, device=self.config.device).view(B, G)
        r_few = torch.tensor(rewards_few, device=self.config.device).view(B, G)

        # per‐example success = did any gen get ≥2?

        per_q_acc_no_shot = (r_no >= 2).any(dim=1).float().mean().item()  # float in [0,1]
        per_q_acc_few_shot = (r_few >= 2).any(dim=1).float().mean().item()  # float in [0,1]

        # init EMA
        if self.ema_no_shot is None:
            self.ema_no_shot = per_q_acc_no_shot
            self.best_ema_no_shot = per_q_acc_no_shot
            self.ema_few_shot = per_q_acc_few_shot
            self.best_ema_few_shot = per_q_acc_few_shot
            self.best_state_dict = self.best_state_dict = {
                                      k: v.detach().cpu()
                                      for k, v in self.model.state_dict().items()
                                  }
            return False

        # update EMA
        self.ema_no_shot = self.ema_alpha * per_q_acc_no_shot + (1-self.ema_alpha) * self.ema_no_shot
        self.ema_few_shot = self.ema_alpha * per_q_acc_few_shot + (1-self.ema_alpha) * self.ema_few_shot

        hard_reset_no_shot = (
            self.best_ema_no_shot > 0 and
            (self.best_ema_no_shot - self.ema_no_shot) / self.best_ema_no_shot > self.drop_threshold
          )
        hard_reset_few_shot = (
              self.best_ema_few_shot > 0 and
              (self.best_ema_few_shot - self.ema_few_shot) / self.best_ema_few_shot > self.drop_threshold
          )
        # check hard reset
        if hard_reset_no_shot or hard_reset_few_shot:
            print(f"[Hard Reset] EMA fell from {self.best_ema_no_shot:.3f} to {self.ema_no_shot:.3f}, rolling back.")
            state = {k: v.to(self.config.device) for k, v in self.best_state_dict.items()}
            self.model.load_state_dict(state)
            self.ema_no_shot = self.best_ema_no_shot
            self.ema_few_shot = self.best_ema_few_shot
            self.steps_since_last_improve = 0
            return True

        # new best?
        if self.ema_no_shot > self.best_ema_no_shot:
            self.best_ema_no_shot = self.ema_no_shot
            self.best_state_dict = {
                                      k: v.detach().cpu()
                                      for k, v in self.model.state_dict().items()
                                  }
            self.steps_since_last_improve = 0
            print(f"New best EMA={self.best_ema_no_shot:.3f}, snapshotting.")
            if self.config.use_lora:
              self.model.merge_adapter
              checkpoint_path = os.path.join(self.config.output_dir, f"checkpoint-best-{self.step}")
              os.makedirs(checkpoint_path, exist_ok=True)
              self.model.save_pretrained(checkpoint_path)
              self.tokenizer.save_pretrained(checkpoint_path)
              print(f"Best model checkpoint saved to {checkpoint_path}\n")
              self.model.unmerge_adapter()
            else:
              checkpoint_path = os.path.join(self.config.output_dir, f"checkpoint-best-{self.step}")
              os.makedirs(checkpoint_path, exist_ok=True)
              self.model.save_pretrained(checkpoint_path)

            self.plot_save_metrics(f"best_ema_{self.step}.png")
            self._cleanup_old_checkpoints()

        else:
            self.steps_since_last_improve += 1

        # patience‐based reset
        if self.steps_since_last_improve >= self.patience:
            print(f"[Patience Reset] no improvement in {self.patience} steps; rolling back to EMA={self.best_ema:.3f}")
            state = {k: v.to(self.config.device) for k, v in self.best_state_dict.items()}
            self.model.load_state_dict(state)
            self.ema_no_shot = self.best_ema_no_shot
            self.ema_few_shot = self.best_ema_few_shot
            self.steps_since_last_improve = 0
            return True

        return False

    def _cleanup_old_checkpoints(self, keep_last_n=20):
        """Keep only the last N checkpoints to save disk space"""
        import glob
        import os

        checkpoint_pattern = os.path.join(self.config.output_dir, "checkpoint-best-*")
        checkpoints = glob.glob(checkpoint_pattern)

        # Sort by step number (extract from filename)
        def extract_step(path):
            try:
                return int(os.path.basename(path).split('-')[-1])
            except:
                return 0

        checkpoints.sort(key=extract_step)

        # Remove old checkpoints, keeping only the last N
        if len(checkpoints) > keep_last_n:
            for old_checkpoint in checkpoints[:-keep_last_n]:
                try:
                    import shutil
                    shutil.rmtree(old_checkpoint)
                    print(f"Cleaned up old checkpoint: {old_checkpoint}")
                except Exception as e:
                    print(f"Failed to remove {old_checkpoint}: {e}")

    def plot_save_metrics(self, metrics_name):
        
        """
        Generates and saves a plot of the key training metrics.

        This method creates a multi-panel figure visualizing the progression
        of total reward, accuracy reward, completion length, and training loss
        over the training steps. The resulting plot is saved as a PNG image
        in the configured output directory.

        Args:
            metrics_name (str): The filename (e.g., "final_metrics.png") for
                the output plot.
        """

            steps = list(range(1, len(self._metrics["reward_few_shot"]) + 1))

            fig, axes = plt.subplots(
                4, 1,
                figsize=(20, 24),    # increase width & height
                dpi=120              # higher resolution
            )


            # 1) Reward
            axes[0].plot(steps, [r.cpu().item() if isinstance(r, torch.Tensor) else r for r in self._metrics["reward_few_shot"]], label="Few-shot")
            axes[0].plot(steps, [r.cpu().item() if isinstance(r, torch.Tensor) else r for r in self._metrics["reward_no_shot"]], label="No-shot")
            axes[0].set_title("Reward vs Steps")
            axes[0].legend()

            # 2) Accuracy Reward
            axes[1].plot(steps, [r.cpu().item() if isinstance(r, torch.Tensor) else r for r in self._metrics["accuracy_reward_few_shot"]], label="Few-shot")
            axes[1].plot(steps, [r.cpu().item() if isinstance(r, torch.Tensor) else r for r in self._metrics["accuracy_reward_no_shot"]], label="No-shot")
            axes[1].set_title("Accuracy Reward vs Steps")
            axes[1].legend()

            # 3) Completion Length
            axes[2].plot(steps, self._metrics["few_shot_completion_length"], label="Few-shot")
            axes[2].plot(steps, self._metrics["no_shot_completion_length"], label="No-shot")
            axes[2].set_title("Completion Length vs Steps")
            axes[2].legend()
            
            axes[3].plot(steps, self._metrics["loss"], label="Training Loss")
            axes[3].set_title("Training Loss")
            axes[3].legend()

            plt.tight_layout(pad=3.0)
            metrics_png = os.path.join(self.config.output_dir, metrics_name)
            fig.savefig(metrics_png, bbox_inches="tight", dpi=120)
            plt.close(fig)
            print(f"Metrics plot saved to {metrics_png}")

    def train(self):
        """
        Executes the main training loop for the AWDPO algorithm.

        This method orchestrates the entire training process. It iterates for the
        configured number of epochs or steps, and in each step it:
        1. Prepares a batch of data by generating and filtering responses
           (`_fill_successful_batch`).
        2. Performs a forward pass to calculate the combined AWDPO and MLE losses.
        3. Executes the backward pass and optimizer step, handling gradient
           accumulation and mixed-precision scaling.
        4. Manages the learning rate schedule.
        5. Tracks and logs all relevant metrics.
        6. Handles model checkpointing and saving the final model.
        """
        self.model.train()
        accumulation_counter = 0

        stop_on_steps = self.config.max_steps > 0
        if stop_on_steps:
            batches_per_epoch = math.ceil(len(self.train_dataset) / self.config.per_device_train_batch_size)
            updates_per_epoch = math.ceil(batches_per_epoch / self.config.gradient_accumulation_steps)
            total_updates = updates_per_epoch * self.config.num_train_epochs
        else:
            total_updates = self.config.max_steps

        pbar = tqdm(total=total_updates, desc="Training", unit="step")

        sampler = RandomSampler(self.train_dataset, replacement=False)
        idx_iter = iter(sampler)

        for epoch in range(self.config.num_train_epochs):
            while self.step < total_updates:
                batch = self._fill_successful_batch(idx_iter)
                if batch is None:
                    return
                (questions, gt_answers,
                 few_shot_msgs, no_shot_msgs,
                 few_shot_input_ids, few_shot_attention_mask,
                 no_shot_input_ids,  no_shot_attention_mask,
                 few_shot_completion_ids, no_shot_completion_ids, truncated_few_shot_completion_ids,
                 accurate_completion_ids,
                 few_shot_completions, no_shot_completions, truncated_few_shot_completions
                ) = batch


                with torch.autocast(
                    device_type=self.config.device,
                    enabled=(self.scaler is not None),
                    dtype=(torch.bfloat16 if self.config.bf16 else torch.float16)
                ):
                    # Calculate log probabilities under both few-shot and no-shot settings
                    # Expand no_shot_input_ids to match batch size of completions
                    expanded_ids = no_shot_input_ids.repeat(self.config.num_generations, 1)
                    expanded_attn = no_shot_attention_mask.repeat(self.config.num_generations, 1).to(self.config.device)

                    expanded_accurate_ids = few_shot_input_ids.repeat(accurate_completion_ids.shape[0], 1)
                    expanded_accurate_attn = few_shot_attention_mask.repeat(accurate_completion_ids.shape[0], 1).to(self.config.device)

                    #Create full few-shot sequences (few-shot prompt + few-shot completion)

                    print("expanded_accurate_ids:", expanded_accurate_ids.shape)
                    print("accurate_completion_ids:", accurate_completion_ids.shape)
                    few_shot_ids_mle = torch.cat([expanded_accurate_ids, accurate_completion_ids], dim=1)
                    few_shot_attn_mle = self.create_attention_mask_with_eos_handling(
                        expanded_accurate_ids,
                        accurate_completion_ids,
                        expanded_accurate_attn
                    )

                    # Create full few-shot sequences (no-shot prompt + few-shot completion)
                    few_shot_full_ids = torch.cat([expanded_ids, truncated_few_shot_completion_ids], dim=1)
                    few_shot_attention_mask = self.create_attention_mask_with_eos_handling(
                        expanded_ids,
                        truncated_few_shot_completion_ids,
                        expanded_attn
                    )

                    # Create full no-shot sequences (no-shot prompt + no-shot completion)
                    no_shot_full_ids = torch.cat([expanded_ids, no_shot_completion_ids], dim=1)
                    no_shot_attention_mask = self.create_attention_mask_with_eos_handling(
                        expanded_ids,
                        no_shot_completion_ids,
                        expanded_attn
                    )
                    
                    avg_few_len = self._compute_avg_response_length(expanded_ids,
                                                few_shot_attention_mask)
                    avg_no_len  = self._compute_avg_response_length(expanded_ids,
                                                                    no_shot_attention_mask)


                    print("few_shot_full_ids:", few_shot_full_ids.dtype, few_shot_full_ids.shape)
                    print("no_shot_full_ids: ", no_shot_full_ids.dtype,  no_shot_full_ids.shape)

                    # Calculate logps

                    outputs_few = self.model(input_ids=few_shot_full_ids, attention_mask=few_shot_attention_mask, use_cache = False)
                    outputs_no = self.model(input_ids=no_shot_full_ids, attention_mask=no_shot_attention_mask, use_cache = False)

                    outputs_mle_few = self.model(input_ids=few_shot_ids_mle, attention_mask=few_shot_attn_mle, use_cache = False)

                    few_shot_logits = outputs_few.logits[:, :-1, :]  # [2*bsz, L-1, V]
                    no_shot_logits = outputs_no.logits[:, :-1, :]

                    logits_mle_few = outputs_mle_few.logits[:, :-1, :]

                    few_shot_logps = self._slice_and_gather(few_shot_logits, few_shot_full_ids, truncated_few_shot_completion_ids.shape[1])
                    no_shot_logps  = self._slice_and_gather(no_shot_logits,  no_shot_full_ids,  no_shot_completion_ids.shape[1])

                    few_shot_accurate_logps = self._slice_and_gather(logits_mle_few, few_shot_ids_mle, accurate_completion_ids.shape[1])

                    # Sum across sequence to get sequence probability
                    few_shot_accurate_logps = few_shot_accurate_logps.sum(dim=1)

                    logp_plus  = few_shot_logps.sum(dim=1)
                    logp_minus = no_shot_logps.sum(dim=1)

                    if self.ref_model is not None:
                      with torch.no_grad():
                          outputs_ref_few_shot = self.ref_model(input_ids = few_shot_full_ids, attention_mask = few_shot_attention_mask)
                          few_shot_logits_ref = outputs_ref_few_shot.logits[:, :-1, :]
                          outputs_few_no_shot = self.ref_model(input_ids = no_shot_full_ids, attention_mask = no_shot_attention_mask)
                          no_shot_logits_ref = outputs_few_no_shot.logits[:, :-1, :]
                          few_shot_logps_ref = self._slice_and_gather(few_shot_logits_ref, few_shot_full_ids, truncated_few_shot_completion_ids.shape[1])
                          no_shot_logps_ref  = self._slice_and_gather(no_shot_logits_ref,  no_shot_full_ids,  no_shot_completion_ids.shape[1])
                          logp_plus_ref = few_shot_logps_ref.sum(dim=1)
                          logp_minus_ref = no_shot_logps_ref.sum(dim=1)

                          delta_ref_plus = logp_plus - logp_plus_ref
                          delta_ref_minus = logp_minus - logp_minus_ref

                          delta = delta_ref_plus - delta_ref_minus

                    else:
                          delta = logp_plus - logp_minus

                    # Evaluate rewards using no-shot prompt for consistency
                    view_flag = (self.step < self.config.num_generated_samples_to_view)

                    # Get accuracy rewards

                    expanded_gt_answers = [answer for answer in gt_answers for _ in range(self.config.num_generations)]
                    expanded_no_shot_prompts = [prompt for prompt in no_shot_msgs for _ in range(self.config.num_generations)]

                    combined_rewards_few, reward_dict_few = self.evaluate_rewards(
                        expanded_no_shot_prompts, truncated_few_shot_completions, expanded_gt_answers
                    )
                    combined_rewards_no, reward_dict_no = self.evaluate_rewards(
                        expanded_no_shot_prompts, no_shot_completions, expanded_gt_answers
                    )

                    rewards_tensor_few = torch.tensor(combined_rewards_few, device=self.config.device, dtype=torch.float)
                    rewards_tensor_no = torch.tensor(combined_rewards_no, device=self.config.device, dtype=torch.float)

                    G = self.config.num_generations
                    B = len(gt_answers)

                    rewards_tensor_few_grouped = rewards_tensor_few.view(B, G)
                    rewards_tensor_no_grouped = rewards_tensor_no.view(B, G)

                    # Log batch-level metrics for monitoring
                    reward_avg_few = rewards_tensor_few.mean()
                    reward_std_few = rewards_tensor_few.std() if rewards_tensor_few.numel() > 1 else 0.0
                    reward_avg_no = rewards_tensor_no.mean()
                    reward_std_no = rewards_tensor_no.std() if rewards_tensor_no.numel() > 1 else 0.0

                    acc_rewards_few = accuracy_reward(expanded_no_shot_prompts, truncated_few_shot_completions, expanded_gt_answers,
                                              num_generated_samples_to_view=view_flag, q_num=self.step)
                    acc_rewards_no = accuracy_reward(expanded_no_shot_prompts, no_shot_completions, expanded_gt_answers,
                                              num_generated_samples_to_view=view_flag, q_num=self.step)

                    if self.config.policy_reset:
                      reset = self.policy_reset(acc_rewards_no)
                      if reset:
                          continue   # skip optimizer.step() this iteration

                    reasoning_rewards_few = reward_dict_few.get("reasoning_reward", [0.0]*self.config.num_generations)
                    reasoning_reward_avg_few = sum(reasoning_rewards_few) / len(reasoning_rewards_few)

                    reasoning_rewards_no = reward_dict_no.get("reasoning_reward", [0.0]*self.config.num_generations)
                    reasoning_reward_avg_no = sum(reasoning_rewards_no) / len(reasoning_rewards_no)

                    advantages = rewards_tensor_few - rewards_tensor_no

                    # Optionally apply normalization to stabilize training, without losing the sign of the advantage
                    #normalized_advantages = torch.clamp(advantages, -5.0, 5.0)
                    normalized_advantages = torch.tanh(advantages)

                    w = normalized_advantages.abs()      # |A|
                    s = normalized_advantages.sign()     # ±1 or 0
                    dpo_loss = -(w * F.logsigmoid(s * delta * self.beta)).mean()
                    print(f"DPO Loss: {dpo_loss.item()}")

                    mle_loss = -few_shot_accurate_logps.mean()

                    print(f"MLE Loss: {mle_loss.item()}")

                    # Calculate dynamic lambda based on current loss magnitudes
                    if mle_loss > 0:
                        lambda_dynamic = dpo_loss.item() / mle_loss.item()
                    else:
                        lambda_dynamic = self.lambda_value  # fallback

                    print(f"Dynamic Lambda: {lambda_dynamic}")

                    #if self.config.use_reference_model:
                     # loss = dpo_loss + lambda_dynamic * mle_loss
                    #else:
                     # loss = dpo_loss + mle_loss
                    loss = dpo_loss + lambda_dynamic * mle_loss

                current_batch_responses = [
                        {
                            'timestamp': datetime.now().isoformat(),
                            'question': q,
                            'expected_answer': a,
                            'full_response_few_shot': few_resp,
                            'full_response_no_shot': no_resp,
                            'few_shot_reward': combined_reward_few,
                            'no_shot_reward': combined_reward_no,
                            'acc_reward_few': acc_reward_few,
                            'acc_reward_no': acc_reward_no,
                            'dpo_loss':dpo_loss.item(),
                            'mle_loss':mle_loss.item(),
                            'loss': loss.item()
                        }
                        for q, a, few_resp,\
                            no_resp, combined_reward_few,\
                            combined_reward_no, acc_reward_few, acc_reward_no in zip(expanded_no_shot_prompts,
                                                                                      expanded_gt_answers,
                                                                                      few_shot_completions,
                                                                                      no_shot_completions,
                                                                                      combined_rewards_few,
                                                                                      combined_rewards_no,
                                                                                      acc_rewards_few,
                                                                                      acc_rewards_no)
                    ]

                self.save_responses(current_batch_responses)

                loss = loss / self.config.gradient_accumulation_steps
                if self.scaler is not None:
                    self.scaler.scale(loss).backward()
                else:
                    loss.backward()
                accumulation_counter += 1

                if accumulation_counter % self.config.gradient_accumulation_steps == 0:
                    if self.scaler is not None:
                        self.scaler.unscale_(self.optimizer)
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)
                        self.scaler.step(self.optimizer)
                        self.scaler.update()
                    else:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)
                        self.optimizer.step()
                    self.scheduler.step()
                    self.optimizer.zero_grad()
                    accumulation_counter = 0

                    # Track metrics
                    self._metrics["loss"].append(loss.item() * self.config.gradient_accumulation_steps)
                    self._metrics["no_shot_completion_length"].append(avg_no_len)
                    self._metrics["few_shot_completion_length"].append(avg_few_len)
                    self._metrics["reward_few_shot"].append(reward_avg_few)
                    self._metrics["reward_no_shot"].append(reward_avg_no)
                    self._metrics["reward_std_few_shot"].append(reward_std_few)
                    self._metrics["reward_std_no_shot"].append(reward_std_no)
                    self._metrics["accuracy_reward_few_shot"].append(sum(acc_rewards_few))
                    self._metrics["accuracy_reward_no_shot"].append(sum(acc_rewards_no))
                    self._metrics["reasoning_reward_few"].append(reasoning_reward_avg_few)
                    self._metrics["reasoning_reward_no"].append(reasoning_reward_avg_no)

                    # Print metrics
                    print(f"Step {self.step} | Loss: {loss.item()*self.config.gradient_accumulation_steps:.10f} | "
                          f"Reward Few Shot: {reward_avg_few:.4f} | Reward Std Few Shot: {reward_std_few:.4f} | "
                          f"Reward No Shot: {reward_avg_no:.4f} | Reward Std No Shot: {reward_std_no:.4f} | "
                          f"Accuracy Reward Few Shot: {sum(acc_rewards_few):.4f} | Accuracy Reward No Shot {sum(acc_rewards_no):.4f} | "
                          f"Few Shot Completion Length: {avg_few_len:.4f} | No Shot Completion Length: {avg_no_len:.4f}\n")
                    self.step += 1

                    pbar.update(1)
                    pbar.set_postfix({
                        "loss": (loss.item() * self.config.gradient_accumulation_steps),
                        "rew_few": (reward_avg_few.item()
                                    if isinstance(reward_avg_few, torch.Tensor)
                                    else reward_avg_few),
                    })

                    # Save checkpoint
                    if self.step % self.config.save_steps == 0:
                        checkpoint_path = os.path.join(self.config.output_dir, f"checkpoint-{self.step}")
                        os.makedirs(checkpoint_path, exist_ok=True)
                        self.model.save_pretrained(checkpoint_path)
                        self.tokenizer.save_pretrained(checkpoint_path)
                        print(f"Checkpoint saved to {checkpoint_path}\n")

                if stop_on_steps and self.step >= self.config.max_steps:
                    break
            if stop_on_steps and self.step >= self.config.max_steps:
                break

        pbar.close()

        # Save final model
        final_model_path = os.path.join(self.config.output_dir, "final_model")
        os.makedirs(final_model_path, exist_ok=True)
        self.model.save_pretrained(final_model_path)
        self.tokenizer.save_pretrained(final_model_path)
        print(f"Final model saved to {final_model_path}")
        self.plot_save_metrics("full_training_metrics")