import torch.nn.functional as F
import numpy as np
import random
import os
import json
import torch
import torch.nn as nn
import numpy as np
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModel,
    AutoConfig,
    AutoTokenizer,
    PreTrainedModel,
    get_linear_schedule_with_warmup
)
from torch.optim import AdamW
from sklearn.metrics import precision_recall_fscore_support, classification_report
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm
from collections import defaultdict
from torch.amp import GradScaler, autocast  
torch.set_float32_matmul_precision('high')

###############################################################################
# 1. Data Handling
###############################################################################

class PromptWordDataset(Dataset):
    """
    A custom dataset to handle:
    - the prompt (text) -> prompt_label
    - the word-level labels for explanation

    The dataset will store:
        - input_ids
        - attention_mask
        - prompt_label
        - word_labels (one per sub-token)
        - original words (for analysis/predictions later)
    """

    def __init__(self, data, tokenizer, max_length=512,prompt_penalty_dict=None):
        """
        data: list of items with the following structure:
            {
               "id": str,
               "prompt": str,
               "prompt_label": int (0 or 1),
               "word_label": [ [word1, 0/1], [word2, 0/1], ... ]
            }
        tokenizer: a modernBert tokenizer
        max_length: maximum sequence length
        """
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        self.examples = []
        
        for item in self.data:
            prompt_text = item["prompt"]
            prompt_label = item["prompt_label"]
            word_label_pairs = item["word_label"]  # list of [word, label]
            
            words = [wp[0] for wp in word_label_pairs]
            word_level_labels = [wp[1] for wp in word_label_pairs]
            prompt_penalty = 0.0
            if prompt_penalty_dict:
                prompt_penalty = prompt_penalty_dict.get(item["id"], 0.0)
            # Tokenize in a way that allows us to align sub-tokens with word labels
            splitted_text = [wp[0] for wp in word_label_pairs]
            encoding = self.tokenizer(
                splitted_text,
                is_split_into_words=True,
                max_length=self.max_length,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            )

            input_ids = encoding["input_ids"].squeeze(0)       # shape (max_length,)
            attention_mask = encoding["attention_mask"].squeeze(0)

            # Construct sub-token labels
            word_ids = encoding.word_ids(batch_index=0)  # list of length max_length
            token_labels = []
            first_mask = []
            prev_wid = None
            for w_id in word_ids:
                if w_id is None:
                    token_labels.append(-100)  # ignore index
                    first_mask.append(False)
                else:
                    token_labels.append(word_level_labels[w_id])
                    is_first = (w_id != prev_wid)
                    first_mask.append(is_first)
                    prev_wid = w_id
            
            self.examples.append({
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "prompt_label": prompt_label,
                "token_labels": token_labels,
                "penalty": torch.tensor(prompt_penalty, dtype=torch.float),
                "first_mask": torch.tensor(first_mask, dtype=torch.bool),
                "original_data": item  # for analysis
            })

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        return {
            "input_ids": example["input_ids"],
            "attention_mask": example["attention_mask"],
            "prompt_label": example["prompt_label"],
            "token_labels": torch.tensor(example["token_labels"], dtype=torch.long),
            "penalty": example["penalty"],
            "first_mask": example["first_mask"],
            "original_data": example["original_data"]
        }


def collate_fn(batch):
    """
    Collate function to combine items into a single batch for the DataLoader.
    """
    input_ids = torch.stack([item["input_ids"] for item in batch])
    attention_mask = torch.stack([item["attention_mask"] for item in batch])

    prompt_labels = torch.tensor([item["prompt_label"] for item in batch], dtype=torch.long)
    token_labels = torch.stack([item["token_labels"] for item in batch])
    penalties = torch.stack([item["penalty"] for item in batch])
    first_masks = torch.stack([item["first_mask"] for item in batch])
    original_data = [item["original_data"] for item in batch]

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "prompt_labels": prompt_labels,
        "token_labels": token_labels,
        "penalties": penalties,
        "first_masks": first_masks,
        "original_data": original_data
    }


