
import torch
import transformers


def detect_model_type(model):
    if hasattr(model, 'swin'):
        return 'swin'
    elif hasattr(model, 'vit'):
        return 'vit'
    elif hasattr(model, 'resnet'):
        return 'resnet'
    else:
        return None


def get_forward_function(model):
    model_type = detect_model_type(model)
    
    if model_type == 'swin':
        return swin_forward
    elif model_type == 'vit':
        return vit_forward
    elif model_type == 'resnet':
        return resnet_forward
    else:
        model_class = model.__class__.__name__
        raise ValueError(f"Unsupported model type: {model_class}. Supported types: Swin, ViT, ResNet")



def swin_forward(
    self,
    pixel_values=None,
    head_mask=None,
    output_attentions=None,
    output_hidden_states=None,
    interpolate_pos_encoding=None,
    return_dict=None,
    labels=None,
    sample_idx=None,
    weight=None
):
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    outputs = self.swin(
        pixel_values,
        head_mask=head_mask,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        interpolate_pos_encoding=interpolate_pos_encoding,
        return_dict=return_dict,
    )
    pooled_output = outputs[1]

    logits = self.classifier(pooled_output)

    loss = None
    if labels is not None:
        mean_loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
        
        batch_size = logits.size(0)
        loss = mean_loss.expand(batch_size)

        ret_scores = loss
        if self.training:
            self.trainset.update_scores(sample_idx.detach().cpu(), ret_scores.detach().cpu().numpy())
            
            if isinstance(weight, torch.Tensor):
                weights = weight.to(loss.device).view(-1, 1)
            else:
                weights = torch.tensor(weight, device=loss.device).view(-1, 1)

            loss = (loss * weights).mean() / weights.mean()
        else:
            loss = loss.mean()

    if not return_dict:
        output = (logits,) + outputs[2:]
        return ((loss,) + output) if loss is not None else output

    return transformers.models.swin.modeling_swin.SwinImageClassifierOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
        reshaped_hidden_states=outputs.reshaped_hidden_states,
    )


def vit_forward(
    self,
    pixel_values=None,
    head_mask=None,
    output_attentions=None,
    output_hidden_states=None,
    interpolate_pos_encoding=None,
    return_dict=None,
    labels=None,
    sample_idx=None,
    weight=None
):
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    outputs = self.vit(
        pixel_values,
        head_mask=head_mask,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        interpolate_pos_encoding=interpolate_pos_encoding,
        return_dict=return_dict,
    )

    sequence_output = outputs[0]

    logits = self.classifier(sequence_output[:, 0, :])

    loss = None
    if labels is not None:
        labels = labels.to(logits.device)
        
        if self.config.problem_type is None:
            if self.num_labels == 1:
                self.config.problem_type = "regression"
            elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                self.config.problem_type = "single_label_classification"
            else:
                self.config.problem_type = "multi_label_classification"
        
        if self.config.problem_type == "regression":
            loss_fct = torch.nn.MSELoss(reduction='none')
            if self.num_labels == 1:
                loss = loss_fct(logits.squeeze(), labels.squeeze())
            else:
                loss = loss_fct(logits, labels)
        elif self.config.problem_type == "single_label_classification":
            loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        elif self.config.problem_type == "multi_label_classification":
            loss_fct = torch.nn.BCEWithLogitsLoss(reduction='none')
            loss = loss_fct(logits, labels)

        ret_scores = loss
        if self.training:
            self.trainset.update_scores(sample_idx.detach().cpu(), ret_scores.detach().cpu().numpy())
            
            if isinstance(weight, torch.Tensor):
                weights = weight.to(loss.device).view(-1, 1)
            else:
                weights = torch.tensor(weight, device=loss.device).view(-1, 1)

            loss = (loss * weights).mean() / weights.mean()
        else:
            loss = loss.mean()

    if not return_dict:
        output = (logits,) + outputs[1:]
        return ((loss,) + output) if loss is not None else output

    return transformers.modeling_outputs.ImageClassifierOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )


def resnet_forward(
    self,
    pixel_values=None,
    head_mask=None,
    output_attentions=None,
    output_hidden_states=None,
    interpolate_pos_encoding=None,
    return_dict=None,
    labels=None,
    sample_idx=None,
    weight=None
):
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    outputs = self.resnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
    pooled_output = outputs.pooler_output if return_dict else outputs[1]

    logits = self.classifier(pooled_output)

    loss = None
    if labels is not None:
        labels = labels.to(logits.device)
        
        if self.config.problem_type is None:
            if self.num_labels == 1:
                self.config.problem_type = "regression"
            elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                self.config.problem_type = "single_label_classification"
            else:
                self.config.problem_type = "multi_label_classification"
        
        if self.config.problem_type == "regression":
            loss_fct = torch.nn.MSELoss(reduction='none')
            if self.num_labels == 1:
                loss = loss_fct(logits.squeeze(), labels.squeeze())
            else:
                loss = loss_fct(logits, labels)
        elif self.config.problem_type == "single_label_classification":
            loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        elif self.config.problem_type == "multi_label_classification":
            loss_fct = torch.nn.BCEWithLogitsLoss(reduction='none')
            loss = loss_fct(logits, labels)

        ret_scores = loss
        if self.training:
            self.trainset.update_scores(sample_idx.detach().cpu(), ret_scores.detach().cpu().numpy())
            
            if isinstance(weight, torch.Tensor):
                weights = weight.to(loss.device).view(-1, 1)
            else:
                weights = torch.tensor(weight, device=loss.device).view(-1, 1)

            loss = (loss * weights).mean() / weights.mean()
        else:
            loss = loss.mean()

    if not return_dict:
        output = (logits,) + outputs[2:]
        return ((loss,) + output) if loss is not None else output

    return transformers.modeling_outputs.ImageClassifierOutputWithNoAttention(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
    )
class LabelSmoother:
    def __init__(self, epsilon: float = 0.1, ignore_index: int = -100):
        self.epsilon = epsilon
        self.ignore_index = ignore_index

    def __call__(self, model_output, labels, shift_labels=False, return_per_sample=False):
        logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
        if shift_labels:
            logits = logits[..., :-1, :].contiguous()
            labels = labels[..., 1:].contiguous()

        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

        if torch.is_floating_point(labels):
            if self.epsilon > 0:
                V = log_probs.shape[-1]
                labels = (1.0 - self.epsilon) * labels + self.epsilon / V
            per_sample_loss = -(labels * log_probs).sum(dim=-1)  
            return per_sample_loss if return_per_sample else per_sample_loss.mean()

        if labels.dim() == log_probs.dim() - 1:
            labels = labels.unsqueeze(-1)
        padding_mask = labels.eq(self.ignore_index)
        labels = torch.clamp(labels.long(), min=0)

        nll = (-log_probs).gather(dim=-1, index=labels)
        V = log_probs.shape[-1]
        smooth = (-log_probs).sum(dim=-1, keepdim=True, dtype=torch.float32) / V

        nll = nll.masked_fill(padding_mask, 0.0)
        smooth = smooth.masked_fill(padding_mask, 0.0)

        loss = (1.0 - self.epsilon) * nll + self.epsilon * smooth
        loss = loss.squeeze(-1)

        active = (~padding_mask.squeeze(-1)).to(loss.dtype)
        if loss.dim() == 1:
            num_active = active.sum(dim=0).clamp_min(1.0)
            return loss if return_per_sample else loss.sum() / num_active
        else:
            reduce_dims = list(range(1, loss.dim()))
            num_active_per_sample = active.sum(dim=reduce_dims).clamp_min(1.0)
            loss_per_sample = loss.sum(dim=reduce_dims) / num_active_per_sample
            return loss_per_sample if return_per_sample else loss_per_sample.mean()


from torchvision.transforms import v2
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
    
    labels = inputs.pop("labels")
    weight = inputs.pop("weight", None)
    sample_idx = inputs.pop("sample_idx", None)
    
    update_score_flag = (weight is not None and sample_idx is not None)
    if self.model_accepts_loss_kwargs:
        loss_kwargs = {}
        if num_items_in_batch is not None:
            loss_kwargs["num_items_in_batch"] = num_items_in_batch
        inputs = {**inputs, **loss_kwargs}
    
    cutmix_prob = float(getattr(self.args, "use_cutmixup", 0.0) or 0.0)
    if cutmix_prob > 0.0 and self.model.training:
        pixel_values = inputs.get("pixel_values")

        num_labels = self.model.config.num_labels
        cutmix = v2.CutMix(num_classes=num_labels)
        mixup = v2.MixUp(num_classes=num_labels)
        
        cutmix_or_mixup = v2.RandomChoice([cutmix, mixup], p=[0.5, 0.5])
        
        random_augment = v2.RandomApply([cutmix_or_mixup], p=cutmix_prob)

        pixel_values, labels = random_augment(pixel_values, labels)

        inputs["pixel_values"] = pixel_values

    outputs = model(**inputs)

    if self.args.past_index >= 0:
        self._past = outputs[self.args.past_index]

    per_sample_loss = self.label_smoother(outputs, labels, return_per_sample=True)
    
    if update_score_flag: 
        self.train_dataset.update_scores(sample_idx.detach().cpu(), per_sample_loss.detach().cpu().numpy())
    
    if self.model.training and update_score_flag:     
        if isinstance(weight, torch.Tensor):
            weights = weight.to(per_sample_loss.device).view(-1)
        else:
            weights = torch.tensor(weight, device=per_sample_loss.device).view(-1)
        
        loss = (per_sample_loss * weights).mean() / weights.mean()
    else:
        loss = per_sample_loss.mean()


    return (loss, outputs) if return_outputs else loss