"""
RoBERTa fine-tuning model for delta regression.

End-to-end fine-tuning to predict the signed margin (delta) directly.
"""

from typing import Optional, List, Dict
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
)

from models.base import BaseModel


class RegressionDataset(Dataset):
    """Text dataset for regression fine-tuning."""

    def __init__(
        self,
        texts: List[str],
        targets: np.ndarray,
        tokenizer,
        max_len: int = 128,
    ):
        self.texts = texts
        self.targets = targets.astype(np.float32)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self) -> int:
        return len(self.texts)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        text = str(self.texts[idx])
        enc = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors=None,
        )
        return {
            "input_ids": torch.tensor(enc["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(enc["attention_mask"], dtype=torch.long),
            "labels": torch.tensor(self.targets[idx], dtype=torch.float32),
            "idx": torch.tensor(idx, dtype=torch.long),
        }


class RegressionTrainer(Trainer):
    """Trainer with MSE loss for regression."""

    def __init__(self, *args, sample_weights: Optional[np.ndarray] = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.sample_weights = sample_weights

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        idx = inputs.pop("idx", None)
        labels = inputs.pop("labels")
        outputs = model(**inputs)

        # Model outputs logits of shape (batch, 1) for num_labels=1
        predictions = outputs.logits.squeeze(-1)

        # MSE loss
        loss = (predictions - labels) ** 2

        if self.sample_weights is not None and idx is not None:
            weights = torch.tensor(
                self.sample_weights[idx.cpu().numpy()],
                dtype=torch.float32,
                device=loss.device,
            )
            loss = loss * weights

        loss = loss.mean()
        return (loss, outputs) if return_outputs else loss


class RobertaFinetuneRegressionModel(BaseModel):
    """
    End-to-end RoBERTa fine-tuning model for delta regression.

    Predicts the signed margin (delta) directly using MSE loss.
    At inference, classification is done by thresholding at 0.
    """

    def __init__(
        self,
        task: str = 'regression',
        hf_model: str = 'roberta-base',
        max_len: int = 128,
        epochs: int = 3,
        batch_size: int = 128,
        grad_accum: int = 4,
        lr: float = 2e-5,
        weight_decay: float = 0.01,
        **kwargs,
    ):
        if task != 'regression':
            raise ValueError("RobertaFinetuneRegressionModel only supports regression")

        super().__init__(task=task, **kwargs)
        self.hf_model = hf_model
        self.max_len = max_len
        self.epochs = epochs
        self.batch_size = batch_size
        self.grad_accum = grad_accum
        self.lr = lr
        self.weight_decay = weight_decay

        self.tokenizer = None
        self.model = None
        self.trainer = None

    def fit(
        self,
        X: List[str],
        y: np.ndarray,
        sample_weight: Optional[np.ndarray] = None,
    ) -> 'RobertaFinetuneRegressionModel':
        """Fine-tune RoBERTa for regression on the training data.

        Args:
            X: List of input texts
            y: Target values (delta values for regression)
            sample_weight: Optional per-example weights
        """
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model)

        # Create dataset
        train_dataset = RegressionDataset(X, y, self.tokenizer, self.max_len)

        # Load model with num_labels=1 for regression
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.hf_model,
            num_labels=1,
            problem_type="regression"
        )

        # Training args
        args = TrainingArguments(
            output_dir="outputs/roberta_finetune_reg_temp",
            overwrite_output_dir=True,
            eval_strategy="no",
            save_strategy="no",
            logging_strategy="steps",
            logging_steps=100,
            num_train_epochs=self.epochs,
            per_device_train_batch_size=self.batch_size,
            per_device_eval_batch_size=self.batch_size,
            learning_rate=self.lr,
            weight_decay=self.weight_decay,
            gradient_accumulation_steps=self.grad_accum,
            bf16=torch.cuda.is_available(),
            seed=42,
            report_to="wandb",
            dataloader_num_workers=4,
        )

        # Create trainer
        if sample_weight is not None:
            self.trainer = RegressionTrainer(
                model=self.model,
                args=args,
                train_dataset=train_dataset,
                sample_weights=sample_weight.astype(np.float32),
            )
        else:
            self.trainer = RegressionTrainer(
                model=self.model,
                args=args,
                train_dataset=train_dataset,
            )

        # Train
        self.trainer.train()
        self.is_fitted_ = True

        return self

    def predict(self, X: List[str]) -> np.ndarray:
        """Predict delta values."""
        self._check_fitted()

        dataset = RegressionDataset(
            X,
            np.zeros(len(X)),  # dummy targets
            self.tokenizer,
            self.max_len,
        )

        pred_output = self.trainer.predict(dataset)
        # Squeeze the last dimension since num_labels=1
        predictions = pred_output.predictions.squeeze(-1)
        return predictions

    def predict_class(self, X: List[str]) -> np.ndarray:
        """Predict class labels by thresholding delta at 0."""
        delta_pred = self.predict(X)
        return (delta_pred >= 0).astype(int)

    def __repr__(self) -> str:
        return f"RobertaFinetuneRegressionModel(hf_model={self.hf_model!r}, epochs={self.epochs}, batch_size={self.batch_size})"
