
import torch
import torch.nn as nn
from transformers import RobertaModel, RobertaConfig, RobertaTokenizer, RobertaForSequenceClassification

class LoRALayer(nn.Module):
    def __init__(self, input_dim, output_dim, r=4, alpha=32):
        super(LoRALayer, self).__init__()
        self.r = r
        self.alpha = alpha
        self.scaling = alpha / r

        self.lora_A = nn.Linear(input_dim, r, bias=False)
        self.lora_B = nn.Linear(r, output_dim, bias=False)
        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B.weight)

    def forward(self, x):
        return self.lora_B(self.lora_A(x)) * self.scaling

class RobertaWithLoRA(nn.Module):
    def __init__(self, model_name="roberta-large", num_labels=3, lora_r=4, lora_alpha=32):
        super(RobertaWithLoRA, self).__init__()
        self.config = RobertaConfig.from_pretrained(model_name)
        self.backbone = RobertaModel.from_pretrained(model_name, config=self.config)
        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        self.classifier = nn.Linear(self.config.hidden_size, num_labels)

        # Replace self-attention modules' projections with LoRA-enhanced projections
        for layer in self.backbone.encoder.layer:
            attn = layer.attention.self
            attn.query_lora = LoRALayer(self.config.hidden_size, self.config.hidden_size, r=lora_r, alpha=lora_alpha)
            attn.key_lora = LoRALayer(self.config.hidden_size, self.config.hidden_size, r=lora_r, alpha=lora_alpha)
            attn.value_lora = LoRALayer(self.config.hidden_size, self.config.hidden_size, r=lora_r, alpha=lora_alpha)

    def forward(self, input_ids, attention_mask=None, labels=None, distill_logits=None, alpha=0.5, temperature=2.0):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = self.dropout(outputs.last_hidden_state[:, 0])
        logits = self.classifier(pooled_output)

        if labels is not None:
            ce_loss = nn.CrossEntropyLoss()(logits, labels)
            if distill_logits is not None:
                log_probs = nn.LogSoftmax(dim=-1)(logits / temperature)
                targets = nn.Softmax(dim=-1)(distill_logits / temperature)
                kl_loss = nn.KLDivLoss(reduction="batchmean")(log_probs, targets) * (temperature ** 2)
                loss = alpha * ce_loss + (1 - alpha) * kl_loss
            else:
                loss = ce_loss
            return {"loss": loss, "logits": logits}
        else:
            return {"logits": logits}