def compute_subtoken_stats(train_data, tokenizer):
    """
    Computes how many times each subtoken is used as safe or unsafe in the training dataset.
    Returns a dict {subtoken: (safe_count, unsafe_count)}.
    """
    subtoken_stats = defaultdict(lambda: [0, 0])

    for item in train_data:
        for word, label in item["word_label"]:
            subwords = tokenizer.tokenize(word)
            for st in subwords:
                if label == 0:
                    subtoken_stats[st][0] += 1
                else:
                    subtoken_stats[st][1] += 1

    return subtoken_stats

def compute_subtoken_penalties(subtoken_stats):
    """
    Returns: {subtoken: (penalty_mag, dominant_label)}
        penalty_mag   ∈ (0,1]  -  |safe_cnt - unsafe_cnt| / total
        dominant_label ∈ {0,1} -  0 if safe≥unsafe else 1
    """
    penalties = {}
    eps = 1e-6
    for st, (safe_cnt, unsafe_cnt) in subtoken_stats.items():
        total = safe_cnt + unsafe_cnt + eps
        mag = abs(safe_cnt - unsafe_cnt) / total
        dom = 0 if safe_cnt >= unsafe_cnt else 1
        penalties[st] = (mag, dom)
    return penalties


def compute_prompt_penalties(subtoken_stats, data, tokenizer):
    """
    Computes a prompt-level penalty for each prompt.

    Args:
        subtoken_stats: dict of {subtoken: (safe_count, unsafe_count)}
        data: list of dicts, each with 'id', 'prompt_label', and 'word_label'
        tokenizer: tokenizer for subword tokenization

    Returns:
        dict of {prompt_id: penalty_score} where penalty_score ∈ [0, 1]
    """
    prompt_penalties = {}
    eps = 1e-6

    for item in data:
        prompt_id = item["id"]
        prompt_label = item["prompt_label"]
        word_label_pairs = item["word_label"]  # [(word, token_label), ...]

        numerator = 0.0
        denominator = 0.0

        for word, token_label in word_label_pairs:
            subtokens = tokenizer.tokenize(word)
            for st in subtokens:
                safe_cnt, unsafe_cnt = subtoken_stats.get(st, (0, 0))
                total = safe_cnt + unsafe_cnt
                if total == 0:
                    continue
                dominant_label = 0 if safe_cnt >= unsafe_cnt else 1

                #if dominant_label == prompt_label and token_label == prompt_label:
                if token_label == prompt_label:
                    numerator += abs(safe_cnt - unsafe_cnt)
                    denominator += total

        penalty = numerator / (denominator + eps)
        prompt_penalties[prompt_id] = 2.0 * penalty

    return prompt_penalties


                


###############################################################################
# 2. Model Definition
###############################################################################
class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attn = nn.Linear(hidden_size, 1)

    def forward(self, embeddings, mask):
        """
        embeddings: Tensor (batch_size, seq_len, hidden_size)
        mask: Tensor (batch_size, seq_len) with 1 for valid tokens, 0 for padding
        """
        attn_scores = self.attn(embeddings).squeeze(-1)  # (batch_size, seq_len)
        neg_inf = torch.finfo(attn_scores.dtype).min        # ‑65504 for fp16
        attn_scores = attn_scores.masked_fill(mask == 0, neg_inf)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        pooled_output = torch.sum(embeddings * attn_weights.unsqueeze(-1), dim=1)
        return pooled_output


