import os
import sys
from dataclasses import dataclass
from functools import partial
from time import time
from typing import Tuple, Dict, Any, Iterable

import numpy as np
import torch
import torch.nn.functional as F
import trlx.utils.logging as logging
from rich.console import Console
from rich.table import Table
from torch.nn.parallel import DistributedDataParallel
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torchtyping import TensorType
from trlx.data.accelerate_base_datatypes import PromptBatch
from trlx.data.ppo_types import PPORLElement, PPORLBatch
from trlx.data.configs import TRLConfig
from trlx.models.modeling_ppo import FixedKLController
from trlx.pipeline.ppo_pipeline import PPORolloutStorage
from trlx.trainer import register_trainer
from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer
from trlx.utils import significant, Clock
from trlx.utils.modeling import gather_dict, logprobs_of_labels, get_tensor_stats, flatten_dict, RunningMoments

logger = logging.get_logger(__name__)

sys.path.append(os.getcwd())

from src.core.kl_controller import AdaptiveUncertaintyKLController

@dataclass
class PPORLElementExtended(PPORLElement):
    lr_coef: TensorType["response_size"] 


@dataclass
class PPORLBatchExtended(PPORLBatch):
    lr_coef: TensorType["response_size"] 


def ppo_collate_fn(padding_side: str, pad_token_id: int, elems: Iterable[PPORLElementExtended]):
    if padding_side == "left":
        # Left padding of already left-padded queries
        query_tensors = pad_sequence(
            [elem.query_tensor.flip(0) for elem in elems],
            padding_value=pad_token_id,
            batch_first=True,
        ).flip(1)
    elif padding_side == "right":
        query_tensors = pad_sequence(
            [elem.query_tensor for elem in elems],
            padding_value=pad_token_id,
            batch_first=True,
        )
    else:
        raise ValueError(f"Invalid padding side: {padding_side}")

    return PPORLBatchExtended(
        query_tensors,
        # Right pad the rest, to have a single horizontal query/response split
        pad_sequence(
            [elem.response_tensor for elem in elems],
            padding_value=pad_token_id,
            batch_first=True,
        ),
        pad_sequence(
            [elem.logprobs for elem in elems],
            padding_value=0.0,
            batch_first=True,
        ),
        pad_sequence([elem.values for elem in elems], padding_value=0.0, batch_first=True),
        pad_sequence(
            [elem.rewards for elem in elems],
            padding_value=0.0,
            batch_first=True,
        ),
        torch.tensor(
            [elem.lr_coef for elem in elems],
        ),
    )

class PPORolloutStorageExtended(PPORolloutStorage):
    def create_loader(
        self,
        batch_size: int,
        shuffle: bool,
    ) -> DataLoader:
        return DataLoader(
            self, batch_size, shuffle=shuffle, collate_fn=partial(ppo_collate_fn, self.padding_side, self.pad_token_id)
        )


