import torch.nn.functional as F
import numpy as np
import pandas as pd
import os
import json
import torch
import torch.nn as nn
import numpy as np
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModel,
    AutoConfig,
    AutoTokenizer,
    PreTrainedModel,
    get_linear_schedule_with_warmup
)
from time import perf_counter
from torch.optim import AdamW
from sklearn.metrics import precision_recall_fscore_support, classification_report
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm
from collections import defaultdict
from torch.amp import GradScaler, autocast  
torch.set_float32_matmul_precision('high')

###############################################################################
# 1. Data Handling
###############################################################################

class PromptWordDataset(Dataset):
    """
    A custom dataset to handle:
    - the prompt (text) -> prompt_label
    - the word-level labels for explanation

    The dataset will store:
        - input_ids
        - attention_mask
        - prompt_label
        - word_labels (one per sub-token)
        - original words (for analysis/predictions later)
    """

    def __init__(self, data, tokenizer, max_length=512):
        """
        data: list of items with the following structure:
            {
               "id": str,
               "prompt": str,
               "prompt_label": int (0 or 1),
               "word_label": [ [word1, 0/1], [word2, 0/1], ... ]
            }
        tokenizer: a modernBert tokenizer
        max_length: maximum sequence length
        """
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        self.examples = []
        
        for item in self.data:
            prompt_text = item["prompt"]
            prompt_label = item["prompt_label"]
            word_label_pairs = item["word_label"]  # list of [word, label]
            
            words = [wp[0] for wp in word_label_pairs]
            word_level_labels = [wp[1] for wp in word_label_pairs]

            encoding = self.tokenizer(
                words,
                is_split_into_words=True,
                max_length=self.max_length,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            )

            input_ids = encoding["input_ids"].squeeze(0)       # shape (max_length,)
            attention_mask = encoding["attention_mask"].squeeze(0)

            # Construct sub-token labels
            word_ids = encoding.word_ids(batch_index=0)  # list of length max_length
            token_labels = []
            first_mask = []
            prev_wid = None
            for w_id in word_ids:
                if w_id is None:
                    token_labels.append(-100)  # ignore index
                    first_mask.append(False)
                else:
                    token_labels.append(word_level_labels[w_id])
                    is_first = (w_id != prev_wid)
                    first_mask.append(is_first)
                    prev_wid = w_id
            
            self.examples.append({
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "prompt_label": prompt_label,
                "token_labels": token_labels,
                "first_mask": torch.tensor(first_mask, dtype=torch.bool),
                "original_data": item  # for analysis
            })

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        return {
            "input_ids": example["input_ids"],
            "attention_mask": example["attention_mask"],
            "prompt_label": example["prompt_label"],
            "token_labels": torch.tensor(example["token_labels"], dtype=torch.long),
            "first_mask": example["first_mask"],
            "original_data": example["original_data"]
        }


def collate_fn(batch):
    """
    Collate function to combine items into a single batch for the DataLoader.
    """
    input_ids = torch.stack([item["input_ids"] for item in batch])
    attention_mask = torch.stack([item["attention_mask"] for item in batch])

    prompt_labels = torch.tensor([item["prompt_label"] for item in batch], dtype=torch.long)
    token_labels = torch.stack([item["token_labels"] for item in batch])
    first_masks = torch.stack([item["first_mask"] for item in batch])
    original_data = [item["original_data"] for item in batch]

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "prompt_labels": prompt_labels,
        "token_labels": token_labels,
        "first_masks": first_masks,
        "original_data": original_data
    }


def compute_subtoken_stats(train_data, tokenizer):
    """
    Computes how many times each subtoken is used as safe or unsafe in the training dataset.
    Returns a dict {subtoken: (safe_count, unsafe_count)}.
    """
    subtoken_stats = defaultdict(lambda: [0, 0])

    for item in train_data:
        for word, label in item["word_label"]:
            subwords = tokenizer.tokenize(word)
            for st in subwords:
                if label == 0:
                    subtoken_stats[st][0] += 1
                else:
                    subtoken_stats[st][1] += 1

    return subtoken_stats