class JointModel(PreTrainedModel):
    def __init__(
        self,
        model_name,
        config,
        prompt_class_weights=None,
        word_class_weights=None,
        alpha=0.01,
        subtoken_penalties=None,
        tokenizer=None,
        ablation=None
    ):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = AutoModel.from_pretrained(model_name, config=config)
        self.tokenizer = tokenizer

        # Shared Dense Layer
        #self.shared_dense = nn.Linear(config.hidden_size, config.hidden_size)
        #self.activation = nn.ReLU()

        # Attention Pooling
        self.attention_pooling = AttentionPooling(config.hidden_size)

        # Classification Heads
        self.prompt_classifier = nn.Linear(config.hidden_size, 2)
        self.token_classifier = nn.Linear(config.hidden_size, 2)

        # Trainable Loss Weights
        #self.lambda_prompt = nn.Parameter(torch.tensor(1.0, dtype=torch.float, requires_grad=True))
        #self.lambda_word = nn.Parameter(torch.tensor(1.0, dtype=torch.float, requires_grad=True))
        # ——— Uncertainty weighting params ———
        # We learn log(σ) for each task; σ = exp(log_sigma)
        self.log_sigma_prompt = nn.Parameter(torch.zeros(()))
        self.log_sigma_token  = nn.Parameter(torch.zeros(()))
        # L2 Regularization Strength
        self.alpha = alpha  # Controls regularization

        # Loss weights
        self.prompt_class_weights = prompt_class_weights if prompt_class_weights is not None else torch.tensor([1.0, 1.0])
        self.word_class_weights = word_class_weights if word_class_weights is not None else torch.tensor([1.0, 1.0])

        #self.subtoken_penalties = subtoken_penalties or {}
        #Save subtoken penalties
        vocab_size = tokenizer.vocab_size
        mag_table  = torch.zeros(vocab_size)                 # (V,)
        dom_table  = torch.full((vocab_size,), -1, dtype=torch.long)

        if subtoken_penalties:
            for tok_str, (mag, dom) in subtoken_penalties.items():
                tok_id = tokenizer.convert_tokens_to_ids(tok_str)
                if tok_id != tokenizer.unk_token_id:
                    mag_table[tok_id] = mag
                    dom_table[tok_id] = dom

        # register as buffers so they move with .to(device) and are not updated
        self.register_buffer("penalty_mag_table", mag_table, persistent=False)
        self.register_buffer("penalty_dom_table", dom_table, persistent=False)

        # Store ablation multipliers
        self.ablation   = ablation
        self.init_weights()


    def forward(
        self,
        input_ids,
        attention_mask=None,
        prompt_labels=None,
        token_labels=None,
        prompt_penalties=None
    ):
        """
        Args
        ----
        input_ids      : (B, S)
        attention_mask : (B, S)
        prompt_labels  : (B,)         gold prompt class (0/1)      or None
        token_labels   : (B, S)      gold token labels (0/1/-100) or None

        Returns
        -------
        dict with keys {"loss", "prompt_logits", "token_logits"}
        """
        # ------------------------------------------------------------
        # 1) Encoder
        # ------------------------------------------------------------
        enc_out = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        hidden = enc_out.last_hidden_state            # (B, S, H)

        # ------------------------------------------------------------
        # 2) Two heads
        # ------------------------------------------------------------
        pooled         = self.attention_pooling(hidden, attention_mask)
        prompt_logits  = self.prompt_classifier(pooled)     # (B, 2)
        token_logits   = self.token_classifier(hidden)      # (B, S, 2)

        loss = None
        if prompt_labels is not None:
            ce_prompt = F.cross_entropy(
                prompt_logits,
                prompt_labels,
                weight=self.prompt_class_weights.to(prompt_logits.device),
                reduction="none"
            )
            pt_prompt = torch.exp(-ce_prompt)
            fl_prompt = (1.0 - pt_prompt).pow(2.0) * ce_prompt
            loss_prompt = ce_prompt + (fl_prompt * prompt_penalties) * self.ablation["f"] * self.ablation["d_p"]
            loss_prompt = loss_prompt.mean()

            loss_token = torch.tensor(0.0, device=input_ids.device)

            if token_labels is not None:
                B, S = input_ids.size()
                valid = token_labels != -100

                per_tok_ce = F.cross_entropy(
                    token_logits.view(-1, 2),
                    token_labels.view(-1),
                    ignore_index=-100,
                    reduction="none"
                ).view(B, S)

                gamma = 2.0
                pt = torch.exp(-per_tok_ce)
                per_tok_fl = (1.0 - pt).pow(gamma) * per_tok_ce
                
                with torch.no_grad():
                    mags = self.penalty_mag_table[input_ids]
                    doms = self.penalty_dom_table[input_ids].long()
                    same_cls = (doms == token_labels)
                    token_penalties = mags * same_cls.float() * valid.float()
                '''with torch.no_grad():
                    mags = self.penalty_mag_table[input_ids]
                    token_penalties = mags * valid.float()'''


                combined = per_tok_ce + per_tok_fl * token_penalties * self.ablation['f'] * self.ablation['d_w']
                #loss_token = combined[valid].mean()
                loss_token = combined[valid].sum() / input_ids.size(0)
            '''reg = self.alpha * (
                self.lambda_prompt.pow(2)+
                self.lambda_word.pow(2)
            )

            loss = (self.lambda_prompt.pow(2)) * loss_prompt + \
                   (self.lambda_word.pow(2)) * loss_token + \
                   reg * self.ablation['r']'''


            if self.ablation['u']:
                # uncertainty‐weighted joint loss
                sigma_p = torch.exp(self.log_sigma_prompt)
                sigma_t = torch.exp(self.log_sigma_token)
                loss_p = 0.5 * loss_prompt / (sigma_p**2) + self.log_sigma_prompt
                loss_t = 0.5 * loss_token / (sigma_t**2) + self.log_sigma_token
                loss = loss_p + loss_t
            else:
                loss = loss_prompt + loss_token

        return {
            "loss": loss,
            "prompt_logits": prompt_logits,
            "token_logits": token_logits
        }
