"""
Extended BERT + multiple Captum methods (IG, Occlusion, Shapley, LIME) for text data,
including spurious correlation handling, new R, F, IS, RBP, 
and visualizes token attributions for normal vs spurious texts, saved to 'text_saliency_plots/'.

CHANGELOG:
 - Spurious token is now 'peach' for label=1
 - Larger vocab, random length ~10..15
 - Implementation of (R, F, IS) plus RBP for text
"""

import random
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader

from transformers import BertTokenizer, BertForSequenceClassification
from captum.attr import (
    IntegratedGradients,
    Occlusion,
    ShapleyValueSampling,
    Lime
)

import time
import matplotlib.pyplot as plt
import os

######################################################################
# 1) Synthetic Data
######################################################################
def generate_synthetic_sentences(n_samples=1000):
    """
    If sentence contains 'apple' or 'banana' => label=1, else 0
    We'll expand vocab and use random length ~10..15.
    Then we'll adopt that *spurious* scenario will add 'peach' if label=1 later.
    """
    vocab = [
        "this", "is", "just", "an", "example", "random", "word",
        "apple", "banana", "orange", "tree", "car", "house",
        "toy", "test", "cat", "dog", "bird", "table", "chair",
        "computer", "phone", "like", "play", "music", "drink",
        "water", "great", "fantasy", "fast", "slow", "jungle",
        "forest", "moon", "space", "pizza", "tasty", "color",
        "pencil", "paint", "smart"
    ]
    pos_tokens = {"apple", "banana"}

    texts, labels = [], []
    for _ in range(n_samples):
        length = random.randint(10, 15)  # bigger sentence length
        tokens = [random.choice(vocab) for _ in range(length)]
        label = 1 if (set(tokens) & pos_tokens) else 0
        texts.append(" ".join(tokens))
        labels.append(label)

    return texts, labels


class TextDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx], self.labels[idx]


def collate_fn(batch, tokenizer, max_len=30):
    """
    Convert (text, label) -> model inputs
    """
    texts, labels = zip(*batch)
    labels_tensor = torch.tensor(labels, dtype=torch.long)

    enc = tokenizer(
        list(texts),
        padding="max_length",
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )
    input_ids = enc["input_ids"]
    attention_mask = enc["attention_mask"]

    return input_ids, attention_mask, labels_tensor


######################################################################
# 2) Simple Training Loop for BERT
######################################################################
def train_bert_classifier(
        model,
        train_dataset,
        tokenizer,
        device,
        epochs=4,
        batch_size=16
):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda b: collate_fn(b, tokenizer, max_len=30)
    )

    for epoch in range(epochs):
        total_loss = 0.0
        total_samples = 0
        for input_ids, attention_mask, labels in train_loader:
            input_ids = input_ids.to(device, dtype=torch.long)
            attention_mask = attention_mask.to(device, dtype=torch.long)
            labels = labels.to(device)

            optimizer.zero_grad()
            out = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = out.loss
            loss.backward()
            optimizer.step()

            bs = labels.size(0)
            total_loss += loss.item() * bs
            total_samples += bs

        epoch_loss = total_loss / total_samples
        print(f"Epoch {epoch + 1}/{epochs}, Loss={epoch_loss:.4f}")


######################################################################
# 3) Forward Wrappers
######################################################################
def forward_embeds_only(embeds, mask, model, class_idx=1):
    """
    Basic forward function => logits[:, class_idx]
    'embeds' is the input to BERT embeddings,
    'mask' is attention_mask.
    """
    out = model(inputs_embeds=embeds, attention_mask=mask)
    logits = out.logits
    return logits[:, class_idx]


######################################################################
# 4) Explanation Methods (Captum)
######################################################################
def explain_ig_text(model, tokenizer, text, device, target_class, max_len=30):
    enc = tokenizer(text, max_length=max_len, truncation=True, padding='max_length', return_tensors='pt')
    input_ids = enc['input_ids'].to(device)
    attention_mask = enc['attention_mask'].to(device)

    baseline_emb = torch.zeros_like(model.bert.embeddings.word_embeddings(input_ids))
    input_emb = model.bert.embeddings.word_embeddings(input_ids)

    ig = IntegratedGradients(lambda embs: forward_embeds_only(embs, attention_mask, model, target_class))
    attrs = ig.attribute(inputs=input_emb, baselines=baseline_emb, n_steps=8)
    token_attributions = attrs[0].detach().cpu().numpy()
    token_attributions = np.linalg.norm(token_attributions, ord=2, axis=1)

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
    return tokens, token_attributions


