import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import Qwen2ForCausalLM, Qwen2Model, Qwen2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast


class ChemLLMForMultiTask(ChemLLM2PreTrainedModel):
    def __init__(self, config, num_classes=6):
        super().__init__(config)
        self.model = ChemLLM2Model(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.classification_head = nn.Linear(config.hidden_size, num_classes)
        self.regression_head = nn.Linear(config.hidden_size, 1)
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def forward(
            self,
            input_ids: torch.LongTensor,
            attention_mask: Optional[torch.Tensor] = None,
            task_types: list[str] = None,
            class_labels: Optional[torch.LongTensor] = None,
            regression_values: Optional[torch.FloatTensor] = None,
            labels: Optional[torch.LongTensor] = None,  # 原生LM的labels
            **kwargs,
    ):
        transformer_outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs,
        )
        hidden_states = transformer_outputs[0]

        total_loss = None

        all_lm_logits = []
        all_class_logits = []
        all_reg_predictions = []

        for i in range(input_ids.shape[0]):
            task_type = task_types[i]

            if task_type == "qa":
                lm_logits = self.lm_head(hidden_states[i])
                all_lm_logits.append(lm_logits)

                if labels is not None:
                    shift_logits = lm_logits[..., :-1, :].contiguous()
                    shift_labels = labels[i, 1:].contiguous()
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
                    if total_loss is None:
                        total_loss = loss
                    else:
                        total_loss += loss

            elif task_type == "classification" or task_type == "regression":
                sequence_lengths = attention_mask[i].sum() - 1
                last_token_hidden_state = hidden_states[i, sequence_lengths, :]

                if task_type == "classification":
                    class_logits = self.classification_head(last_token_hidden_state)
                    all_class_logits.append(class_logits)

                    if class_labels is not None:
                        loss_fct = CrossEntropyLoss()
                        loss = loss_fct(class_logits.unsqueeze(0), class_labels[i].unsqueeze(0))
                        if total_loss is None:
                            total_loss = loss
                        else:
                            total_loss += loss

                elif task_type == "regression":
                    reg_prediction = self.regression_head(last_token_hidden_state)
                    all_reg_predictions.append(reg_prediction)

                    if regression_values is not None:
                        loss_fct = MSELoss()
                        loss = loss_fct(reg_prediction.squeeze(), regression_values[i])
                        if total_loss is None:
                            total_loss = loss
                        else:
                            total_loss += loss

        return {
            "loss": total_loss,
            "lm_logits": all_lm_logits,
            "class_logits": all_class_logits,
            "reg_predictions": all_reg_predictions
        }