def compute_subtoken_penalties(subtoken_stats):
    """
    Returns: {subtoken: (penalty_mag, dominant_label)}
        penalty_mag   ∈ (0,1]  -  |safe_cnt - unsafe_cnt| / total
        dominant_label ∈ {0,1} -  0 if safe≥unsafe else 1
    """
    penalties = {}
    eps = 1e-6
    for st, (safe_cnt, unsafe_cnt) in subtoken_stats.items():
        total = safe_cnt + unsafe_cnt + eps
        mag = abs(safe_cnt - unsafe_cnt) / total
        dom = 0 if safe_cnt >= unsafe_cnt else 1
        penalties[st] = (mag, dom)
    return penalties


###############################################################################
# 2. Model Definition
###############################################################################
class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attn = nn.Linear(hidden_size, 1)

    def forward(self, embeddings, mask):
        """
        embeddings: Tensor (batch_size, seq_len, hidden_size)
        mask: Tensor (batch_size, seq_len) with 1 for valid tokens, 0 for padding
        """
        attn_scores = self.attn(embeddings).squeeze(-1)  # (batch_size, seq_len)
        neg_inf = torch.finfo(attn_scores.dtype).min        # ‑65504 for fp16
        attn_scores = attn_scores.masked_fill(mask == 0, neg_inf)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        pooled_output = torch.sum(embeddings * attn_weights.unsqueeze(-1), dim=1)
        return pooled_output


class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attn = nn.Linear(hidden_size, 1)

    def forward(self, embeddings, mask):
        """
        embeddings: Tensor (batch_size, seq_len, hidden_size)
        mask: Tensor (batch_size, seq_len) with 1 for valid tokens, 0 for padding
        """
        attn_scores = self.attn(embeddings).squeeze(-1)  # (batch_size, seq_len)
        neg_inf = torch.finfo(attn_scores.dtype).min        # ‑65504 for fp16
        attn_scores = attn_scores.masked_fill(mask == 0, neg_inf)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        pooled_output = torch.sum(embeddings * attn_weights.unsqueeze(-1), dim=1)
        return pooled_output