def explain_occlusion_text(model, tokenizer, text, device, target_class, max_len=30):
    enc = tokenizer(text, max_length=max_len, truncation=True, padding='max_length', return_tensors='pt')
    input_ids = enc['input_ids'].to(device)
    attention_mask = enc['attention_mask'].to(device)

    input_emb = model.bert.embeddings.word_embeddings(input_ids)
    occ = Occlusion(lambda embs: forward_embeds_only(embs, attention_mask, model, target_class))

    hidden_dim = input_emb.size(-1)
    sliding_window = (1, hidden_dim)

    attributions = occ.attribute(
        inputs=input_emb,
        sliding_window_shapes=sliding_window,
        strides=(1, hidden_dim),
        baselines=torch.zeros_like(input_emb),
    )
    token_attributions = attributions[0].detach().cpu().numpy()
    token_attributions = np.linalg.norm(token_attributions, ord=2, axis=1)

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
    return tokens, token_attributions


def explain_shapley_text(model, tokenizer, text, device, target_class, max_len=30):
    enc = tokenizer(text, max_length=max_len, truncation=True, padding='max_length', return_tensors='pt')
    input_ids = enc['input_ids'].to(device)
    attention_mask = enc['attention_mask'].to(device)

    input_emb = model.bert.embeddings.word_embeddings(input_ids)
    shapley = ShapleyValueSampling(lambda embs: forward_embeds_only(embs, attention_mask, model, target_class))
    attributions = shapley.attribute(
        inputs=input_emb,
        baselines=torch.zeros_like(input_emb),
        n_samples=5
    )
    token_attributions = attributions[0].detach().cpu().numpy()
    token_attributions = np.linalg.norm(token_attributions, ord=2, axis=1)

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
    return tokens, token_attributions


def explain_lime_text(model, tokenizer, text, device, target_class, max_len=30):
    enc = tokenizer(text, max_length=max_len, truncation=True, padding='max_length', return_tensors='pt')
    input_ids = enc['input_ids'].to(device)
    attention_mask = enc['attention_mask'].to(device)

    input_emb = model.bert.embeddings.word_embeddings(input_ids)
    lime_obj = Lime(lambda embs: forward_embeds_only(embs, attention_mask, model, target_class))

    attributions = lime_obj.attribute(
        inputs=input_emb,
        baselines=torch.zeros_like(input_emb),
        n_samples=5
    )
    token_attributions = attributions[0].detach().cpu().numpy()
    token_attributions = np.linalg.norm(token_attributions, ord=2, axis=1)

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
    return tokens, token_attributions


def _explain_text(model, tokenizer, text, device, method, target_class):
    if method == "ig":
        return explain_ig_text(model, tokenizer, text, device, target_class)
    elif method == "occlusion":
        return explain_occlusion_text(model, tokenizer, text, device, target_class)
    elif method == "shapley":
        return explain_shapley_text(model, tokenizer, text, device, target_class)
    elif method == "lime":
        return explain_lime_text(model, tokenizer, text, device, target_class)
    else:
        raise ValueError(f"Unknown explanation method: {method}")