###############################################################################
# 3. Metrics and Utility
###############################################################################

def compute_metrics(preds, labels, average='micro'):
    p, r, f1, _ = precision_recall_fscore_support(labels, preds, average=average, zero_division=0)
    return {
        f"{average}_precision": p,
        f"{average}_recall": r,
        f"{average}_f1": f1
    }

def print_classwise_report(labels, preds):
    target_names = ["safe (0)", "unsafe (1)"]
    print(classification_report(labels, preds, target_names=target_names, zero_division=0))

def label_to_str(label_int):
    return "safe" if label_int == 0 else "unsafe"

###############################################################################
# 4. Train, Evaluate, and Prediction Functions
###############################################################################

def train_one_epoch(model, dataloader, optimizer, scheduler, scaler, device):
    model.train()
    total_loss = 0.0

    for batch in tqdm(dataloader, desc="Training", leave=False):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        prompt_labels = batch["prompt_labels"].to(device)
        token_labels = batch["token_labels"].to(device)
        prompt_penalties = batch["penalties"].to(device)
        optimizer.zero_grad()                                # move to top
        

        with autocast(device_type="cuda"):                    # FP16/BF16 region
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                prompt_labels=prompt_labels,
                token_labels=token_labels,
                prompt_penalties=prompt_penalties 
            )
            loss = outputs["loss"]                           # still FP32 here

        scaler.scale(loss).backward()                        # scaled backward
        scaler.step(optimizer)                               # scaled step
        scaler.update()                                      # adjust scale

        scheduler.step()                                     # LR scheduler

        
        total_loss += loss.item()

    return total_loss / len(dataloader)

