import os, json, argparse
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from tqdm import tqdm
from itertools import combinations

# =========================
# Dataset configuration
# =========================
DATASETS = [
    "sst5", "imdb", "twitter", "Dialogue",
    "AnimalsBeingBros", "Confession", "Cringe", "OkCupid"
]
# DATASETS = [
#     "sst2"
# ]

DATASET_PATHS = {
    "sst5": "data/processed/test/sst5_binary_phrases.txt",
    "AnimalsBeingBros": "data/processed/test/AnimalsBeingBros_posneg.json",
    "Confession": "data/processed/test/confession_posneg.json",
    "Cringe": "data/processed/test/cringe_posneg.json",
    "OkCupid": "data/processed/test/OkCupid_posneg.json",
    "Dialogue": "data/processed/test/dailydialog_emotion_filtered.json",
}


# =========================
# Data loading functions
# =========================
def load_data(name):
    if name == "twitter":
        ds = load_dataset("tweet_eval", "sentiment", split="test")
        texts, labels = [], []
        for t, l in zip(ds["text"], ds["label"]):
            if l in (0, 2):  # 0=negative, 2=positive
                texts.append(t)
                labels.append(1 if l == 2 else 0)
        return texts, np.array(labels)

    elif name == "imdb":
        ds = load_dataset("imdb", split="test")
        return ds["text"], np.array(ds["label"])

    elif name == "sst2":
        # Note: GLUE test set has no labels, use validation instead
        ds = load_dataset("glue", "sst2", split="validation")
        texts, labels = ds["sentence"], np.array(ds["label"])
        return texts, labels

    elif name == "sst5":
        path = DATASET_PATHS[name]
        texts, labels = [], []
        with open(path) as f:
            for line in f:
                if not line.strip(): continue
                parts = line.strip().split("\t")
                if len(parts) != 2: continue
                text, label = parts
                if label.lower() not in {"positive", "negative"}: continue
                texts.append(text)
                labels.append(1 if label.lower() == "positive" else 0)
        return texts, np.array(labels)

    else:
        path = DATASET_PATHS[name]
        with open(path) as f:
            data = json.load(f)
        texts, labels = [], []
        for j in data:
            lbl = j.get("label") or j.get("valence") or j.get("sentiment") or j.get("emotion")
            if isinstance(lbl, str):
                if lbl.lower() not in {"positive", "negative"}: continue
                label = 1 if lbl.lower() == "positive" else 0
            elif isinstance(lbl, int):
                label = lbl
            else:
                continue
            text = j.get("text") or j.get("utterance")
            if text:
                texts.append(text)
                labels.append(label)
        return texts, np.array(labels)


def batch_encode(texts, tokenizer, model, num_layers, batch_size, max_len):
    hidden_size = model.config.hidden_size
    all_embs = [np.empty((len(texts), hidden_size), dtype=np.float32) for _ in range(num_layers)]
    for i in tqdm(range(0, len(texts), batch_size), desc="🧠 Encode"):
        batch = texts[i:i + batch_size]
        inputs = tokenizer(batch, return_tensors="pt", padding="max_length", truncation=True,
                           max_length=max_len).to(model.device)
        with torch.no_grad():
            hs = model(**inputs).hidden_states
        mask = inputs["attention_mask"].unsqueeze(-1)
        for l in range(num_layers):
            h = hs[l + 1]  # start from Layer 1
            pooled = (h * mask).sum(1) / mask.sum(1)
            all_embs[l][i:i + batch_size] = pooled.cpu().numpy()
    return all_embs


# =========================
# Combination testing
# =========================
def test_combinations(args):
    tokenizer = AutoTokenizer.from_pretrained(args.model_id)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModel.from_pretrained(args.model_id, output_hidden_states=True).eval().to(args.device)
    num_layers = model.config.num_hidden_layers

    # main + sub-axis directories
    main_dir = os.path.join(args.exp_main_axis, f"{args.model_name}_main")
    sub_dirs = {
        "contextual": os.path.join(args.exp_sub_axis, f"{args.model_name}_contextual"),
        "genre": os.path.join(args.exp_sub_axis, f"{args.model_name}_genre"),
        "topic": os.path.join(args.exp_sub_axis, f"{args.model_name}_topic"),
        "tone": os.path.join(args.exp_sub_axis, f"{args.model_name}_tone"),
    }

    summary = []

    for ds_name in DATASETS:
        print(f"\n📊 Testing {args.model_name} on {ds_name}")
        test_texts, test_labels = load_data(ds_name)
        test_embs = batch_encode(test_texts, tokenizer, model, num_layers, args.batch_size, args.max_len)

        for layer_id in range(1, num_layers + 1):
            # load main axis
            main_axis = np.load(os.path.join(main_dir, f"sentiment_axis_L{layer_id}.npy"))

            # generate all sub-axis combinations
            sub_keys = list(sub_dirs.keys())
            for r in range(0, len(sub_keys) + 1):  # 0 = only main, up to 4 subs
                for combo in combinations(sub_keys, r):
                    axes = [main_axis]
                    for k in combo:
                        axis_path = os.path.join(sub_dirs[k], f"sentiment_axis_L{layer_id}.npy")
                        if os.path.exists(axis_path):
                            axes.append(np.load(axis_path))
                    # average combination + normalize
                    combined_axis = np.mean(axes, axis=0)
                    combined_axis /= np.linalg.norm(combined_axis)

                    # projection on test set
                    X_test = test_embs[layer_id - 1]
                    scores = X_test @ combined_axis
                    preds = (scores > 0).astype(int)

                    acc = accuracy_score(test_labels, preds)
                    f1 = f1_score(test_labels, preds)
                    try:
                        auroc = roc_auc_score(test_labels, scores)
                    except ValueError:
                        auroc = np.nan

                    summary.append({
                        "dataset": ds_name,
                        "layer": layer_id,
                        "combo": "main" if not combo else "main+" + "+".join(combo),
                        "acc": acc,
                        "f1": f1,
                        "auroc": auroc
                    })

    # save results
    os.makedirs(args.out_dir, exist_ok=True)
    df = pd.DataFrame(summary)
    df.to_csv(os.path.join(args.out_dir, f"{args.model_name}_axis_combinations.csv"), index=False)
    print(f"\n✅ Saved results to {args.out_dir}/{args.model_name}_axis_combinations.csv")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_id", type=str, required=True, help="HuggingFace model id")
    parser.add_argument("--model_name", type=str, required=True, help="Local experiment model name (e.g. gemma-7b)")
    parser.add_argument("--exp_main_axis", type=str, default="experiments/exp_main_axis")
    parser.add_argument("--exp_sub_axis", type=str, default="experiments/exp_sub_axis")
    parser.add_argument("--out_dir", type=str, default="outputs/test_combinations")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--max_len", type=int, default=128)
    parser.add_argument("--batch_size", type=int, default=64)
    args = parser.parse_args()

    test_combinations(args)