######################################################################
# 5) R, F, IS for Text + RBP
######################################################################
def compute_inversion_scores_text(
    model,
    tokenizer,
    device,
    texts,
    labels,
    expl_method="ig",
    n_samples=20,
    p=2
):
    """
    - R: correlation of delta(prob) with token attributions if we mask tokens
    - F: difference in output if we add a random token
    - IS = ((R^p + (1 - F)^p)/2)^(1/p)
    We'll skip 'G' in this script or you can adapt as needed.
    """
    model.eval()
    indices = np.random.choice(len(texts), size=min(n_samples, len(texts)), replace=False)

    R_vals, F_vals = [], []

    for idx in indices:
        text_ = texts[idx]
        label_ = labels[idx]

        # predict
        enc = tokenizer(text_, return_tensors='pt', max_length=30, truncation=True, padding='max_length').to(device)
        emb_ = model.bert.embeddings.word_embeddings(enc['input_ids'])
        out_ = model(inputs_embeds=emb_, attention_mask=enc['attention_mask'])
        pred_class = torch.argmax(out_.logits, dim=1).item()
        prob_orig = F.softmax(out_.logits, dim=1)[0, pred_class].item()

        # baseline attributions
        tokens, attr_vals = _explain_text(model, tokenizer, text_, device, expl_method, pred_class)

        # R: mask each non-special token
        local_sals = []
        local_deltas = []
        for i, tok in enumerate(tokens):
            if tok in ["[PAD]", "[CLS]", "[SEP]"]:
                continue
            masked_tokens = tokens[:]
            masked_tokens[i] = "[MASK]"
            masked_txt = " ".join(masked_tokens)

            enc_m = tokenizer(masked_txt, return_tensors='pt', max_length=30, truncation=True, padding='max_length').to(device)
            emb_m = model.bert.embeddings.word_embeddings(enc_m['input_ids'])
            out_m = model(inputs_embeds=emb_m, attention_mask=enc_m['attention_mask'])
            prob_m = F.softmax(out_m.logits, dim=1)[0, pred_class].item()

            delta_m = prob_m - prob_orig
            local_sals.append(attr_vals[i])
            local_deltas.append(delta_m)

        if np.std(local_sals)<1e-9 or np.std(local_deltas)<1e-9:
            R_val_local = 0.0
        else:
            corr_ = np.corrcoef(local_sals, local_deltas)[0,1]
            R_val_local = max(0, corr_)
        R_vals.append(R_val_local)

        # F: add random token e.g. "orange"
        new_text = text_ + " orange"
        tokens_pert, attr_pert = _explain_text(model, tokenizer, new_text, device, expl_method, pred_class)
        mlen = min(len(attr_pert), len(attr_vals))
        diff_ = np.mean(np.abs(attr_pert[:mlen] - attr_vals[:mlen])) if mlen>0 else 0.0
        F_val_local = diff_ if diff_<1.0 else 1.0
        F_vals.append(F_val_local)

    R_mean = np.mean(R_vals) if len(R_vals)>0 else 0.0
    F_mean = np.mean(F_vals) if len(F_vals)>0 else 0.0
    IS_val = ((R_mean**p + (1.0 - F_mean)**p)/2.0)**(1.0/p)

    return {
        "R": R_mean,
        "F": F_mean,
        "IS": IS_val
    }

# ---- RBP for text ----

def apply_rbp_text_single(
    model,
    tokenizer,
    text_str,
    device,
    expl_method="ig",
    pred_class=0,
    n_pert=3,
    perturb_token="orange",
    lambda_=1.0
):
    """
    For a single text, do RBP:
      - baseline attributions
      - for each token i, do multiple small "perturbations"?
        We'll emulate a 'perturbation' by e.g. randomly replacing token i with 'perturb_token' 
        re-run the explanation, measure delta.
    Then a'^i = a^i / [1 + lambda_ * (avg dev in i-th token's attribution)]
    Returns refined attributions
    """
    tokens, base_attr = _explain_text(model, tokenizer, text_str, device, expl_method, pred_class)
    d = len(base_attr)
    refined = base_attr.copy()

    for i, tok in enumerate(tokens):
        if tok in ["[PAD]", "[CLS]", "[SEP]"]:
            continue

        local_dev_sum = 0.0
        base_val = base_attr[i]
        for _ in range(n_pert):
            # random token replacement
            alt_tokens = tokens[:]
            alt_tokens[i] = perturb_token
            alt_txt = " ".join(alt_tokens)

            t_, a_pert = _explain_text(model, tokenizer, alt_txt, device, expl_method, pred_class)
            if len(a_pert) == d: 
                local_dev_sum += abs(a_pert[i] - base_val)

        avg_dev = local_dev_sum / (n_pert + 1e-9)
        refined[i] = base_val / (1.0 + lambda_ * avg_dev)

    return tokens, refined


######################################################################
# 6) Spurious
######################################################################
def create_spurious_test(texts, labels):
    """
    If label=1 => add 'peach' to the end.
    """
    sp_texts = []
    for i, txt in enumerate(texts):
        if labels[i] == 1:
            sp_texts.append(txt + " peach")
        else:
            sp_texts.append(txt)
    return sp_texts