@torch.inference_mode()                                # disables grad + sets eval
def evaluate(
        model: nn.Module,
        dataloader: torch.utils.data.DataLoader,
        device: torch.device,
        prompt_thr: float = 0.6,
        token_thr : float = 0.4
):
    """
    Returns:
        ((prompt_pred, prompt_gold), (token_pred, token_gold))
        where each entry is a flat Python list of ints.
        Token lists contain **only first sub‑tokens**.
    """
    model.eval()

    prompt_preds,  prompt_golds  = [], []
    token_preds,   token_golds   = [], []

    # AMP speeds the forward pass; device_type arg valid for PT >= 2.1, ignored otherwise
    with autocast(device_type='cuda'):
        for batch in dataloader:
            # ------------------------------------------------------------------
            # 1) Move tensors to GPU
            # ------------------------------------------------------------------
            input_ids      = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            p_gold         = batch["prompt_labels"].to(device)
            t_gold         = batch["token_labels"].to(device)
            first_mask     = batch["first_masks"].to(device)          # (B, S) bool

            # ------------------------------------------------------------------
            # 2) Forward pass
            # ------------------------------------------------------------------
            out = model(input_ids=input_ids, attention_mask=attention_mask)
            p_logits = out["prompt_logits"]            # (B, 2)
            t_logits = out["token_logits"]             # (B, S, 2)

            # ------------------------------------------------------------------
            # 3) Prompt‑level predictions (vectorised)
            # ------------------------------------------------------------------
            p_probs   = p_logits.softmax(dim=1)        # (B, 2)
            unsafe_p  = p_probs[:, 1]
            safe_p    = p_probs[:, 0]
            p_pred_b  = ((unsafe_p > safe_p) & (unsafe_p > prompt_thr)).int()

            prompt_preds .append(p_pred_b.cpu())
            prompt_golds .append(p_gold.cpu())

            # ------------------------------------------------------------------
            # 4) Token‑level predictions – first sub‑tokens only
            # ------------------------------------------------------------------
            t_probs   = t_logits.softmax(dim=2)        # (B, S, 2)
            unsafe_t  = t_probs[:, :, 1]
            safe_t    = t_probs[:, :, 0]

            valid_mask   = (t_gold != -100)            # exclude padding / ignored
            first_valid  = first_mask & valid_mask     # keep only first sub‑tokens

            t_pred_b = ((unsafe_t > safe_t) & (unsafe_t > token_thr)).int()

            token_preds .append(t_pred_b[first_valid].cpu())
            token_golds .append(t_gold[first_valid].cpu())

    # --------------------------------------------------------------------------
    # 5) Flatten once at the end (avoids many small Python allocs)
    # --------------------------------------------------------------------------
    prompt_preds  = torch.cat(prompt_preds).tolist()
    prompt_golds  = torch.cat(prompt_golds).tolist()
    token_preds   = torch.cat(token_preds ).tolist()
    token_golds   = torch.cat(token_golds ).tolist()

    return (prompt_preds, prompt_golds), (token_preds, token_golds)