class JointModel(PreTrainedModel):
    def __init__(
        self,
        model_name,
        config,
        prompt_class_weights=None,
        word_class_weights=None,
        alpha=0.01,
        subtoken_penalties=None,
        tokenizer=None,
        ablation=None
    ):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = AutoModel.from_pretrained(model_name, config=config)
        self.tokenizer = tokenizer

        # Shared Dense Layer
        #self.shared_dense = nn.Linear(config.hidden_size, config.hidden_size)
        #self.activation = nn.ReLU()

        # Attention Pooling
        self.attention_pooling = AttentionPooling(config.hidden_size)

        # Classification Heads
        self.prompt_classifier = nn.Linear(config.hidden_size, 2)
        self.token_classifier = nn.Linear(config.hidden_size, 2)

        # Trainable Loss Weights
        #self.lambda_prompt = nn.Parameter(torch.tensor(1.0, dtype=torch.float, requires_grad=True))
        #self.lambda_word = nn.Parameter(torch.tensor(1.0, dtype=torch.float, requires_grad=True))
        # ——— Uncertainty weighting params ———
        # We learn log(σ) for each task; σ = exp(log_sigma)
        self.log_sigma_prompt = nn.Parameter(torch.zeros(()))
        self.log_sigma_token  = nn.Parameter(torch.zeros(()))
        # L2 Regularization Strength
        self.alpha = alpha  # Controls regularization

        # Loss weights
        self.prompt_class_weights = prompt_class_weights if prompt_class_weights is not None else torch.tensor([1.0, 1.0])
        self.word_class_weights = word_class_weights if word_class_weights is not None else torch.tensor([1.0, 1.0])

        #self.subtoken_penalties = subtoken_penalties or {}
        #Save subtoken penalties
        vocab_size = tokenizer.vocab_size
        mag_table  = torch.zeros(vocab_size)                 # (V,)
        dom_table  = torch.full((vocab_size,), -1, dtype=torch.long)

        if subtoken_penalties:
            for tok_str, (mag, dom) in subtoken_penalties.items():
                tok_id = tokenizer.convert_tokens_to_ids(tok_str)
                if tok_id != tokenizer.unk_token_id:
                    mag_table[tok_id] = mag
                    dom_table[tok_id] = dom

        # register as buffers so they move with .to(device) and are not updated
        self.register_buffer("penalty_mag_table", mag_table, persistent=False)
        self.register_buffer("penalty_dom_table", dom_table, persistent=False)

        # Store ablation multipliers
        self.ablation   = ablation
        self.init_weights()


    def forward(
        self,
        input_ids,
        attention_mask=None,
        prompt_labels=None,
        token_labels=None,
        prompt_penalties=None
    ):
        """
        Args
        ----
        input_ids      : (B, S)
        attention_mask : (B, S)
        prompt_labels  : (B,)         gold prompt class (0/1)      or None
        token_labels   : (B, S)      gold token labels (0/1/-100) or None

        Returns
        -------
        dict with keys {"loss", "prompt_logits", "token_logits"}
        """
        # ------------------------------------------------------------
        # 1) Encoder
        # ------------------------------------------------------------
        enc_out = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        hidden = enc_out.last_hidden_state            # (B, S, H)

        # ------------------------------------------------------------
        # 2) Two heads
        # ------------------------------------------------------------
        pooled         = self.attention_pooling(hidden, attention_mask)
        prompt_logits  = self.prompt_classifier(pooled)     # (B, 2)
        token_logits   = self.token_classifier(hidden)      # (B, S, 2)

        loss = None
        if prompt_labels is not None:
            ce_prompt = F.cross_entropy(
                prompt_logits,
                prompt_labels,
                weight=self.prompt_class_weights.to(prompt_logits.device),
                reduction="none"
            )
            pt_prompt = torch.exp(-ce_prompt)
            fl_prompt = (1.0 - pt_prompt).pow(2.0) * ce_prompt
            loss_prompt = ce_prompt + (fl_prompt * prompt_penalties) * self.ablation["f"] * self.ablation["d_p"]
            loss_prompt = loss_prompt.mean()

            loss_token = torch.tensor(0.0, device=input_ids.device)

            if token_labels is not None:
                B, S = input_ids.size()
                valid = token_labels != -100

                per_tok_ce = F.cross_entropy(
                    token_logits.view(-1, 2),
                    token_labels.view(-1),
                    ignore_index=-100,
                    reduction="none"
                ).view(B, S)

                gamma = 2.0
                pt = torch.exp(-per_tok_ce)
                per_tok_fl = (1.0 - pt).pow(gamma) * per_tok_ce
                
                with torch.no_grad():
                    mags = self.penalty_mag_table[input_ids]
                    doms = self.penalty_dom_table[input_ids].long()
                    same_cls = (doms == token_labels)
                    token_penalties = mags * same_cls.float() * valid.float()
                '''with torch.no_grad():
                    mags = self.penalty_mag_table[input_ids]
                    token_penalties = mags * valid.float()'''


                combined = per_tok_ce + per_tok_fl * token_penalties * self.ablation['f'] * self.ablation['d_w']
                #loss_token = combined[valid].mean()
                loss_token = combined[valid].sum() / input_ids.size(0)
            '''reg = self.alpha * (
                self.lambda_prompt.pow(2)+
                self.lambda_word.pow(2)
            )

            loss = (self.lambda_prompt.pow(2)) * loss_prompt + \
                   (self.lambda_word.pow(2)) * loss_token + \
                   reg * self.ablation['r']'''


            if self.ablation['u']:
                # uncertainty‐weighted joint loss
                sigma_p = torch.exp(self.log_sigma_prompt)
                sigma_t = torch.exp(self.log_sigma_token)
                loss_p = 0.5 * loss_prompt / (sigma_p**2) + self.log_sigma_prompt
                loss_t = 0.5 * loss_token / (sigma_t**2) + self.log_sigma_token
                loss = loss_p + loss_t
            else:
                loss = loss_prompt + loss_token

        return {
            "loss": loss,
            "prompt_logits": prompt_logits,
            "token_logits": token_logits
        }