@register_trainer
class CustomMetricPPOTrainer(AcceleratePPOTrainer):
    def __init__(self, reward_model = None, kl_ctl = "fixed", config: TRLConfig = None, **kwargs):
        super(AcceleratePPOTrainer, self).__init__(config, **kwargs)

        # Setup rollout logging
        if config.train.rollout_logging_dir is not None:
            self.log_rollouts = True
            self.setup_rollout_logging(config)
        else:
            self.log_rollouts = False

        # Setup the rollout store
        # Rollouts contain the prompt & response, log probs, values and rewards - from each rollout
        self.store = PPORolloutStorageExtended(self.tokenizer.pad_token_id, self.tokenizer.padding_side)

        # Create the rollout store dataloader (for batching up rollouts)
        # TODO (jon-tow): This is only used to satisfy to `accelerator.prepare` call constraint below - remove in future
        rollout_loader: DataLoader = self.store.create_loader(self.config.train.batch_size, shuffle=True)

        # Prepare multi-GPU acceleration
        self.model, self.opt, self.scheduler, rollout_loader = self.accelerator.prepare(
            self.model, self.opt, self.scheduler, rollout_loader
        )

        self.store.clear_history()  # Clear the rollout store

        # Set up a reference model when hydra heads are not used
        unpacked_model = self.model if not isinstance(self.model, DistributedDataParallel) else self.model.module
        if not hasattr(unpacked_model, "frozen_head") and not unpacked_model.peft_type:
            self.ref_model = self.get_arch(self.config)
            self.ref_model.to(self.accelerator.device)
            self.ref_model.eval()

        # Create the parameters for the Hugging Face language model's generator
        # method (that generates new tokens from a prompt).
        # https://huggingface.co/docs/transformers/v4.25.1/en/main_classes/text_generation#transformers.GenerationMixin.generate
        generate_kwargs = dict(
            do_sample=True,
            use_cache=True,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
        )
        self.generate_kwargs = {**generate_kwargs, **config.method.gen_kwargs}

        if config.method.gen_experience_kwargs is not None:
            self.generate_experience_kwargs = {**generate_kwargs, **config.method.gen_experience_kwargs}
        else:
            self.generate_experience_kwargs = None

        # Setup stats tracker
        self.running_moments = RunningMoments()
        self.ref_mean = self.config.method.ref_mean
        self.ref_std = self.config.method.ref_std

        self.reward_model = reward_model
        self.g_init_mean_r = None
        self.init_mean_r = None
        if kl_ctl.lower() == "uncertainty":
            self.kl_ctl = AdaptiveUncertaintyKLController(kl_coef=config.method.init_kl_coef)
        else: 
            self.kl_ctl = FixedKLController(config.method.init_kl_coef)


    def compute_eval_metrics(self, samples, prompts, outputs, **metadata):
        rewards = self.reward_model.score(samples, prompts, outputs, agg_fn=None, is_eval=True)
        gold_reward = self.reward_model.score(samples, prompts, outputs, model=self.reward_model.gold_rm, tokenizer=self.reward_model.g_tokenizer, agg_fn=None, is_eval=True, cache_key="gold", batch_size=self.reward_model.g_batch_size)
        if self.g_init_mean_r is None:
            self.g_init_mean_r = gold_reward.mean().item()
        if self.init_mean_r is None:
            self.init_mean_r = rewards.mean().item()
        gold_reward -= self.g_init_mean_r
        rewards -= self.init_mean_r
        all_tokens, attention_mask = self.tokenizer(samples, return_tensors="pt", padding=True).to(self.accelerator.device).values()
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        kls = []
        batch_size = self.config.train.minibatch_size if self.config.train.minibatch_size is not None else self.config.train.batch_size
        for i in range(0, len(all_tokens), batch_size):
            tokens, mask, ids = all_tokens[i : i + batch_size], attention_mask[i : i + batch_size], position_ids[i : i + batch_size]
            with torch.no_grad():
                logits, *_ = self.model(tokens, attention_mask=mask, position_ids=ids)
                
                unpacked_model = self.model if not isinstance(self.model, DistributedDataParallel) else self.model.module
                if hasattr(unpacked_model, "frozen_head") or unpacked_model.peft_type:
                    ref_logits = unpacked_model.forward_hydra(
                        tokens,
                        attention_mask=mask,
                        position_ids=ids,
                        return_dict=True,
                    ).logits
                else:
                    ref_logits = self.ref_model(
                        tokens,
                        attention_mask=mask,
                        position_ids=ids,
                        return_dict=True,
                    ).logits
                    ref_logits = ref_logits.to(self.accelerator.device)

                logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:])
                ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], tokens[:, 1:])
                log_ratio = torch.clamp((logprobs - ref_logprobs) * mask[:, :-1], -10, 10)
                kls.append((log_ratio.exp() - 1 - log_ratio).sum(1))
        
        kl = torch.cat(kls)
        if self.reward_model.agg_fn == "min":
            reward = rewards.min(dim=0)[0]
        else:
            reward = rewards.mean(dim=0)
        return {"sample_kl": kl, "reward": reward, "reward_std": rewards.std(dim=0), "all_rewards": rewards.T, "gold_reward": gold_reward.flatten()}

    def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):  # noqa:
        """
        Takes `chunk_size` number of prompts from `prompt_iterator`, samples
        from the model and then computes the KL against a reference model. Finally it
        then appends PPOElements to trainer's `store`.

        Args:
            num_rollouts: Number of rollouts to generate
            iter_count: Total number of updates for all batches & epochs
        """
        logger.info("Collecting rollouts")
        tbar = logging.tqdm(
            total=num_rollouts,
            disable=os.environ.get("RANK", 0) != "0",
            desc=f"[rollout 0 / {num_rollouts}]",
            # Lower progress bar by 1 if we're in WARNING mode or above to avoid hiding high priority progress
            # bars (e.g. loss progress in trainers)
            position=logging.get_verbosity() >= logging.WARNING,
            # Leave progress bar if we're in INFO mode or lower to avoid spamming in suppressed verbosity levels
            leave=logging.get_verbosity() < logging.WARNING,
        )

        clock = Clock()
        ppo_rl_elements = []
        accumulated_stats = []
        all_reward_vars = []
        reward_vars_mean = torch.tensor(0).to(self.accelerator.device)

        while len(ppo_rl_elements) < num_rollouts:
            stats = {}
            # Get next batch in prompt dataset
            batch: PromptBatch = next(self.prompt_iterator)

            rollout_generate_time = time()

            # Generate samples from the language model (similar to using HuggingFace `generate` method)
            samples = self.generate(batch["input_ids"], batch["attention_mask"])
            stats["time/rollout_generate"] = time() - rollout_generate_time

            prompt_tensors = batch.input_ids
            device = samples.device

            prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device)
            padded_samples = self.accelerator.pad_across_processes(
                samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
            )
            padded_prompts = self.accelerator.pad_across_processes(
                prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
            )
            gathered_samples = self.accelerator.gather(padded_samples)
            gathered_prompts = self.accelerator.gather(padded_prompts)
            gathered_prompt_sizes = self.accelerator.gather(prompt_sizes)
            metadata = gather_dict({k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"})

            if self.accelerator.is_main_process:
                all_str_samples, all_str_prompts, all_str_outputs = self.decode(
                    gathered_prompts, gathered_samples, gathered_prompt_sizes, append_eos_token=True
                )

                rollout_score_time = time()
                # reward_fn should return list of rewards at each token per sample
                # NOTE: all_scores[0][i] is the reward due to token (action) i in prompt + response (b/c of how kl is computed)
                all_scores = self.reward_fn(
                    samples=all_str_samples,
                    prompts=all_str_prompts,
                    outputs=all_str_outputs,
                    tokenizer=self.tokenizer,
                    agg_fn=None,
                    **metadata,
                ).to(device)
                reward_vars = all_scores.var(axis=0)
                reward_vars = all_scores.var(axis=0)
                all_reward_vars.append(reward_vars)
                reward_vars_mean = torch.concat(all_reward_vars).mean().to(device)
                if self.reward_model.agg_fn == "min":
                    all_scores = all_scores.min(dim=0)[0]
                else:
                    all_scores = all_scores.mean(dim=0)
                all_scores = [
                    torch.tensor(score, dtype=torch.float, device=device).view(
                        -1,
                    )
                    for score in all_scores
                ]
                # Pad 0 reward on the ends
                all_scores = pad_sequence(all_scores, batch_first=True, padding_value=-np.inf)
                max_len = torch.tensor(all_scores.shape[1], dtype=torch.long, device=device)

                stats["time/rollout_score"] = time() - rollout_score_time

                all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind())
                reward_vars = list(reward_vars.reshape(self.accelerator.num_processes, -1).unbind())
            else:
                all_scores = None
                max_len = torch.tensor(0, dtype=torch.long, device=device)
                reward_vars = None

            if torch.distributed.is_initialized():
                torch.distributed.barrier()
                torch.distributed.broadcast(max_len, 0)
                scores = torch.empty((len(samples), max_len), device=device)
                reward_vars_scattered = torch.empty(len(samples), device=device, dtype=torch.bfloat16)
                torch.distributed.scatter(scores, all_scores)
                torch.distributed.scatter(reward_vars_scattered, reward_vars)
            else:
                scores = all_scores[0].clone().detach()
                reward_vars_scattered = reward_vars[0].clone().detach()
            scores_mask = scores != -np.inf

            str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True)

            # Pad the sample outputs
            outputs = self.tokenizer(str_outputs).input_ids
            if self.config.model.model_arch_type == "seq2seq":
                # add <pad> to the start of the output
                for i in range(len(outputs)):
                    outputs[i] = [self.tokenizer.pad_token_id] + outputs[i]

            outputs = list(map(torch.LongTensor, outputs))
            maxsize = max(map(len, outputs))
            outputs = [
                F.pad(
                    output,
                    (0, maxsize - len(output)),
                    value=self.tokenizer.pad_token_id,
                )
                for output in outputs
            ]
            sample_outputs = torch.vstack(outputs).to(device)

            if self.config.method.cliprange_reward:
                scores = torch.clip(scores, -self.config.method.cliprange_reward, self.config.method.cliprange_reward)

            # store statistics of the initial rollout as reference
            if self.ref_mean is None:
                self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), (scores * scores_mask).sum(
                    dim=1
                ).std()
            all_scores_mean, all_scores_std = self.running_moments.update(torch.sum(scores * scores_mask, dim=1))
            stats["rollout_scores/mean"] = all_scores_mean.item()
            stats["rollout_scores/std"] = all_scores_std.item()
            stats["rollout_scores/running_mean"] = self.running_moments.mean.item()
            stats["rollout_scores/running_std"] = self.running_moments.std.item()

            if self.config.method.scale_reward == "running":
                scores /= self.running_moments.std
            elif self.config.method.scale_reward == "ref":
                scores /= self.ref_std

            unpacked_model = self.model if not isinstance(self.model, DistributedDataParallel) else self.model.module
            # Precompute logprobs, values
            if self.config.model.model_arch_type == "seq2seq":
                attention_mask = batch.attention_mask.to(device)
                prompt_tensors = batch.input_ids.to(device)
                decoder_attention_mask = sample_outputs.not_equal(self.tokenizer.pad_token_id)
                decoder_attention_mask[:, 0] = 1
                with torch.no_grad():
                    outputs = self.model(
                        input_ids=prompt_tensors,
                        attention_mask=attention_mask,
                        decoder_input_ids=sample_outputs,
                        decoder_attention_mask=decoder_attention_mask,
                    )
                    logits = outputs.logits
                    values = outputs.value
                    if hasattr(unpacked_model, "frozen_head") or unpacked_model.peft_type:
                        ref_logits = unpacked_model.forward_hydra(
                            input_ids=prompt_tensors,
                            attention_mask=attention_mask,
                            decoder_input_ids=sample_outputs,
                            decoder_attention_mask=decoder_attention_mask,
                            return_dict=True,
                        ).logits
                    else:
                        ref_logits = self.ref_model(
                            input_ids=prompt_tensors,
                            attention_mask=attention_mask,
                            decoder_input_ids=sample_outputs,
                            decoder_attention_mask=decoder_attention_mask,
                            return_dict=True,
                        ).logits
            else:
                all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1)
                attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device)
                position_ids = attention_mask.long().cumsum(-1) - 1
                position_ids.masked_fill_(attention_mask == 0, 1)
                with torch.no_grad():
                    logits, *_, values = self.model(
                        all_tokens, attention_mask=attention_mask, position_ids=position_ids
                    )
                    # TODO(dahoas): When hydra model works need to also support generation on hydra head
                    if hasattr(unpacked_model, "frozen_head") or unpacked_model.peft_type:
                        ref_logits = unpacked_model.forward_hydra(
                            all_tokens,
                            attention_mask=attention_mask,
                            position_ids=position_ids,
                            return_dict=True,
                        ).logits
                    else:
                        ref_logits = self.ref_model(
                            all_tokens,
                            attention_mask=attention_mask,
                            position_ids=position_ids,
                            return_dict=True,
                        ).logits
                        ref_logits = ref_logits.to(device)

            if self.config.model.model_arch_type == "seq2seq":
                logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:])
                ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:])
            else:
                # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled
                logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:])
                ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:])

            n_samples: int = samples.shape[0]

            # Estimate the KL divergence between the model and reference model
            if self.config.model.model_arch_type == "seq2seq":
                attention_mask = sample_outputs != self.tokenizer.pad_token_id
                start = 0
            else:
                start = prompt_tensors.shape[1] - 1

            log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1]
            kl = log_ratio.exp() - 1 - log_ratio
            mean_kl_per_token = kl.mean()
            mean_kl = kl.sum(1).mean()

            logprobs = logprobs.cpu()
            ref_logprobs = ref_logprobs.cpu()
            prompt_tensors = prompt_tensors.cpu()
            sample_outputs = sample_outputs.cpu()
            values = values.cpu()[:, :-1]

            # Get the logprobs and values, for tokens that are not padding,
            # from the end of the prompt up to the <eos> token, while also including the latter
            # (these are taken from the student model and not the reference model)
            ends = start + attention_mask[:, start:].sum(1) + 1
            all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)]
            all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)]

            if isinstance(self.kl_ctl, AdaptiveUncertaintyKLController):
                kl_penalty = self.kl_ctl.value_uncertainty(reward_vars_scattered).view(-1, 1).cpu() * -log_ratio.cpu()
            else:
                kl_penalty = self.kl_ctl.value * -log_ratio.cpu()
            kl_penalty = [xs[start : ends[ix]] for ix, xs in enumerate(kl_penalty)]

            rollout_count = 0

            for sample_idx in range(n_samples):
                rewards = kl_penalty[sample_idx]
                # Then add in rewards
                if scores.shape[1] == 1:
                    # NOTE: Final reward given at EOS token following HHH practice
                    rewards[-1] += scores[sample_idx][0].cpu()
                else:
                    score = scores[sample_idx]
                    score_right_padding = torch.sum(scores_mask[sample_idx])
                    score = score[:score_right_padding].cpu()
                    p_score = torch.zeros_like(rewards)
                    p_score[: score.shape[0]] += score
                    rewards += p_score
                if isinstance(self.kl_ctl, AdaptiveUncertaintyKLController) and self.kl_ctl.prior_variance is not None:
                    lr_coef = 2 * self.kl_ctl.prior_variance / (self.kl_ctl.prior_variance + reward_vars_scattered[sample_idx])
                else:
                    lr_coef = torch.tensor(1)
                ppo_rl_elements.append(
                    PPORLElementExtended(
                        query_tensor=prompt_tensors[sample_idx],
                        response_tensor=sample_outputs[sample_idx],
                        logprobs=all_logprobs[sample_idx],
                        values=all_values[sample_idx],
                        rewards=rewards,
                        lr_coef=lr_coef
                    )
                )

                rollout_count += 1

            if torch.distributed.is_initialized():
                torch.distributed.barrier()
                torch.distributed.all_reduce(mean_kl, torch.distributed.ReduceOp.AVG)
                torch.distributed.broadcast(reward_vars_mean, 0)

            stats["time/rollout_time"] = clock.tick()
            stats["policy/sqrt_kl"] = torch.sqrt(mean_kl).item()
            stats["policy/kl_per_token"] = torch.sqrt(mean_kl_per_token).item()
            accumulated_stats.append(stats)

            tbar.set_description(f"[rollout {len(ppo_rl_elements)} / {num_rollouts}]")
            tbar.update(min(rollout_count, num_rollouts))
        tbar.close()

        # Update prior variance after the first run
        if isinstance(self.kl_ctl, AdaptiveUncertaintyKLController):
            if self.kl_ctl.prior_variance is None:
                self.kl_ctl.prior_variance = reward_vars_mean.item()
        
        stats = {k: sum([xs[k] for xs in accumulated_stats]) / len(accumulated_stats) for k in stats}
        stats["kl_ctl_value"] = self.kl_ctl.value
        stats["rollout_scores/reward_var"] = reward_vars_mean.item()
        self.mean_kl = stats["policy/sqrt_kl"] ** 2
        self.accelerator.log(stats, step=iter_count)

        # Push samples and rewards to trainer's rollout storage
        self.push_to_store(ppo_rl_elements)

    def evaluate(self):
        """Samples model using `eval_prompts`, computes statistics with `reward_fn` and `metric_fn`"""
        logger.info("Evaluating model")
        tbar = logging.tqdm(
            total=len(self.eval_dataloader),
            desc=f"[eval batch 0/{len(self.eval_dataloader)}]",
            disable=not self.accelerator.is_main_process,
            position=0,
            leave=True,
        )

        stats = {}
        table = []
        # A dedicated suffix for wandb logging
        all_samples = []
        all_prompts = []
        all_prompt_sizes = []
        all_metadata = []
        generate_time = time()
        for i_prompt, prompts in enumerate(self.eval_dataloader):
            metadata = {k: v for k, v in prompts.items() if k != "input_ids" and k != "attention_mask"}
            samples = self.generate_eval(prompts["input_ids"], prompts["attention_mask"])

            prompt_sizes = torch.tensor(prompts.input_ids.shape[1]).repeat(len(prompts.input_ids))
            prompts, samples, prompt_sizes = self.accelerator.gather_for_metrics(
                self.accelerator.pad_across_processes(
                    [prompts.input_ids, samples, prompt_sizes.to(samples.device)],
                    dim=1,
                    pad_index=self.tokenizer.pad_token_id,
                )
            )
            all_samples.extend(samples.tolist())
            all_prompts.extend(prompts.tolist())
            all_prompt_sizes.extend(prompt_sizes.tolist())

            metadata = gather_dict(metadata, self.accelerator.gradient_state)
            all_metadata.append(metadata)

            tbar.set_description(f"[eval batch {i_prompt + 1}/{len(self.eval_dataloader)}]")
            tbar.update()
        tbar.close()

        stats["time/generate"] = time() - generate_time

        if self.accelerator.is_main_process:
            str_samples, str_prompts, str_outputs = self.decode(all_prompts, all_samples, all_prompt_sizes)

            columns = ["prompt", "output"]
            columns_data = [str_prompts, str_outputs]

            metadata, *xs = all_metadata
            for k in metadata:
                for x in xs:
                    metadata[k].extend(x[k])

            logger.info("Computing metrics")
            metric_time = time()
            metrics = self.compute_eval_metrics(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata)
            stats["time/metric"] = time() - metric_time

            mean_metrics = {
                f"metrics/{k}": torch.as_tensor(xs).mean(-1).item() for k, xs in metrics.items() if len(xs.mean(-1).shape) == 0
            }

            stats.update(mean_metrics)

            for metric, values in metrics.items():
                # Skip metrics that are scalars since they represent aggregated values
                if isinstance(values, float):
                    continue
                columns.append(metric)
                if not isinstance(values, list):
                    values = values.tolist()
                columns_data.append(values)

            table.append(list(zip(*columns_data)))

        # Log and display evaluation metrics
        logger.info("Summarizing evaluation")
        if self.accelerator.is_main_process:
            rows = sum(list(map(list, zip(*table))), [])

            # Add metrics/rewards to the table's title
            table_title = f"Evaluation #{self.nth_evaluation}"
            for k, x in stats.items():
                if k.startswith("reward") or k.startswith("metrics"):
                    table_title += f" {k}: {significant(x)}"

            rich_table = Table(*columns, title=table_title, show_lines=True)
            for ix in range(min(3, len(rows))):
                rich_table.add_row(*[str(significant(x)) for x in rows[ix]])
            Console().print(rich_table)

            if self.config.train.tracker == "wandb":
                import wandb

                stats["samples"] = wandb.Table(columns, rows)

        self.nth_evaluation += 1
        return stats

    def ppo_adaptive_loss(
        self,
        logprobs: TensorType["batch_size", "response_size"],
        values: TensorType["batch_size", "response_size"],
        old_logprobs: TensorType["batch_size", "response_size"],
        old_values: TensorType["batch_size", "response_size"],
        advantages: TensorType["batch_size", "response_size"],
        returns: TensorType["batch_size", "response_size"],
        mask: TensorType["batch_size", "response_size"],
        lr_coef: TensorType["response_size"],
    ):
        values_clipped = torch.clamp(
            values,
            old_values - self.config.method.cliprange_value,
            old_values + self.config.method.cliprange_value,
        )
        n = mask.sum()

        vf_loss1 = (values - returns) ** 2 * lr_coef
        vf_loss2 = (values_clipped - returns) ** 2 * lr_coef
        vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / n

        vf_clipfrac = torch.sum((vf_loss2 > vf_loss1).float() * mask) / n

        log_ratio = (logprobs - old_logprobs) * mask
        ratio = torch.exp(log_ratio)
        # Unbiased KL-div estimates (`k3`). Ref: http://joschu.net/blog/kl-approx.html
        with torch.no_grad():
            approx_kl = torch.mean((ratio - 1) - log_ratio)

        pg_loss1 = -advantages * ratio * lr_coef
        pg_loss2 = -advantages * torch.clamp(
            ratio,
            1.0 - self.config.method.cliprange,
            1.0 + self.config.method.cliprange,
        ) * lr_coef
        pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / n
        pg_clipfrac = torch.sum((pg_loss2 > pg_loss1).float() * mask) / n

        loss = pg_loss + self.config.method.vf_coef * vf_loss

        stats = dict(
            losses=dict(
                total_loss=loss.item(),
                policy_loss=pg_loss.item(),
                value_loss=vf_loss.item(),
            ),
            values=dict(
                get_tensor_stats(values, mask, n),
                values_error=torch.sum(((values - returns) * mask) ** 2) / n,
                values_mape_error=torch.sum((abs(values - returns) * mask) / abs(returns * mask + 1e-2)) / n,
                clipfrac=vf_clipfrac,
            ),
            old_values=get_tensor_stats(old_values, mask, n),
            returns=get_tensor_stats(returns, mask, n),
            policy=dict(approx_kl=approx_kl.item(), clipfrac=pg_clipfrac.item()),
            ratio=(ratio * mask).sum() / n,
            padding_percentage=1 - n / mask.numel(),
        )

        return loss, flatten_dict(stats)

    def loss(self, batch: PPORLBatchExtended) -> Tuple[float, Dict[str, Any]]:
        """Computes loss on a batch of data and returns statistics

        Args:
            batch: `PPORLBatch` Previous batch of episodes

        Returns:
            loss: `Float` Loss value
            stats: `Dict[str, Any]` PPO Statistics values
        """
        # Move `batch` data to `accelerator` device
        query_tensors = batch.query_tensors.to(self.accelerator.device)
        response_tensors = batch.response_tensors.to(self.accelerator.device)
        old_logprobs = batch.logprobs.to(self.accelerator.device)
        old_values = batch.values.to(self.accelerator.device)
        old_rewards = batch.rewards.to(self.accelerator.device)
        lr_coef = batch.lr_coef[:, None].to(self.accelerator.device)
        response_length = old_rewards.shape[1]

        advantages, returns = self.config.method.get_advantages_and_returns(old_values, old_rewards, response_length)

        tokens = torch.cat((query_tensors, response_tensors), dim=1)
        attention_mask = tokens.not_equal(self.tokenizer.pad_token_id).long().to(tokens.device)
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        outputs = self.model(tokens, attention_mask, return_dict=True, position_ids=position_ids)
        logits = outputs.logits
        values_pred = outputs.value
        values_pred = values_pred[:, :-1]
        logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:])

        start = query_tensors.shape[1] - 1
        end = start + response_length
        logprobs, values_pred, mask = (
            logprobs[:, start:end],
            values_pred[:, start:end],
            attention_mask[:, start + 1 : end + 1],
        )

        loss, stats = self.ppo_adaptive_loss(
            logprobs=logprobs,
            values=values_pred,
            old_logprobs=old_logprobs,
            old_values=old_values,
            advantages=advantages,
            returns=returns,
            mask=mask,
            lr_coef=lr_coef
        )

        return loss, stats
