import os
import time
import torch
import numpy as np
from datasets import load_dataset, concatenate_datasets
from promptsource.templates import DatasetTemplates
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
)
from transformers.modeling_outputs import SequenceClassifierOutput
from evaluate import load as load_metric

# ──────────────── 1. CONFIG ────────────────
task_name    = "rte"                    # e.g. "rte", "sst2", "qnli", "mnli"
model_name   = "gpt2"                # your HF causal LM
batch_size   = 8
max_length   = 128                      # truncation length
output_dir   = "./zero_shot_glue_out"
device       = "cuda" if torch.cuda.is_available() else "cpu"

# ──────────────── 2. DATA & TEMPLATES ────────────────
# load validation / test splits
raw = load_dataset("glue", task_name, 
                   split={"validation": "validation", 
                          "test":       "test"} 
)
# handle MNLI double‐eval
if task_name == "mnli":
    raw["validation_mm"] = load_dataset("glue", "mnli", split="validation_mismatched")
    raw["test_mm"]       = load_dataset("glue", "mnli", split="test_mismatched")

# load PromptSource templates
templates = DatasetTemplates(f"glue/{task_name}")
template  = list(templates.templates.values())[0]   # pick whichever you like
# label names (e.g. ["not_entailment","entailment"])
label_list = raw["validation"].features["label"].names

# ──────────────── 3. PREPROCESS ────────────────
def make_zero_shot_ds(split_ds):
    def preprocess(ex):
        prompt, answer = template.apply(ex.data)
        choices = template.answer_choices.split('|||')

        # encode prompt+choice, collect last‐token logit IDs
        encs, cids = [], []
        for c in choices:
            c = c.strip()
            txt = prompt + " " + c # todo?
            out = tokenizer(txt,
                            #truncation=True,
                            #padding="max_length",
                            #max_length=max_length,
                            return_tensors="pt")
            encs.append(out)
            # assume single‐token labels
            cid = tokenizer(c, add_special_tokens=False).input_ids
            assert len(cid) == 1, "Choice must be 1 token!"
            cids.append(cid[0])
        # stack into lists
        return {
            "input_ids":        [e.input_ids[0] for e in encs],
            "attention_mask":   [e.attention_mask[0] for e in encs],
            "choice_token_ids": cids,
            "labels":           list(range(len(choices))), #.index(label_list[ex["label"]]), # todo
        }

    return split_ds.map(
        preprocess,
        remove_columns=split_ds.column_names,
    )

# ──────────────── 4. MODEL WRAPPER ────────────────
class ZeroShotMultipleChoiceModel(torch.nn.Module):
    def __init__(self, lm):
        super().__init__()
        self.lm = lm

    def forward(self, input_ids=None, attention_mask=None, choice_token_ids=None, labels=None):
        # input_ids: (B, C, L)
        B, C, L = input_ids.shape
        flat_ids   = input_ids.view(B*C, L)
        flat_attn  = attention_mask.view(B*C, L)
        out        = self.lm(flat_ids, attention_mask=flat_attn)
        # get last‐token logits
        last_logits = out.logits[:, -1, :]                 # (B*C, V)
        flat_cids   = choice_token_ids.view(-1)           # (B*C,)
        choice_logp = last_logits[torch.arange(B*C), flat_cids]
        logits      = choice_logp.view(B, C)              # (B, C)
        return SequenceClassifierOutput(logits=logits)

# ──────────────── 5. COLLATOR ────────────────
def collate_fn(features):
    import torch
    b_input_ids      = torch.tensor([f["input_ids"]      for f in features])
    b_attention_mask = torch.tensor([f["attention_mask"] for f in features])
    b_choice_cids    = torch.tensor([f["choice_token_ids"] for f in features])
    b_labels         = torch.tensor([f["labels"]          for f in features])
    return {
        "input_ids":        b_input_ids,
        "attention_mask":   b_attention_mask,
        "choice_token_ids": b_choice_cids,
        "labels":           b_labels,
    }

# ──────────────── 6. SETUP EVERYTHING ────────────────
tokenizer = AutoTokenizer.from_pretrained(model_name)
lm        = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model     = ZeroShotMultipleChoiceModel(lm).to(device)

# prepare zero‐shot datasets
ds_val  = make_zero_shot_ds(raw["validation"])
ds_test = make_zero_shot_ds(raw["test"])
if task_name == "mnli":
    ds_val_mm  = make_zero_shot_ds(raw["validation_mm"])
    ds_test_mm = make_zero_shot_ds(raw["test_mm"])

# metrics
metric = load_metric("glue", task_name)

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    return metric.compute(predictions=preds, references=p.label_ids)

# Trainer args: only eval & predict
args = TrainingArguments(
    output_dir=output_dir,
    do_train=False,
    do_eval=True,
    do_predict=True,
    per_device_eval_batch_size=batch_size,
    logging_dir=os.path.join(output_dir, "logs"),
)

trainer = Trainer(
    model=model,
    args=args,
    tokenizer=tokenizer,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)

# ──────────────── 7. RUN EVAL & PREDICT ────────────────
start_time = time.time()
tasks = ["validation"]
datasets = [ds_val]
if task_name == "mnli":
    tasks.append("validation_mm")
    datasets.append(ds_val_mm)

print("*** Zero‑Shot Evaluation ***")
for ds, t in zip(datasets, tasks):
    res = trainer.predict(ds, metric_key_prefix=f"eval_{t}")
    print(f"— eval_{t} metrics:", res.metrics)

print("*** Zero‑Shot Test Predictions ***")
tasks = ["test"]; datasets = [ds_test]
if task_name == "mnli":
    tasks.append("test_mm"); datasets.append(ds_test_mm)

for ds, t in zip(datasets, tasks):
    res = trainer.predict(ds, metric_key_prefix=f"pred_{t}")
    # predictions as choice‐indices → tokens
    preds = np.argmax(res.predictions, axis=1)
    labels = raw["test"].features["label"].names
    with open(os.path.join(output_dir, f"predictions_{t}.txt"), "w") as fout:
        fout.write("index\tprediction\n")
        for i, pi in enumerate(preds):
            fout.write(f"{i}\t{label_list[pi]}\n")
    print(f"Wrote predictions_{t}.txt")
print(f"Total time: {time.time()-start_time:.1f}s")