###############################################################################
# 3. Metrics and Utility
###############################################################################

def compute_metrics(preds, labels, average='micro'):
    p, r, f1, _ = precision_recall_fscore_support(labels, preds, average=average, zero_division=0)
    return {
        f"{average}_precision": p,
        f"{average}_recall": r,
        f"{average}_f1": f1
    }

def print_classwise_report(labels, preds):
    target_names = ["safe (0)", "unsafe (1)"]
    print(classification_report(labels, preds, target_names=target_names, zero_division=0))

def label_to_str(label_int):
    return "safe" if label_int == 0 else "unsafe"

###############################################################################
# 4. Train, Evaluate, and Prediction Functions
###############################################################################

def train_one_epoch(model, dataloader, optimizer, scheduler, scaler, device):
    model.train()
    total_loss = 0.0

    for batch in tqdm(dataloader, desc="Training", leave=False):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        prompt_labels = batch["prompt_labels"].to(device)
        token_labels = batch["token_labels"].to(device)

        optimizer.zero_grad()                                # move to top

        with autocast(device_type="cuda"):                    # FP16/BF16 region
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                prompt_labels=prompt_labels,
                token_labels=token_labels
            )
            loss = outputs["loss"]                           # still FP32 here

        scaler.scale(loss).backward()                        # scaled backward
        scaler.step(optimizer)                               # scaled step
        scaler.update()                                      # adjust scale

        scheduler.step()                                     # LR scheduler

        
        total_loss += loss.item()

    return total_loss / len(dataloader)

@torch.inference_mode()                                # disables grad + sets eval
def evaluate(
        model: nn.Module,
        dataloader: torch.utils.data.DataLoader,
        device: torch.device,
        prompt_threshold: float = 0.6,
        word_threshold : float = 0.4
):
    """
    Returns:
        ((prompt_pred, prompt_gold), (token_pred, token_gold))
        where each entry is a flat Python list of ints.
        Token lists contain **only first sub‑tokens**.
    """
    model.eval()

    prompt_preds,  prompt_golds  = [], []
    token_preds,   token_golds   = [], []

    # AMP speeds the forward pass; device_type arg valid for PT >= 2.1, ignored otherwise
    with autocast(device_type='cuda'):
        for batch in  tqdm(dataloader, desc="Evaluating", leave=False): 
            # ------------------------------------------------------------------
            # 1) Move tensors to GPU
            # ------------------------------------------------------------------
            input_ids      = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            p_gold         = batch["prompt_labels"].to(device)
            t_gold         = batch["token_labels"].to(device)
            first_mask     = batch["first_masks"].to(device)          # (B, S) bool

            # ------------------------------------------------------------------
            # 2) Forward pass
            # ------------------------------------------------------------------
            out = model(input_ids=input_ids, attention_mask=attention_mask)
            p_logits = out["prompt_logits"]            # (B, 2)
            t_logits = out["token_logits"]             # (B, S, 2)

            # ------------------------------------------------------------------
            # 3) Prompt‑level predictions (vectorised)
            # ------------------------------------------------------------------
            p_probs   = p_logits.softmax(dim=1)        # (B, 2)
            unsafe_p  = p_probs[:, 1]
            safe_p    = p_probs[:, 0]
            p_pred_b  = ((unsafe_p > safe_p) & (unsafe_p > prompt_threshold)).int()

            prompt_preds .append(p_pred_b.cpu())
            prompt_golds .append(p_gold.cpu())

            # ------------------------------------------------------------------
            # 4) Token‑level predictions – first sub‑tokens only
            # ------------------------------------------------------------------
            t_probs   = t_logits.softmax(dim=2)        # (B, S, 2)
            unsafe_t  = t_probs[:, :, 1]
            safe_t    = t_probs[:, :, 0]

            valid_mask   = (t_gold != -100)            # exclude padding / ignored
            first_valid  = first_mask & valid_mask     # keep only first sub‑tokens

            t_pred_b = ((unsafe_t > safe_t) & (unsafe_t > word_threshold)).int()

            token_preds .append(t_pred_b[first_valid].cpu())
            token_golds .append(t_gold[first_valid].cpu())

    # --------------------------------------------------------------------------
    # 5) Flatten once at the end (avoids many small Python allocs)
    # --------------------------------------------------------------------------
    prompt_preds  = torch.cat(prompt_preds).tolist()
    prompt_golds  = torch.cat(prompt_golds).tolist()
    token_preds   = torch.cat(token_preds ).tolist()
    token_golds   = torch.cat(token_golds ).tolist()

    return (prompt_preds, prompt_golds), (token_preds, token_golds)

