import os
from typing import Optional

import torch
import wandb

from transformers.trainer import Trainer, TRAINING_ARGS_NAME
import torch.distributed as dist
from .modeling import EncoderModel

import logging
logger = logging.getLogger(__name__)


class LeanFinderTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        logger.info("LeanFinderTrainer init")
        super(LeanFinderTrainer, self).__init__(*args, **kwargs)
        self.is_ddp = dist.is_initialized()
        self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1
        logger.info("LeanFinderTrainer ddp: %s", self.is_ddp)

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving model checkpoint to {output_dir}")

        supported_classes = (EncoderModel,)
        if not isinstance(self.model, supported_classes):
            raise ValueError(f"Unsupported model class {self.model}")
        else:
            if state_dict is None:
                state_dict = self.model.state_dict()
            prefix = 'encoder.'
            assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys())
            state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
            self.model.encoder.save_pretrained(
                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            )

        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)

        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        query, passage = inputs
        return model(query=query, passage=passage).loss

    def training_step(self, *args):
        return super(LeanFinderTrainer, self).training_step(*args) / self._dist_loss_scale_factor

    def on_train_end(self):
        super().on_train_end()
        
        if self.is_world_process_zero() and wandb.run is not None:
            artifact = wandb.Artifact(
                name=f"leanfinder-retrieval-{wandb.run.id}",
                type="model",
                description=f"Trained retrieval model for LeanFinder",
            )
            artifact.add_dir(self.args.output_dir)
            wandb.log_artifact(artifact)

