"""
RoBERTa fine-tuning model for cost-sensitive learning.

End-to-end fine-tuning with optional per-example weighting.
"""

from typing import Optional, List, Dict, Any
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
)

from models.base import BaseModel


class TextDataset(Dataset):
    """Simple text dataset for fine-tuning."""

    def __init__(
        self,
        texts: List[str],
        labels: np.ndarray,
        tokenizer,
        max_len: int = 128,
    ):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self) -> int:
        return len(self.texts)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        text = str(self.texts[idx])
        enc = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors=None,
        )
        return {
            "input_ids": torch.tensor(enc["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(enc["attention_mask"], dtype=torch.long),
            "labels": torch.tensor(int(self.labels[idx]), dtype=torch.long),
            "idx": torch.tensor(idx, dtype=torch.long),
        }


class WeightedTrainer(Trainer):
    """Trainer with per-example weighted cross-entropy loss."""

    def __init__(self, *args, sample_weights: Optional[np.ndarray] = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.sample_weights = sample_weights

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        idx = inputs.pop("idx", None)
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.logits

        loss_fct = nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1))

        if self.sample_weights is not None and idx is not None:
            weights = torch.tensor(
                self.sample_weights[idx.cpu().numpy()],
                dtype=torch.float32,
                device=loss.device,
            )
            loss = loss * weights

        loss = loss.mean()
        return (loss, outputs) if return_outputs else loss


class RobertaFinetuneModel(BaseModel):
    """
    End-to-end RoBERTa fine-tuning model.

    Supports weighted cross-entropy for cost-sensitive learning.
    """

    def __init__(
        self,
        task: str = 'classification',
        hf_model: str = 'roberta-base',
        max_len: int = 128,
        epochs: int = 3,
        batch_size: int = 128,
        grad_accum: int = 4,
        lr: float = 2e-5,
        weight_decay: float = 0.01,
        **kwargs,
    ):
        if task != 'classification':
            raise ValueError("RobertaFinetuneModel only supports classification")

        super().__init__(task=task, **kwargs)
        self.hf_model = hf_model
        self.max_len = max_len
        self.epochs = epochs
        self.batch_size = batch_size
        self.grad_accum = grad_accum
        self.lr = lr
        self.weight_decay = weight_decay

        self.tokenizer = None
        self.model = None
        self.trainer = None

    def fit(
        self,
        X: List[str],
        y: np.ndarray,
        sample_weight: Optional[np.ndarray] = None,
    ) -> 'RobertaFinetuneModel':
        """Fine-tune RoBERTa on the training data."""
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model)

        # Create dataset
        train_dataset = TextDataset(X, y, self.tokenizer, self.max_len)

        # Load model
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.hf_model, num_labels=2
        )

        # Training args
        args = TrainingArguments(
            output_dir="outputs/roberta_finetune_temp",
            overwrite_output_dir=True,
            eval_strategy="no",
            save_strategy="no",
            logging_strategy="steps",
            logging_steps=100,
            num_train_epochs=self.epochs,
            per_device_train_batch_size=self.batch_size,
            per_device_eval_batch_size=self.batch_size,
            learning_rate=self.lr,
            weight_decay=self.weight_decay,
            gradient_accumulation_steps=self.grad_accum,
            bf16=torch.cuda.is_available(),
            seed=42,
            report_to="none",
            dataloader_num_workers=4,
        )

        # Create trainer
        if sample_weight is not None:
            self.trainer = WeightedTrainer(
                model=self.model,
                args=args,
                train_dataset=train_dataset,
                sample_weights=sample_weight.astype(np.float32),
            )
        else:
            self.trainer = Trainer(
                model=self.model,
                args=args,
                train_dataset=train_dataset,
            )

        # Train
        self.trainer.train()
        self.is_fitted_ = True

        return self

    def predict(self, X: List[str]) -> np.ndarray:
        """Predict class labels."""
        self._check_fitted()

        dataset = TextDataset(
            X,
            np.zeros(len(X)),  # dummy labels
            self.tokenizer,
            self.max_len,
        )

        pred_output = self.trainer.predict(dataset)
        logits = pred_output.predictions
        return np.argmax(logits, axis=-1)

    def predict_proba(self, X: List[str]) -> np.ndarray:
        """Predict class probabilities."""
        self._check_fitted()

        dataset = TextDataset(
            X,
            np.zeros(len(X)),
            self.tokenizer,
            self.max_len,
        )

        pred_output = self.trainer.predict(dataset)
        logits = pred_output.predictions

        # Softmax
        exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
        probs = exp_logits / exp_logits.sum(axis=-1, keepdims=True)
        return probs

    def __repr__(self) -> str:
        return f"RobertaFinetuneModel(hf_model={self.hf_model!r}, epochs={self.epochs}, batch_size={self.batch_size})"
