import os
from typing import Dict, Literal
from collections import defaultdict
from contextlib import nullcontext

from trl.commands.cli_utils import init_zero_verbose
from trl.env_utils import strtobool

TRL_USE_RICH = strtobool(os.getenv("TRL_USE_RICH", "0"))

if TRL_USE_RICH:
    init_zero_verbose()
    FORMAT = "%(message)s"

    from rich.console import Console
    from rich.logging import RichHandler


import torch
from transformers import TrainingArguments, Trainer

from data import load_data
from model import load
from loss import (
    get_sft_loss,
    get_digit_loss,
    get_digit_base_loss,
)
from args import get_args


os.environ["WANDB_PROJECT"] = "digit_grounding"


class BaseTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._stored_metrics = defaultdict(lambda: defaultdict(list))

    def store_metrics(
        self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
    ) -> None:
        for key, value in metrics.items():
            self._stored_metrics[train_eval][key].append(value)

    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching training, including stored metrics.

        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """
        # logs either has 'loss' or 'eval_loss'
        train_eval = "train" if "loss" in logs else "eval"
        # Add averaged stored metrics to logs
        for key, metrics in self._stored_metrics[train_eval].items():
            logs[key] = torch.tensor(metrics).mean().item()
        del self._stored_metrics[train_eval]
        return super().log(logs)


class DIGITTrainer(BaseTrainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        batch = inputs
        attention_mask = batch["attention_mask"]
        labels = batch.pop("labels")

        is_llava = "llava" in self.model.config.model_type
        kwargs = {}
        if is_llava:
            kwargs["labels"] = labels  # expand labels to match the logit length

        outputs = model(**batch, **kwargs)
        if is_llava:
            labels = outputs.labels  # <-- important
            attention_mask = outputs.attention_mask

        sft_loss = get_sft_loss(
            outputs.logits,
            labels,
            attention_mask,
        )
        digit_loss = get_digit_loss(
            self.tokenizer,
            outputs.logits,
            labels,
            attention_mask,
            beta=0.1,
            target_temperature=2.0,
        )

        loss = sft_loss + digit_loss
        stats = {
            "loss/total": loss.detach(),
            "loss/sft": sft_loss.detach(),
            "loss/digit": digit_loss.detach(),
        }
        self.store_metrics(stats, train_eval="train")
        return (loss, outputs) if return_outputs else loss


class DIGITBaseTrainer(BaseTrainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        batch = inputs
        attention_mask = batch["attention_mask"]
        labels = batch.pop("labels")

        is_llava = "llava" in self.model.config.model_type
        kwargs = {}
        if is_llava:
            kwargs["labels"] = labels  # expand labels to match the logit length
        outputs = model(**batch, **kwargs)
        if is_llava:
            labels = outputs.labels  # <-- important
            attention_mask = outputs.attention_mask

        sft_loss = get_sft_loss(
            outputs.logits,
            labels,
            attention_mask,
        )
        digit_loss = get_digit_base_loss(
            self.tokenizer,
            outputs.logits,
            labels,
            attention_mask,
            beta=0.1,
            target_temperature=2.0,
        )

        loss = sft_loss + digit_loss
        stats = {
            "loss/total": loss.detach(),
            "loss/sft": sft_loss.detach(),
            "loss/digit": digit_loss.detach(),
        }
        self.store_metrics(stats, train_eval="train")
        return (loss, outputs) if return_outputs else loss


class SFTTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        batch = inputs
        attention_mask = batch["attention_mask"]
        labels = batch.pop("labels")

        is_llava = "llava" in self.model.config.model_type
        kwargs = {}
        if is_llava:
            kwargs["labels"] = labels  # expand labels to match the logit length
        outputs = model(**batch, **kwargs)
        if is_llava:
            labels = outputs.labels  # <-- important
            attention_mask = outputs.attention_mask

        loss = get_sft_loss(
            outputs.logits,
            labels,
            attention_mask,
        )
        return (loss, outputs) if return_outputs else loss


def main():
    args = get_args()
    print(args)

    model, processor = load(args.model)

    data, collator = load_data(args, processor)

    train_args = TrainingArguments(
        output_dir=args.output_dir,
        per_device_train_batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        max_grad_norm=1,
        gradient_accumulation_steps=args.grad_acc,
        num_train_epochs=args.num_epochs,
        warmup_steps=args.num_warmup_steps,
        logging_steps=10,
        save_strategy="steps",
        save_steps=args.save_steps,
        save_total_limit=2,
        save_only_model=True,
        seed=42,
        fp16=True,
        tf32=True,
        dataloader_num_workers=args.num_workers,
        report_to="wandb",
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        torch_compile=True,
        deepspeed="./zero3.json",
        local_rank=args.local_rank,
        remove_unused_columns=False,
        run_name=args.run_name,
    )
    # Force use our print callback
    if TRL_USE_RICH:
        train_args.disable_tqdm = True
        console = Console()
    trainer_cls = {
        "sft": SFTTrainer,
        "digit": DIGITTrainer,
        "digit_base": DIGITBaseTrainer,
    }[args.loss]

    init_context = (
        nullcontext()
        if not TRL_USE_RICH
        else console.status("[bold green]Initializing the SFTTrainer...")
    )
    save_context = (
        nullcontext()
        if not TRL_USE_RICH
        else console.status(
            f"[bold green]Training completed! Saving the model to {train_args.output_dir}"
        )
    )

    with init_context:
        trainer = trainer_cls(
            model=model,
            args=train_args,
            train_dataset=data,
            data_collator=collator,
            tokenizer=processor.processor.tokenizer,
            # compute_metrics=compute_metrics,
        )
        trainer.train()
    with save_context:
        trainer.save_model(train_args.output_dir)


if __name__ == "__main__":
    main()