######################################################################
# 7) Visualization of Token Attributions
######################################################################
def visualize_text_explanations(model, tokenizer, device,
                                texts_normal, texts_spurious, labels,
                                methods, out_fig="text_saliency_plots/text_explanations.png",
                                max_samples=3):
    os.makedirs(os.path.dirname(out_fig), exist_ok=True)
    model.eval()

    indices = np.random.choice(len(texts_normal), size=min(max_samples, len(texts_normal)), replace=False)

    num_methods = len(methods)
    fig, axes = plt.subplots(
        nrows=len(indices),
        ncols=2 * num_methods,
        figsize=(4 * 2 * num_methods, 3 * len(indices))
    )
    if len(indices) == 1:
        axes = [axes]

    for row_i, idx in enumerate(indices):
        text_n = texts_normal[idx]
        text_s = texts_spurious[idx]

        # Predict class for normal
        enc_n = tokenizer(text_n, return_tensors='pt', max_length=30, truncation=True, padding='max_length').to(device)
        emb_n = model.bert.embeddings.word_embeddings(enc_n['input_ids'])
        out_n = model(inputs_embeds=emb_n, attention_mask=enc_n['attention_mask'])
        pred_n = torch.argmax(out_n.logits, dim=1).item()

        # Predict class for spurious
        enc_s = tokenizer(text_s, return_tensors='pt', max_length=30, truncation=True, padding='max_length').to(device)
        emb_s = model.bert.embeddings.word_embeddings(enc_s['input_ids'])
        out_s = model(inputs_embeds=emb_s, attention_mask=enc_s['attention_mask'])
        pred_s = torch.argmax(out_s.logits, dim=1).item()

        for m_i, method in enumerate(methods):
            # Normal
            tokens_n, attrs_n = _explain_text(model, tokenizer, text_n, device, method, target_class=pred_n)
            ax_n = axes[row_i][2*m_i]
            ax_n.bar(range(len(attrs_n)), attrs_n, color="blue")
            ax_n.set_title(f"Normal: {method}\n(Pred={pred_n})")
            ax_n.set_xticks(range(len(attrs_n)))
            ax_n.set_xticklabels(tokens_n, rotation=90, fontsize=8)

            # Spurious
            tokens_s, attrs_s = _explain_text(model, tokenizer, text_s, device, method, target_class=pred_s)
            ax_s = axes[row_i][2*m_i + 1]
            ax_s.bar(range(len(attrs_s)), attrs_s, color="red")
            ax_s.set_title(f"Spurious: {method}\n(Pred={pred_s})")
            ax_s.set_xticks(range(len(attrs_s)))
            ax_s.set_xticklabels(tokens_s, rotation=90, fontsize=8)

    plt.tight_layout()
    plt.savefig(out_fig, dpi=150)
    plt.close()
    print(f"Saved text explanations to: {out_fig}")