def predict(model, dataloader, tokenizer, max_length, device, prompt_threshold=0.0, word_threshold=0.0):
    model.eval()
    results = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Predicting", leave=False):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            
            prompt_probs = F.softmax(outputs["prompt_logits"], dim=1)
            token_probs = F.softmax(outputs["token_logits"], dim=2)

            prompt_safe_probs = prompt_probs[:, 0].cpu().numpy()
            prompt_unsafe_probs = prompt_probs[:, 1].cpu().numpy()

            token_safe_probs = token_probs[:, :, 0].cpu().numpy()
            token_unsafe_probs = token_probs[:, :, 1].cpu().numpy()

            for i in range(len(batch["original_data"])):
                original_item = batch["original_data"][i]
                words_and_labels = original_item["word_label"]

                # Prompt-level classification
                unsafe_prob = prompt_unsafe_probs[i]
                safe_prob = prompt_safe_probs[i]
                predicted_prompt_label = 1 if (unsafe_prob > safe_prob and unsafe_prob > prompt_threshold) else 0

                token_label_example = batch["token_labels"][i].cpu().numpy()
                pred_word_labels = []

                words = [w_l[0] for w_l in words_and_labels]
                encoding = tokenizer(
                    words,
                    is_split_into_words=True,
                    max_length=max_length,
                    truncation=True,
                    padding="max_length",
                    return_tensors="pt"
                )
                word_ids = encoding.word_ids(batch_index=0)

                token_safe = token_safe_probs[i]
                token_unsafe = token_unsafe_probs[i]

                word_id_to_label = {}
                for j in range(len(word_ids)):
                    if word_ids[j] is not None and j < len(token_unsafe):
                        u_prob = token_unsafe[j]
                        s_prob = token_safe[j]
                        predicted_label = 1 if (u_prob > s_prob and u_prob > word_threshold) else 0
                        if word_ids[j] not in word_id_to_label:
                            word_id_to_label[word_ids[j]] = predicted_label

                for widx in range(len(words_and_labels)):
                    if widx in word_id_to_label:
                        pred_word_labels.append([words_and_labels[widx][0], word_id_to_label[widx]])
                    else:
                        pred_word_labels.append([words_and_labels[widx][0], 0])  # default safe

                result_item = {
                    "prompt": original_item["prompt"],
                    "original_prompt_label": label_to_str(original_item["prompt_label"]),
                    "predicted_prompt_label": label_to_str(predicted_prompt_label),
                    "original_word_labels": words_and_labels,
                    "predicted_word_labels": pred_word_labels
                }
                results.append(result_item)

    return results