def predict(model, dataloader, tokenizer, max_length, device, prompt_threshold=0.0, word_threshold=0.0):
    model.eval()
    results = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Predicting", leave=False):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            
            prompt_probs = F.softmax(outputs["prompt_logits"], dim=1)
            token_probs = F.softmax(outputs["token_logits"], dim=2)

            prompt_safe_probs = prompt_probs[:, 0].cpu().numpy()
            prompt_unsafe_probs = prompt_probs[:, 1].cpu().numpy()

            token_safe_probs = token_probs[:, :, 0].cpu().numpy()
            token_unsafe_probs = token_probs[:, :, 1].cpu().numpy()

            for i in range(len(batch["original_data"])):
                original_item = batch["original_data"][i]
                words_and_labels = original_item["word_label"]

                # Prompt-level classification
                unsafe_prob = prompt_unsafe_probs[i]
                safe_prob = prompt_safe_probs[i]
                predicted_prompt_label = 1 if (unsafe_prob > safe_prob and unsafe_prob > prompt_threshold) else 0

                token_label_example = batch["token_labels"][i].cpu().numpy()
                pred_word_labels = []

                words = [w_l[0] for w_l in words_and_labels]
                encoding = tokenizer(
                    words,
                    is_split_into_words=True,
                    max_length=max_length,
                    truncation=True,
                    padding="max_length",
                    return_tensors="pt"
                )
                word_ids = encoding.word_ids(batch_index=0)

                token_safe = token_safe_probs[i]
                token_unsafe = token_unsafe_probs[i]

                word_id_to_label = {}
                for j in range(len(word_ids)):
                    if word_ids[j] is not None and j < len(token_unsafe):
                        u_prob = token_unsafe[j]
                        s_prob = token_safe[j]
                        predicted_label = 1 if (u_prob > s_prob and u_prob > word_threshold) else 0
                        if word_ids[j] not in word_id_to_label:
                            word_id_to_label[word_ids[j]] = predicted_label

                for widx in range(len(words_and_labels)):
                    if widx in word_id_to_label:
                        pred_word_labels.append([words_and_labels[widx][0], word_id_to_label[widx]])
                    else:
                        pred_word_labels.append([words_and_labels[widx][0], 0])  # default safe

                result_item = {
                    "prompt": original_item["prompt"],
                    "original_prompt_label": label_to_str(original_item["prompt_label"]),
                    "predicted_prompt_label": label_to_str(predicted_prompt_label),
                    "original_word_labels": words_and_labels,
                    "predicted_word_labels": pred_word_labels
                }
                results.append(result_item)

    return results