######################################################################
# 8) Main
######################################################################
def run_text_experiment_with_explanations(log_filename="text_iq_log.txt"):
    # 1) Generate data
    texts, labels = generate_synthetic_sentences(n_samples=1000)
    n_train = 800
    train_texts, train_labels = texts[:n_train], labels[:n_train]
    test_texts, test_labels = texts[n_train:], labels[n_train:]

    # 2) BERT
    tokenizer = BertTokenizer.from_pretrained("prajjwal1/bert-tiny")
    model = BertForSequenceClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=2)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # 3) Train
    train_dataset = TextDataset(train_texts, train_labels)
    train_bert_classifier(
        model,
        train_dataset,
        tokenizer,
        device=device,
        epochs=4,
        batch_size=16
    )

    # 4) Evaluate
    model.eval()
    correct = 0
    total = len(test_texts)
    for i in range(total):
        enc = tokenizer(
            test_texts[i],
            return_tensors='pt',
            max_length=30,
            truncation=True,
            padding='max_length'
        ).to(device)
        emb = model.bert.embeddings.word_embeddings(enc['input_ids'])
        out = model(inputs_embeds=emb, attention_mask=enc['attention_mask'])
        pred = torch.argmax(out.logits, dim=1).item()
        if pred == test_labels[i]:
            correct += 1
    test_acc = correct / total
    print(f"Test Accuracy (normal): {test_acc:.4f}")

    # 5) Spurious set => add "peach" if label=1
    spurious_texts = create_spurious_test(test_texts, test_labels)

    # Explanation methods
    expl_methods = ["ig", "occlusion", "shapley", "lime"]

    # Logging
    with open(log_filename, "w") as f:
        f.write("Text Explanation with Captum (R, F, IS, RBP)\n")
        f.write("="*60 + "\n")
        f.write(f"TestAccuracy(normal)={test_acc:.4f}\n")

    for method in expl_methods:
        print(f"\n=== Explanation: {method} ===")

        start_t = time.time()
        base_iq = compute_inversion_scores_text(
            model, tokenizer, device,
            texts=test_texts,
            labels=test_labels,
            expl_method=method,
            n_samples=20
        )
        elapsed = time.time() - start_t

        # RBP for a small subset
        # We'll do a minimal demonstration: re-run "apply_rbp_text_single" for ~5 texts
        subset = np.random.choice(len(test_texts), size=5, replace=False)
        R_vals_rbp, F_vals_rbp = [], []

        for idx in subset:
            txt_ = test_texts[idx]
            # predict
            enc_ = tokenizer(txt_, return_tensors='pt', max_length=30, truncation=True, padding='max_length').to(device)
            emb_ = model.bert.embeddings.word_embeddings(enc_['input_ids'])
            out_ = model(inputs_embeds=emb_, attention_mask=enc_['attention_mask'])
            pred_class_ = torch.argmax(out_.logits, dim=1).item()
            prob_orig_ = F.softmax(out_.logits, dim=1)[0, pred_class_].item()

            # baseline attributions
            tokens_b, attr_b = _explain_text(model, tokenizer, txt_, device, method, pred_class_)

            # apply RBP
            tokens_r, attr_r = apply_rbp_text_single(
                model, tokenizer, txt_, device,
                expl_method=method,
                pred_class=pred_class_,
                n_pert=2,
                lambda_=1.0
            )

            # We'll measure R by masking each token with "[MASK]"
            local_sals = []
            local_deltas = []
            for i2, tok2 in enumerate(tokens_r):
                if tok2 in ["[PAD]", "[CLS]", "[SEP]"]:
                    continue
                masked_tokens = tokens_r[:]
                masked_tokens[i2] = "[MASK]"
                masked_txt = " ".join(masked_tokens)

                enc_m = tokenizer(masked_txt, return_tensors='pt', max_length=30, truncation=True, padding='max_length').to(device)
                emb_m = model.bert.embeddings.word_embeddings(enc_m['input_ids'])
                out_m = model(inputs_embeds=emb_m, attention_mask=enc_m['attention_mask'])
                prob_m = F.softmax(out_m.logits, dim=1)[0, pred_class_].item()

                delta_m = prob_m - prob_orig_
                local_sals.append(attr_r[i2])
                local_deltas.append(delta_m)

            if np.std(local_sals)<1e-9 or np.std(local_deltas)<1e-9:
                R_loc = 0.0
            else:
                corr__ = np.corrcoef(local_sals, local_deltas)[0,1]
                R_loc = max(0, corr__)

            # measure F by adding "orange" => compare with attr_r
            new_text_ = txt_ + " orange"
            t_pert, a_pert = _explain_text(model, tokenizer, new_text_, device, method, pred_class_)
            mlen = min(len(a_pert), len(attr_r))
            diff_ = np.mean(np.abs(a_pert[:mlen] - attr_r[:mlen])) if mlen>0 else 0.0
            F_loc = diff_ if diff_<1.0 else 1.0

            R_vals_rbp.append(R_loc)
            F_vals_rbp.append(F_loc)

        R_rbp = np.mean(R_vals_rbp) if len(R_vals_rbp)>0 else 0.0
        F_rbp = np.mean(F_vals_rbp) if len(F_vals_rbp)>0 else 0.0
        p=2
        IS_rbp = ((R_rbp**p + (1-F_rbp)**p)/2.0)**(1.0/p)

        # spurious
        sp_iq = compute_inversion_scores_text(
            model, tokenizer, device,
            texts=spurious_texts,
            labels=test_labels,
            expl_method=method,
            n_samples=20
        )

        with open(log_filename, "a") as ff:
            ff.write(f"Method=[{method}]\n")
            ff.write(
                f"BASE => R={base_iq['R']:.3f}, F={base_iq['F']:.3f}, IS={base_iq['IS']:.3f}, Time={elapsed:.2f}s\n"
            )
            ff.write(
                f"RBP  => R={R_rbp:.3f}, F={F_rbp:.3f}, IS={IS_rbp:.3f}\n"
            )
            ff.write(
                f"Spur => R={sp_iq['R']:.3f}, F={sp_iq['F']:.3f}, IS={sp_iq['IS']:.3f}\n"
            )
            ff.write("-"*60 + "\n")

    # Visualization
    out_figure_path = "saliency_visuals/text_explanations.png"
    visualize_text_explanations(
        model,
        tokenizer,
        device,
        test_texts,
        spurious_texts,
        test_labels,
        expl_methods,
        out_fig=out_figure_path,
        max_samples=3
    )

    print(f"\nResults logged to: {log_filename}")
    print(f"Saliency figure saved to: {out_figure_path}")


if __name__ == "__main__":
    run_text_experiment_with_explanations()