###############################################################################
# 5. Main Pipeline
###############################################################################

def main():
    max_length = 512
    train_batch_size = 16
    eval_batch_size = 16
    num_epochs = 3
    lr = 2e-5
    SEED = 42

    #model_name = "microsoft/deberta-v3-base"
    #model_name = "microsoft/deberta-v3-large"
    model_name = "microsoft/deberta-v3-xsmall"

    model_root_folder = "model_seed_"+str(SEED)

    ablations = [
        {"d_p": 1, "d_w": 1, "f": 1, "u": 1}, # all
        #{"d_p": 1, "d_w": 1, "f": 1, "u": 0}, # No uncertainity
        #{"d_p": 1, "d_w": 0, "f": 1, "u": 1}, # No d_w
        #{"d_p": 1, "d_w": 0, "f": 1, "u": 0}, # No d_w, No u #
        #{"d_p": 0, "d_w": 1, "f": 1, "u": 1}, # No_p
        #{"d_p": 0, "d_w": 1, "f": 1, "u": 0}, # No_p, No u #
        #{"d_p": 1, "d_w": 1, "f": 0, "u": 1}, # Only CE+u
        #{"d_p": 1, "d_w": 1, "f": 0, "u": 0} # Only CE
    ]

    dataset_config = [
        ('toxic_chat_with_word_label_intersection_all_data.json', "toxicchat"),
    ]
 
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    all_result = {}
    for ab in ablations:
        train_config = ""
        if ab["d_p"]:
            train_config+="delta_p+"
        if ab["d_w"]:
            train_config+="delta_w+"
        if ab["f"]:
            train_config+="focal+"
        if ab["u"]:
            train_config+="uncertainty"
        for train_dataset_path,train_dataset_name in dataset_config:
            #output_dir = "./"+model_root_folder+"/"+model_name.split("/")[1]+"_"+dataset_name+"_"+train_config
            output_dir = os.path.join(model_root_folder, model_name.split("/")[1]+"_"+train_dataset_name+"_"+train_config)
            print(f"\n\n=== Running ablation: d_p={ab['d_p']}, d_w={ab['d_w']}, f={ab['f']}, u={ab['u']} ===")
            print(output_dir)
            model_path = os.path.join(output_dir, "pytorch_model.bin")
            #model_path = output_dir
            if train_dataset_name not in all_result:
                all_result[train_dataset_name]={}
            if train_config not in all_result[train_dataset_name]:
                all_result[train_dataset_name][train_config]={}


            model = ""
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            #device = torch.device("cpu")
            results_summary = {}
            if os.path.exists(model_path):
                print("🚀 Model already exists!")
                config = AutoConfig.from_pretrained(output_dir)
                tokenizer = AutoTokenizer.from_pretrained(output_dir)
                model = JointModel(
                    model_name=model_name,
                    config=config,
                    tokenizer=tokenizer
                )
                model.load_state_dict(torch.load(model_path, map_location=device))
                model.to(device)
                model.eval()

            
            for test_dataset_path,test_dataset_name in dataset_config:
                with open(test_dataset_path, "r", encoding="utf-8") as f:
                    test_data = json.load(f).get("test", [])
                # ---------------------------------------------------------------------
                # Prepare tokenizer and datasets
                # ---------------------------------------------------------------------
                

                test_dataset = PromptWordDataset(test_data, tokenizer, max_length=max_length) if test_data else None

                test_dataloader = DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, collate_fn=collate_fn) if test_dataset else None


                # ---------------------------------------------------------------------
                # Setup model
                # ---------------------------------------------------------------------

                results_summary = {}
                if model!="":
                    print("🚀 Model already exists!")

                    # ---------------------------------------------------------------------
                    # Evaluation & Prediction
                    # ---------------------------------------------------------------------
                    splits = {
                        #"train": train_dataloader,
                        #"dev": dev_dataloader,
                        "test": test_dataloader
                    }

                    for split_name, loader in splits.items():
                        if split_name != "test":
                            continue
                        if loader is None:
                            continue

                        (prompt_preds, prompt_labels), (token_preds, token_labels) = evaluate(model, loader, device,prompt_threshold=0.5, word_threshold=0.5)
                        prompt_metrics_micro = compute_metrics(prompt_preds, prompt_labels, average='micro')
                        token_metrics_micro = compute_metrics(token_preds, token_labels, average='micro')

                        prompt_metrics_macro = compute_metrics(prompt_preds, prompt_labels, average='macro')
                        token_metrics_macro = compute_metrics(token_preds, token_labels, average='macro')

                        print(f"{split_name.capitalize()} Prompt Micro:", prompt_metrics_micro)
                        print_classwise_report(prompt_labels, prompt_preds)

                        print(f"{split_name.capitalize()} Token Micro:", token_metrics_micro)
                        print_classwise_report(token_labels, token_preds)

                        prompt_per_class_metrics = classification_report(prompt_labels, prompt_preds, output_dict=True, zero_division=0)
                        token_per_class_metrics = classification_report(token_labels, token_preds, output_dict=True, zero_division=0)

                        results_summary[split_name] = {
                            "prompt": {
                                "micro": prompt_metrics_micro,
                                "macro": prompt_metrics_macro,
                                "per_class": prompt_per_class_metrics
                            },
                            "token": {
                                "micro": token_metrics_micro,
                                "macro": token_metrics_macro,
                                "per_class": token_per_class_metrics
                            }
                        }
                else:
                    print("🔄 No saved model found.")
                    results_summary = "NA"
                all_result[train_dataset_name][train_config][test_dataset_name] = results_summary
                #print(all_result)
    # Save evaluation metrics
    json_file_path = os.path.join(
        model_root_folder,
        model_name.split("/")[1] + "_all_evaluation_seed_"+str(SEED)+".json"
    )

    with open(json_file_path, "w", encoding="utf-8") as f:
        json.dump(all_result, f, indent=2)   

     # Save excel also

    heading =  ["Train_set","Model"]
    summary = []
    for train_set,models in all_result.items():
        for mode_name,test_sets in models.items():
            row = [train_set,mode_name]
            for test_set_name, result in test_sets.items():
                if test_set_name == "best_thresholds":
                    continue
                heading.append(test_set_name+"_prompt_P")
                heading.append(test_set_name+"_prompt_R")
                heading.append(test_set_name+"_prompt_F1")
                heading.append(test_set_name+"_word_P")
                heading.append(test_set_name+"_word_R")
                heading.append(test_set_name+"_word_F1")
                if result != "NA":
                    row.append(round(result["test"]["prompt"]["per_class"]["1"]["precision"]*100,2))
                    row.append(round(result["test"]["prompt"]["per_class"]["1"]["recall"]*100,2))
                    row.append(round(result["test"]["prompt"]["per_class"]["1"]["f1-score"]*100,2))
                    row.append(round(result["test"]["token"]["per_class"]["1"]["precision"]*100,2))
                    row.append(round(result["test"]["token"]["per_class"]["1"]["recall"]*100,2))
                    row.append(round(result["test"]["token"]["per_class"]["1"]["f1-score"]*100,2))
                else:
                    row.append("")
                    row.append("")
                    row.append("")
                    row.append("")
                    row.append("")
                    row.append("")
            summary.append(row)

    # Save to Excel file for 
    excel_file_path = json_file_path.replace(".json","_p_r_f1_final_ablation_test.xlsx")
    pd.DataFrame(summary,columns=heading[:20]).to_excel(excel_file_path, index=False)

    print(f"Excel file saved as {excel_file_path}")


if __name__ == "__main__":
    main()