def find_best_thresholds(model, dataloader, device, verbose=True):
    @torch.inference_mode()
    def get_logits_and_labels(
        model: nn.Module,
        dataloader: torch.utils.data.DataLoader,
        device: torch.device
    ):
        """
        Returns:
            Dictionary with:
                - prompt_probs: (N, 2) tensor
                - prompt_labels: (N,) tensor
                - token_probs: (B, S, 2) tensor
                - token_labels: (B, S) tensor
                - first_masks:  (B, S) bool tensor
        """
        model.eval()

        prompt_probs_list = []
        prompt_labels_list = []

        token_probs_list = []
        token_labels_list = []
        first_masks_list = []

        with autocast(device_type='cuda'):
            for batch in tqdm(dataloader, desc="Getting logits", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                p_labels = batch["prompt_labels"].to(device)
                t_labels = batch["token_labels"].to(device)
                first_mask = batch["first_masks"].to(device)

                outputs = model(input_ids=input_ids, attention_mask=attention_mask)

                p_logits = outputs["prompt_logits"]         # (B, 2)
                t_logits = outputs["token_logits"]          # (B, S, 2)

                p_probs = F.softmax(p_logits, dim=1)        # (B, 2)
                t_probs = F.softmax(t_logits, dim=2)        # (B, S, 2)

                prompt_probs_list.append(p_probs.cpu())
                prompt_labels_list.append(p_labels.cpu())

                token_probs_list.append(t_probs.cpu())
                token_labels_list.append(t_labels.cpu())
                first_masks_list.append(first_mask.cpu())

        return {
            "prompt_probs": torch.cat(prompt_probs_list, dim=0),
            "prompt_labels": torch.cat(prompt_labels_list, dim=0),
            "token_probs": torch.cat(token_probs_list, dim=0),
            "token_labels": torch.cat(token_labels_list, dim=0),
            "first_masks": torch.cat(first_masks_list, dim=0)
        }



    data = get_logits_and_labels(model, dataloader, device)

    thresholds = np.arange(0.0, 1.01, 0.01)

    # -------------------------
    # Prompt-level threshold search
    # -------------------------
    prompt_probs = data["prompt_probs"]         # (N, 2)
    prompt_labels = data["prompt_labels"]       # (N,)

    best_prompt_th = 0.0
    best_prompt_f1 = 0.0

    for th in thresholds:
        preds = ((prompt_probs[:, 1] > prompt_probs[:, 0]) & (prompt_probs[:, 1] > th)).int()
        report = classification_report(prompt_labels, preds, output_dict=True, zero_division=0)
        f1 = report.get("1", {}).get("f1-score", 0.0)
        if f1 > best_prompt_f1:
            best_prompt_f1 = f1
            best_prompt_th = th

    # -------------------------
    # Token-level threshold search
    # -------------------------
    token_probs = data["token_probs"]           # (B, S, 2)
    token_labels = data["token_labels"]         # (B, S)
    first_masks = data["first_masks"]           # (B, S)
    valid = (token_labels != -100) & first_masks

    best_word_th = 0.0
    best_word_f1 = 0.0

    for th in thresholds:
        unsafe = token_probs[:, :, 1]
        safe = token_probs[:, :, 0]
        preds = ((unsafe > safe) & (unsafe > th)).int()
        preds_flat = preds[valid]
        labels_flat = token_labels[valid]

        report = classification_report(labels_flat, preds_flat, output_dict=True, zero_division=0)
        f1 = report.get("1", {}).get("f1-score", 0.0)
        if f1 > best_word_f1:
            best_word_f1 = f1
            best_word_th = th

    if verbose:
        print(f"\n✅ Best Prompt Threshold: {best_prompt_th:.2f} (F1 = {best_prompt_f1:.4f})")
        print(f"✅ Best Token Threshold : {best_word_th:.2f} (F1 = {best_word_f1:.4f})")

    return best_prompt_th, best_word_th

###############################################################################
# 5. Main Pipeline
###############################################################################

def main():
    SEED = 42          # pick any integer
    random.seed(SEED)              # 1. Python's builtin RNG
    np.random.seed(SEED)           # 2. NumPy RNG
    torch.manual_seed(SEED)        # 3. PyTorch CPU RNG
    torch.cuda.manual_seed_all(SEED)  # 4. PyTorch CUDA RNG (all GPUs)
    g = torch.Generator()
    g.manual_seed(SEED)

    max_length = 512
    train_batch_size = 16
    eval_batch_size = 16
    num_epochs = 3
    lr = 2e-5

    #model_name = "microsoft/deberta-v3-base"
    #model_name = "microsoft/deberta-v3-large"
    model_name = "microsoft/deberta-v3-xsmall"
    
    model_root_folder = "model_seed_"+str(SEED)
    ablations = [
        {"d_p": 1, "d_w": 1, "f": 1, "u": 1}, # all
        #{"d_p": 1, "d_w": 1, "f": 1, "u": 0}, # No uncertainity
        #{"d_p": 1, "d_w": 0, "f": 1, "u": 1}, # No d_w
        #{"d_p": 1, "d_w": 0, "f": 1, "u": 0}, # No d_w, No u #
        #{"d_p": 0, "d_w": 1, "f": 1, "u": 1}, # No_p
        #{"d_p": 0, "d_w": 1, "f": 1, "u": 0}, # No_p, No u #
        #{"d_p": 1, "d_w": 1, "f": 0, "u": 1}, # Only CE+u
        #{"d_p": 1, "d_w": 1, "f": 0, "u": 0} # Only CE
    ]

    dataset_config = [
        ('toxic_chat_with_word_label_intersection_all_data.json', "toxicchat")
    ]

    for ab in ablations:
        train_config = ""
        if ab["d_p"]:
            train_config+="delta_p+"
        if ab["d_w"]:
            train_config+="delta_w+"
        if ab["f"]:
            train_config+="focal+"
        if ab["u"]:
            train_config+="uncertainty"
        for dataset_path,dataset_name in dataset_config:
            output_dir = os.path.join(model_root_folder, model_name.split("/")[1]+"_"+dataset_name+"_"+train_config)
            print(f"\n\n=== Running ablation: d_p={ab['d_p']}, d_w={ab['d_w']}, f={ab['f']}, u={ab['u']} ===")
            print(output_dir)
            model_path = os.path.join(output_dir, "pytorch_model.bin")
            #model_path = output_dir
            print(train_config)
            with open(dataset_path, "r", encoding="utf-8") as f:
                dataset = json.load(f)

            train_data = dataset["train"]
            dev_data = dataset.get("dev", [])
            if dataset_name!="aegis":
                train_data = train_data + dev_data
            test_data = dataset.get("test", [])

            # ---------------------------------------------------------------------
            # Prepare tokenizer and datasets
            # ---------------------------------------------------------------------
            tokenizer = AutoTokenizer.from_pretrained(model_name)


            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

            # ---------------------------------------------------------------------
            # Compute class weights
            # ---------------------------------------------------------------------
            train_prompt_labels = [item["prompt_label"] for item in train_data]
            prompt_class_weights = compute_class_weight(class_weight="balanced", classes=np.array([0, 1]), y=train_prompt_labels)
            prompt_class_weights = torch.tensor(prompt_class_weights, dtype=torch.float).to(device)

            all_word_labels = []
            for item in train_data:
                all_word_labels.extend([wl[1] for wl in item["word_label"] if wl[1] in [0, 1]])
            word_class_weights = compute_class_weight(class_weight="balanced", classes=np.array([0, 1]), y=all_word_labels)
            word_class_weights = torch.tensor(word_class_weights, dtype=torch.float).to(device)

            # ---------------------------------------------------------------------
            # Setup model
            # ---------------------------------------------------------------------
            model = ""
            if os.path.exists(model_path):
                print("🚀 Model already exists!")
            else:
                print("🔄 No saved model found. Training new model...")

                config = AutoConfig.from_pretrained(model_name)
                config.num_labels = 2
                config.attn_implementation = "flash_attention_2"

                subtoken_stats = compute_subtoken_stats(train_data, tokenizer)
                subtoken_penalties = compute_subtoken_penalties(subtoken_stats)

                model = JointModel(
                    model_name=model_name,
                    config=config,
                    prompt_class_weights=prompt_class_weights,
                    word_class_weights=word_class_weights,
                    alpha=0.01,
                    subtoken_penalties=subtoken_penalties,
                    tokenizer=tokenizer,
                    ablation = ab
                )
                model.to(device)
                # ---------------------------------------------------------------
                # collect the two scalars and the rest of the model separately
                # ---------------------------------------------------------------
                scaler = GradScaler() 
                optimizer = AdamW(
                    model.parameters(),
                    lr=lr,
                    fused=True           # ask PyTorch to use the fused CUDA kernel
                )

                prompt_penalty_dict = compute_prompt_penalties(subtoken_stats, train_data, tokenizer)
                train_dataset = PromptWordDataset(train_data, tokenizer, max_length=max_length, prompt_penalty_dict=prompt_penalty_dict)
                train_dataloader = DataLoader(
                    train_dataset, 
                    batch_size=train_batch_size, 
                    shuffle=True, 
                    collate_fn=collate_fn,
                    generator=g     
                    )
                
                total_steps = len(train_dataloader) * num_epochs
                scheduler = get_linear_schedule_with_warmup(
                    optimizer,
                    num_warmup_steps=int(0.1 * total_steps),
                    num_training_steps=total_steps
                )


                    # Training
                for epoch in range(num_epochs):
                    print(f"\n======== Epoch {epoch+1}/{num_epochs} ========")
                    avg_train_loss = train_one_epoch(model, train_dataloader, optimizer, scheduler, scaler, device)
                    print(f"Average train loss: {avg_train_loss:.4f}")
                    print(f"σ_prompt={model.log_sigma_prompt.exp().item():.4f}, σ_token={model.log_sigma_token.exp().item():.4f}")

                # Save model and tokenizer
                print("💾 Saving model...")
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)

                model.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
                torch.save(model.state_dict(), model_path)
                print("✅ Model saved successfully!")


if __name__ == "__main__":
    main()

