import os
from typing import Dict, Literal
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import dataclass

from trl.commands.cli_utils import init_zero_verbose
from trl.env_utils import strtobool

TRL_USE_RICH = strtobool(os.getenv("TRL_USE_RICH", "0"))

if TRL_USE_RICH:
    init_zero_verbose()
    FORMAT = "%(message)s"

    from rich.console import Console
    from rich.logging import RichHandler


import torch
from transformers import TrainingArguments, Trainer

from data import load_data
from model import load
from loss import (
    get_sft_loss,
    get_digit_loss,
    get_digit_loss_with_cont,
    get_digit_base_loss,
)
from args import get_args
from ablations import get_ablation_args


@dataclass
class CustomArguments(TrainingArguments):
    use_place_weighting: bool = True
    use_cont_loss: bool = False
    label_smoothing: float = 0.0
    split_digit: bool = True


class BaseTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._stored_metrics = defaultdict(lambda: defaultdict(list))

    def store_metrics(
        self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
    ) -> None:
        for key, value in metrics.items():
            self._stored_metrics[train_eval][key].append(value)

    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching training, including stored metrics.

        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """
        # logs either has 'loss' or 'eval_loss'
        train_eval = "train" if "loss" in logs else "eval"
        # Add averaged stored metrics to logs
        for key, metrics in self._stored_metrics[train_eval].items():
            logs[key] = torch.tensor(metrics).mean().item()
        del self._stored_metrics[train_eval]
        return super().log(logs)


class DIGITTrainer(BaseTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.use_place_weighting = kwargs.get("use_place_weighting", True)
        self.use_cont_loss = kwargs.get("use_cont_loss", False)
        self.label_smoothing = kwargs.get("label_smoothing", 0.0)

    def compute_loss(
        self, model, inputs, return_outputs=False, num_items_in_batch=None
    ):
        batch = inputs
        attention_mask = batch["attention_mask"]
        labels = batch.pop("labels")

        outputs = model(**batch)

        batch_size = len(labels)
        contrastive_labels = None
        if self.use_cont_loss:
            # used negative samples
            batch_size = batch_size // 2
            logits = outputs.logits[:batch_size]
            _labels = labels[:batch_size]
            _attention_mask = attention_mask[:batch_size]
            contrastive_labels = labels[batch_size:]
        else:
            logits = outputs.logits[:batch_size]
            _labels = labels[:batch_size]
            _attention_mask = attention_mask[:batch_size]

        if self.label_smoothing > 0:
            sft_loss = get_sft_loss(
                logits,
                _labels,
                _attention_mask,
                label_smoothing=self.label_smoothing,
            )
            digit_loss = torch.zeros_like(sft_loss)
        else:
            sft_loss = get_sft_loss(
                logits,
                _labels,
                _attention_mask,
            )
            if self.use_cont_loss:
                digit_loss = get_digit_loss_with_cont(
                    self.tokenizer,
                    logits,
                    _labels,
                    _attention_mask,
                    beta=1.0,
                    target_temperature=2.0,
                    use_place_weighting=self.use_place_weighting,
                    contrastive_logits=outputs.logits[batch_size:],
                    contrastive_labels=contrastive_labels,
                )
            else:
                digit_loss = get_digit_loss(
                    self.tokenizer,
                    logits,
                    _labels,
                    _attention_mask,
                    beta=1.0,
                    target_temperature=2.0,
                    use_place_weighting=self.use_place_weighting,
                )
            loss = sft_loss + digit_loss
        stats = {
            "loss/total": loss.detach(),
            "loss/sft": sft_loss.detach(),
            "loss/digit": digit_loss.detach(),
        }
        self.store_metrics(stats, train_eval="train")
        return (loss, outputs) if return_outputs else loss


class DIGITBaseTrainer(BaseTrainer):
    def compute_loss(
        self, model, inputs, return_outputs=False, num_items_in_batch=None
    ):
        batch = inputs
        attention_mask = batch["attention_mask"]
        labels = batch.pop("labels")

        is_llava = "llava" in self.model.config.model_type
        kwargs = {}
        if is_llava:
            kwargs["labels"] = labels  # expand labels to match the logit length
        outputs = model(**batch, **kwargs)
        if is_llava:
            labels = outputs.labels  # <-- important
            attention_mask = outputs.attention_mask

        sft_loss = get_sft_loss(
            outputs.logits,
            labels,
            attention_mask,
        )
        digit_loss = get_digit_base_loss(
            self.tokenizer,
            outputs.logits,
            labels,
            attention_mask,
            beta=1.0,
            target_temperature=2.0,
        )

        loss = sft_loss + digit_loss
        stats = {
            "loss/total": loss.detach(),
            "loss/sft": sft_loss.detach(),
            "loss/digit": digit_loss.detach(),
        }
        self.store_metrics(stats, train_eval="train")
        return (loss, outputs) if return_outputs else loss


class SFTTrainer(Trainer):
    def compute_loss(
        self, model, inputs, return_outputs=False, num_items_in_batch=None
    ):
        batch = inputs
        attention_mask = batch["attention_mask"]
        labels = batch.pop("labels")

        is_llava = "llava" in self.model.config.model_type
        kwargs = {}
        if is_llava:
            kwargs["labels"] = labels  # expand labels to match the logit length
        outputs = model(**batch, **kwargs)
        if is_llava:
            labels = outputs.labels  # <-- important
            attention_mask = outputs.attention_mask

        loss = get_sft_loss(
            outputs.logits,
            labels,
            attention_mask,
        )
        return (loss, outputs) if return_outputs else loss


def main():
    args = get_args()
    print(args)

    ablation_args = {}
    if args.ablation is not None:
        ablation_args = get_ablation_args(args.ablation)
    if "split_digit" in ablation_args:
        args.split_digit = ablation_args["split_digit"]
    args.use_cont_loss = ablation_args.get("use_cont_loss", False)

    model, tokenizer = load(args.model)

    data, collator = load_data(args, tokenizer)

    train_args = CustomArguments(
        output_dir=args.output_dir,
        per_device_train_batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        max_grad_norm=1,
        gradient_accumulation_steps=args.grad_acc,
        num_train_epochs=args.num_epochs,
        warmup_steps=args.num_warmup_steps,
        logging_steps=1,
        save_strategy="steps",
        save_steps=args.save_steps,
        save_total_limit=2,
        save_only_model=True,
        seed=42,
        fp16=True,
        tf32=True,
        dataloader_num_workers=args.num_workers,
        report_to="none",
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        torch_compile=True,
        deepspeed="./zero3.json",
        local_rank=args.local_rank,
        remove_unused_columns=False,
        run_name=args.run_name,
        **ablation_args,
    )
    # Force use our print callback
    if TRL_USE_RICH:
        train_args.disable_tqdm = True
        console = Console()
    trainer_cls = {
        "sft": SFTTrainer,
        "digit": DIGITTrainer,
        "digit_base": DIGITBaseTrainer,
    }[args.loss]

    init_context = (
        nullcontext()
        if not TRL_USE_RICH
        else console.status("[bold green]Initializing the SFTTrainer...")
    )
    save_context = (
        nullcontext()
        if not TRL_USE_RICH
        else console.status(
            f"[bold green]Training completed! Saving the model to {train_args.output_dir}"
        )
    )

    with init_context:
        trainer = trainer_cls(
            model=model,
            args=train_args,
            train_dataset=data,
            data_collator=collator,
            tokenizer=tokenizer,
        )
        trainer.train()
    with save_context:
        trainer.save_model(train_args.output_dir)


if __name__ == "__main__":
    main()
