from tqdm import tqdm
import torch
from torch.utils.data import DataLoader, SequentialSampler
import wandb

from loss import SelfDistillationLoss
from transformers import Trainer, TrainingArguments
from typing import Dict, List, Optional, Union, Any, Tuple
import os
from dataclasses import dataclass

from model import SelfDistillationOutput
from transformers.modeling_utils import PreTrainedModel

class CropTrainer(Trainer):
    def __init__(
        self,
        model,
        args: TrainingArguments,
        processing_class=None,
        train_dataset=None,
        eval_dataset=None,
        data_collator=None,
        sequential_sampler=True,
        temperature=1.0,
        **kwargs
    ):
        super().__init__(
            model=model,
            args=args,
            processing_class=processing_class,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=data_collator,
            **kwargs
        )
        # SequentialSampler is used for training for put global-local image in the same batch
        # and reproducibility
        self.sequential_sampler = sequential_sampler
        self.temperature = temperature

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # Add collect_hidden_states and layer selectors to forward pass
        if "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        
        outputs = model(
            **inputs,
            labels=labels,
        )
        
        total_loss = outputs.loss

            
        return (total_loss, outputs) if return_outputs else total_loss

    def _get_train_sampler(self,) -> torch.utils.data.Sampler:
        """
        Override get_train_sampler to use SequentialSampler for training
        """
        if self.sequential_sampler:
            return SequentialSampler(self.train_dataset)
        else:
            return super()._get_train_sampler()
